├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── conftest.py
├── noxfile.py
├── pyproject.toml
├── setup.cfg
├── setup.py
├── src
└── disaggregators
│ ├── __init__.py
│ ├── disaggregation_modules
│ ├── __init__.py
│ ├── age
│ │ └── __init__.py
│ ├── continent
│ │ └── __init__.py
│ ├── disaggregation_module.py
│ ├── gender
│ │ └── __init__.py
│ ├── pronoun
│ │ └── __init__.py
│ └── religion
│ │ └── __init__.py
│ └── disaggregator.py
└── tests
├── __init__.py
├── disaggregation_modules
├── test_age.py
├── test_continent.py
├── test_gender.py
├── test_pronoun.py
└── test_religion.py
├── integration
└── test_disaggregation.py
├── test_disaggregation_module.py
└── test_disaggregator.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Locked files
2 | *.lock
3 | !dvc.lock
4 |
5 | # Extracted dummy data
6 | datasets/**/dummy_data-zip-extracted/
7 |
8 | # Compiled python modules.
9 | *.pyc
10 |
11 | # Byte-compiled
12 | _pycache__/
13 | .cache/
14 |
15 | # Python egg metadata, regenerated from source files by setuptools.
16 | *.egg-info
17 | .eggs/
18 |
19 | # PyPI distribution artifacts.
20 | build/
21 | dist/
22 |
23 | # Environments
24 | .env
25 | .venv
26 | env/
27 | venv/
28 | ENV/
29 | env.bak/
30 | venv.bak/
31 |
32 | # pyenv
33 | .python-version
34 |
35 | # Tests
36 | .pytest_cache/
37 |
38 | # Other
39 | *.DS_Store
40 |
41 | # PyCharm/vscode
42 | .idea
43 | .vscode
44 |
45 | # Vim
46 | .*.swp
47 |
48 | # playground
49 | /playground
50 |
51 | # Sphinx documentation
52 | docs/_build/
53 | docs/source/_build/
54 |
55 | notebooks/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: quality style test
2 |
3 | # Check that source code meets quality standards
4 |
5 | quality:
6 | black --check --line-length 119 --target-version py36 tests src
7 | isort --check-only tests src
8 | flake8 tests src
9 |
10 | # Format source code automatically
11 |
12 | style:
13 | black --line-length 119 --target-version py36 tests src
14 | isort tests src
15 |
16 | # Run tests for the library
17 |
18 | test:
19 | python -m pytest -n auto --dist=loadfile ./tests/
20 |
21 | test-fast:
22 | python -m pytest -n auto --dist=loadfile -m "not slow" ./tests/
23 |
24 |
25 | # Utility for Nox
26 |
27 | load_spacy_model:
28 | spacy validate | grep en_core_web_lg || spacy download en_core_web_lg
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | > ⚠️ Please note: This library is in early development, and the disaggregation modules that are included are proofs of concept that are _not_ production-ready. Additionally, all APIs are subject to breaking changes any time before a 1.0.0 release. Rigorously tested versions of the included modules will be released in the future, so stay tuned. [We'd love your feedback in the meantime!](https://github.com/huggingface/disaggregators/discussions/23)
17 |
18 | The `disaggregators` library allows you to easily add new features to your datasets to enable disaggregated data exploration and disaggregated model evaluation. `disaggregators` is preloaded with disaggregation modules for text data, with image modules coming soon!
19 |
20 | This library is intended to be used with [🤗 Datasets](https://github.com/huggingface/datasets), but should work with any other "mappable" interface to a dataset.
21 |
22 | ## Requirements and Installation
23 |
24 | `disaggregators` has been tested on Python 3.8, 3.9, and 3.10.
25 |
26 | `pip install disaggregators` will fetch the latest release from PyPI.
27 |
28 | Note that some disaggregation modules require extra dependencies such as SpaCy modules, which may need to be installed manually. If these dependencies aren't installed, `disaggregators` will inform you about how to install them.
29 |
30 | To install directly from this GitHub repo, use the following command:
31 | ```shell
32 | pip install git+https://github.com/huggingface/disaggregators.git
33 | ```
34 |
35 | ## Usage
36 |
37 | You will likely want to use 🤗 Datasets with `disaggregators`.
38 |
39 | ```shell
40 | pip install datasets
41 | ```
42 |
43 | The snippet below loads the IMDB dataset from the Hugging Face Hub, and initializes a disaggregator for "pronoun" that will run on the IMDB dataset's "text" column. If you would like to run multiple disaggregations, you can pass a list to the `Disaggregator` constructor (e.g. `Disaggregator(["pronoun", "sentiment"], column="text")`). We then use the 🤗 Datasets `map` method to apply the disaggregation to the dataset.
44 |
45 | ```python
46 | from disaggregators import Disaggregator
47 | from datasets import load_dataset
48 |
49 | dataset = load_dataset("imdb", split="train")
50 | disaggregator = Disaggregator("pronoun", column="text")
51 |
52 | ds = dataset.map(disaggregator) # New boolean columns are added for she/her, he/him, and they/them
53 | ```
54 |
55 | The resulting dataset can now be used for data exploration and disaggregated model evaluation.
56 |
57 | You can also run disaggregations on Pandas DataFrames with `.apply` and `.merge`:
58 |
59 | ```python
60 | from disaggregators import Disaggregator
61 | import pandas as pd
62 | df = pd.DataFrame({"text": ["They went to the park."]})
63 |
64 | disaggregator = Disaggregator("pronoun", column="text")
65 |
66 | new_cols = df.apply(disaggregator, axis=1)
67 | df = pd.merge(df, pd.json_normalize(new_cols), left_index=True, right_index=True)
68 | ```
69 |
70 | ### Available Disaggregation Modules
71 |
72 | The following modules are currently available:
73 |
74 | - `"age"`
75 | - `"gender"`
76 | - `"pronoun"`
77 | - `"religion"`
78 | - `"continent"`
79 |
80 | Note that `disaggregators` is in active development, and that these (and future) modules are subject to changing interfaces and implementations at any time before a `1.0.0` release. Each module provides its own method for overriding the default configuration, with the general interface documented below.
81 |
82 | ### Module Configurations
83 |
84 | Modules may make certain variables and functionality configurable. If you'd like to configure a module, import the module, its labels, and its config class. Then, override the labels and set the configuration as needed while instantiating the module. Once instantiated, you can pass the module to the `Disaggregator`. The example below shows this with the `Age` module.
85 |
86 | ```python
87 | from disaggregators import Disaggregator
88 | from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig
89 |
90 | class MeSHAgeLabels(AgeLabels):
91 | INFANT = "infant"
92 | CHILD_PRESCHOOL = "child_preschool"
93 | CHILD = "child"
94 | ADOLESCENT = "adolescent"
95 | ADULT = "adult"
96 | MIDDLE_AGED = "middle_aged"
97 | AGED = "aged"
98 | AGED_80_OVER = "aged_80_over"
99 |
100 | age = Age(
101 | config=AgeConfig(
102 | labels=MeSHAgeLabels,
103 | ages=[list(MeSHAgeLabels)],
104 | breakpoints=[0, 2, 5, 12, 18, 44, 64, 79]
105 | ),
106 | column="question"
107 | )
108 |
109 | disaggregator = Disaggregator([age, "gender"], column="question")
110 | ```
111 |
112 | ### Custom Modules
113 |
114 | Custom modules can be created by extending the `CustomDisaggregator`. All custom modules must have `labels` and a `module_id`, and must implement a `__call__` method.
115 |
116 | ```python
117 | from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator
118 |
119 | class TabsSpacesLabels(DisaggregationModuleLabels):
120 | TABS = "tabs"
121 | SPACES = "spaces"
122 |
123 | class TabsSpaces(CustomDisaggregator):
124 | module_id = "tabs_spaces"
125 | labels = TabsSpacesLabels
126 |
127 | def __call__(self, row, *args, **kwargs):
128 | if "\t" in row[self.column]:
129 | return {self.labels.TABS: True, self.labels.SPACES: False}
130 | else:
131 | return {self.labels.TABS: False, self.labels.SPACES: True}
132 |
133 | disaggregator = Disaggregator(TabsSpaces, column="text")
134 | ```
135 |
136 | ## Development
137 |
138 | Development requirements can be installed with `pip install .[dev]`. See the `Makefile` for useful targets, such as code quality and test running.
139 |
140 | To run tests locally across multiple Python versions (3.8, 3.9, and 3.10), ensure that you have all the Python versions available and then run `nox -r`. Note that this is quite slow, so it's only worth doing to double-check your code before you open a Pull Request.
141 |
142 | ## Contact
143 |
144 | Nima Boscarino – `nima huggingface co`
145 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 |
4 | from typing import Type
5 |
6 | from disaggregators import DisaggregationModule, DisaggregationModuleLabels, DisaggregationModuleConfig
7 |
8 |
9 | def pytest_collection_modifyitems(items):
10 | for item in items:
11 | if "integration" in item.nodeid:
12 | item.add_marker(pytest.mark.slow)
13 | item.add_marker(pytest.mark.integration)
14 |
15 |
16 | def pytest_generate_tests(metafunc):
17 | if metafunc.definition.name == "test_each_module":
18 | metafunc.parametrize("module", [
19 | x.name for x in os.scandir(
20 | metafunc.definition.config.rootdir / "src/disaggregators/disaggregation_modules"
21 | ) if x.is_dir() and not x.name.startswith("__")
22 | ])
23 |
24 | # Fixtures
25 |
26 |
27 | class DummyLabels(DisaggregationModuleLabels):
28 | DUMMY_ONE = "dummy-value-1"
29 | DUMMY_TWO = "dummy-value-2"
30 |
31 |
32 | @pytest.fixture
33 | def dummy_labels():
34 | return DummyLabels
35 |
36 |
37 | @pytest.fixture
38 | def dummy_module_config(dummy_labels):
39 | class DummyModuleConfig(DisaggregationModuleConfig):
40 | def __init__(self, labels: Type[dummy_labels]):
41 | self.labels = labels
42 |
43 | return DummyModuleConfig
44 |
45 |
46 | @pytest.fixture
47 | def dummy_module(dummy_labels, dummy_module_config):
48 | class DummyModule(DisaggregationModule):
49 | labels = dummy_labels
50 |
51 | def __init__(self, module_id="dummy-module", *args, **kwargs):
52 | super().__init__(module_id=module_id, *args, **kwargs)
53 |
54 | def _apply_config(self, config):
55 | self.labels = config.labels
56 |
57 | def __call__(self, row, *args, **kwargs):
58 | return {label: True for label in list(self.labels)}
59 |
60 | return DummyModule
61 |
62 |
63 | @pytest.fixture
64 | def custom_dummy_labels():
65 | class CustomDummyLabels(DisaggregationModuleLabels):
66 | DUMMY_ONE = "dummy-value-1"
67 | DUMMY_TWO = "dummy-value-2"
68 | DUMMY_THREE = "dummy-value-3"
69 |
70 | return CustomDummyLabels
71 |
72 |
73 | @pytest.fixture
74 | def configured_module(custom_dummy_labels, dummy_module_config, dummy_module):
75 | return dummy_module(config=dummy_module_config(labels=custom_dummy_labels), column=None)
76 |
77 |
78 | @pytest.fixture
79 | def configured_dummy_expected_results(custom_dummy_labels, configured_module):
80 | return {f"{configured_module.name}.{label}": True for label in custom_dummy_labels}
81 |
--------------------------------------------------------------------------------
/noxfile.py:
--------------------------------------------------------------------------------
1 | import nox
2 |
3 | python_versions = ["3.8", "3.9", "3.10"]
4 |
5 |
6 | @nox.session(python=python_versions)
7 | def tests(session):
8 | session.install(".[tests]")
9 | session.run("make", "load_spacy_model", external=True)
10 | session.run("make", "test", external=True)
11 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | markers = [
3 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
4 | "integration: marks integration tests (deselect with '-m \"not integration\"')",
5 | ]
6 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | license_file = LICENSE
3 |
4 | [isort]
5 | ensure_newline_before_comments = True
6 | force_grid_wrap = 0
7 | include_trailing_comma = True
8 | line_length = 119
9 | lines_after_imports = 2
10 | multi_line_output = 3
11 | use_parentheses = True
12 |
13 | [flake8]
14 | ignore = E203, E501, W503
15 | max-line-length = 119
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | """ HuggingFace/Disaggregators is an open library for disaggregating datasets.
3 |
4 | Note:
5 |
6 | VERSION needs to be formatted following the MAJOR.MINOR.PATCH convention
7 |
8 | Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
9 |
10 | To create the package for pypi.
11 |
12 | 0. Prerequisites:
13 | - Dependencies:
14 | - twine: "pip install twine"
15 | - Create an account in (and join the 'disaggregators' project):
16 | - PyPI: https://pypi.org/
17 | - Test PyPI: https://test.pypi.org/
18 |
19 | 1. Change the version in:
20 | - __init__.py
21 | - setup.py
22 |
23 | 2. Commit these changes: "git commit -m 'Release: VERSION'"
24 |
25 | 3. Add a tag in git to mark the release: "git tag VERSION -m 'Add tag VERSION for pypi'"
26 | Push the tag to remote: git push --tags origin main
27 |
28 | 4. Build both the sources and the wheel. Do not change anything in setup.py between
29 | creating the wheel and the source distribution (obviously).
30 |
31 | First, delete any "build" directory that may exist from previous builds.
32 |
33 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
34 | (this will build a wheel for the python version you use to build it).
35 |
36 | For the sources, run: "python setup.py sdist"
37 | You should now have a /dist directory with both .whl and .tar.gz source versions.
38 |
39 | 5. Check that everything looks correct by uploading the package to the pypi test server:
40 |
41 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
42 |
43 | Check that you can install it in a virtualenv/notebook by running:
44 | pip install -i https://testpypi.python.org/pypi disaggregators
45 |
46 | 6. Upload the final version to actual pypi:
47 | twine upload dist/* -r pypi
48 |
49 | 7. Fill release notes in the tag in GitHub once everything is looking hunky-dory.
50 |
51 | 8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
52 | Then push the change with a message 'set dev version'
53 | """
54 |
55 | from setuptools import find_packages, setup
56 |
57 |
58 | REQUIRED_PKGS = [
59 | # Utilities from PyPA to e.g., compare versions
60 | "packaging",
61 | "spacy",
62 | "datasets",
63 | "aenum>=3.1.11",
64 | "sentence-transformers>=2.2.2",
65 | "geograpy3",
66 | "nltk",
67 | "requests",
68 | ]
69 |
70 | TESTS_REQUIRE = [
71 | # test dependencies
72 | "pytest",
73 | "pytest-datadir",
74 | "pytest-xdist",
75 | "pytest-mock",
76 | "nox",
77 | "pandas",
78 | ]
79 |
80 | QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.0.0", "pyyaml>=5.3.1"]
81 |
82 |
83 | EXTRAS_REQUIRE = {
84 | "dev": TESTS_REQUIRE + QUALITY_REQUIRE,
85 | "tests": TESTS_REQUIRE,
86 | "quality": QUALITY_REQUIRE,
87 | }
88 |
89 | setup(
90 | name="disaggregators",
91 | version="0.1.3.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
92 | description="HuggingFace community-driven open-source library for dataset disaggregation",
93 | long_description=open("README.md", encoding="utf-8").read(),
94 | long_description_content_type="text/markdown",
95 | author="HuggingFace Inc.",
96 | author_email="nima@huggingface.co",
97 | url="https://github.com/NimaBoscarino/disaggregators",
98 | download_url="https://github.com/NimaBoscarino/disaggregators/tags",
99 | license="Apache 2.0",
100 | package_dir={"": "src"},
101 | packages=find_packages("src"),
102 | install_requires=REQUIRED_PKGS,
103 | extras_require=EXTRAS_REQUIRE,
104 | python_requires=">=3.7.0",
105 | classifiers=[
106 | "Development Status :: 1 - Planning",
107 | "Intended Audience :: Developers",
108 | "Intended Audience :: Education",
109 | "Intended Audience :: Science/Research",
110 | "License :: OSI Approved :: Apache Software License",
111 | "Operating System :: OS Independent",
112 | "Programming Language :: Python :: 3",
113 | "Programming Language :: Python :: 3.8",
114 | "Programming Language :: Python :: 3.9",
115 | "Programming Language :: Python :: 3.10",
116 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
117 | ],
118 | keywords="machine learning evaluate evaluation disaggregation",
119 | zip_safe=False, # Required for mypy to find the py.typed file
120 | )
121 |
--------------------------------------------------------------------------------
/src/disaggregators/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # Copyright 2022 The HuggingFace Disaggregators Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | # pylint: enable=line-too-long
18 | # pylint: disable=g-import-not-at-top,g-bad-import-order,wrong-import-position
19 |
20 | __version__ = "0.1.3.dev0"
21 |
22 | from packaging import version
23 |
24 | from disaggregators.disaggregation_modules import (
25 | CustomDisaggregator,
26 | DisaggregationModule,
27 | DisaggregationModuleConfig,
28 | DisaggregationModuleFactory,
29 | DisaggregationModuleLabels,
30 | )
31 |
32 | from .disaggregator import Disaggregator
33 |
34 |
35 | SCRIPTS_VERSION = "main" if version.parse(__version__).is_devrelease else __version__
36 |
37 | del version
38 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregation_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .age import Age
2 | from .continent import Continent
3 | from .disaggregation_module import (
4 | CustomDisaggregator,
5 | DisaggregationModule,
6 | DisaggregationModuleConfig,
7 | DisaggregationModuleFactory,
8 | DisaggregationModuleLabels,
9 | )
10 | from .gender import Gender
11 | from .pronoun import Pronoun
12 | from .religion import Religion
13 |
14 |
15 | AVAILABLE_MODULES = {"pronoun": Pronoun, "age": Age, "gender": Gender, "religion": Religion, "continent": Continent}
16 |
17 | __all__ = [
18 | "DisaggregationModule",
19 | "DisaggregationModuleFactory",
20 | "DisaggregationModuleLabels",
21 | "CustomDisaggregator",
22 | "DisaggregationModuleConfig",
23 | ]
24 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregation_modules/age/__init__.py:
--------------------------------------------------------------------------------
1 | import re
2 | from bisect import bisect
3 | from typing import List, Type
4 |
5 | import spacy
6 |
7 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels
8 |
9 |
10 | class AgeLabels(DisaggregationModuleLabels):
11 | CHILD = "child"
12 | YOUTH = "youth"
13 | ADULT = "adult"
14 | SENIOR = "senior"
15 |
16 |
17 | class AgeConfig(DisaggregationModuleConfig):
18 | def __init__(self, labels: Type[AgeLabels], ages: List, breakpoints: List):
19 | self.labels = labels
20 | self.ages = ages
21 | self.breakpoints = breakpoints
22 |
23 |
24 | class Age(DisaggregationModule):
25 | labels = AgeLabels
26 | AGES = [AgeLabels.CHILD, AgeLabels.YOUTH, AgeLabels.ADULT, AgeLabels.SENIOR]
27 | AGE_BREAKPOINTS = [0, 12, 20, 65]
28 | spacy_model = "en_core_web_lg"
29 |
30 | def __init__(self, *args, **kwargs):
31 | try:
32 | self.nlp = spacy.load(self.spacy_model, enable="ner")
33 | except OSError:
34 | raise ValueError(
35 | f"This disaggregation module depends on the {self.spacy_model} model from spaCy.\n"
36 | f"You can install it by running: python -m spacy download {self.spacy_model}"
37 | )
38 |
39 | super().__init__(module_id="age", *args, **kwargs)
40 |
41 | def _apply_config(self, config: AgeConfig):
42 | self.labels = config.labels
43 | self.AGES = config.ages
44 | self.AGE_BREAKPOINTS = config.breakpoints
45 |
46 | def __call__(self, row, *args, **kwargs):
47 | return_ages = {age: False for age in self.AGES}
48 | text = row[self.column]
49 | doc = self.nlp(text)
50 | date_entities = [y for y in doc.ents if y.label_ == "DATE"]
51 |
52 | for date_entity in date_entities:
53 | value = re.search(r"\d+", date_entity.text)
54 |
55 | if value is None:
56 | continue
57 |
58 | value = int(value[0])
59 |
60 | if value <= 120:
61 | age_bucket = bisect(self.AGE_BREAKPOINTS, value)
62 | if age_bucket >= 1:
63 | return_ages.update({self.AGES[age_bucket - 1]: True})
64 |
65 | return return_ages
66 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregation_modules/continent/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | # noinspection PyPackageRequirements
4 | import geograpy # from geograpy3
5 | import nltk
6 | import requests
7 |
8 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleLabels
9 |
10 |
11 | class ContinentLabels(DisaggregationModuleLabels):
12 | AFRICA = "africa"
13 | AMERICAS = "americas"
14 | ASIA = "asia"
15 | EUROPE = "europe"
16 | OCEANIA = "oceania"
17 |
18 |
19 | class Continent(DisaggregationModule):
20 | labels = ContinentLabels
21 | continents = [
22 | ContinentLabels.AFRICA,
23 | ContinentLabels.AMERICAS,
24 | ContinentLabels.ASIA,
25 | ContinentLabels.EUROPE,
26 | ContinentLabels.OCEANIA,
27 | ]
28 |
29 | def __init__(self, *args, **kwargs):
30 | super().__init__(module_id="continent", *args, **kwargs)
31 |
32 | countries_url = "https://raw.githubusercontent.com/bigscience-workshop/data_sourcing/master/"\
33 | "sourcing_sprint/resources/country_regions.json"
34 |
35 | response = json.loads(requests.get(countries_url).text)
36 |
37 | self.continents = response[0]
38 | countries = response[1]
39 | region_countries = response[2]
40 |
41 | def get_countries_and_regions(continent_or_region):
42 | return_countries_and_regions = {"countries": [], "regions": []}
43 |
44 | for region in region_countries.get(continent_or_region):
45 | if region in countries:
46 | return_countries_and_regions["countries"] = return_countries_and_regions["countries"] + [region]
47 | else:
48 | countries_and_regions = get_countries_and_regions(region)
49 | return_countries_and_regions["regions"] = return_countries_and_regions["regions"] + [region]
50 | return_countries_and_regions["countries"] = (
51 | return_countries_and_regions["countries"] + countries_and_regions["countries"]
52 | )
53 |
54 | return return_countries_and_regions
55 |
56 | continent_maps = {c: get_countries_and_regions(c) for c in self.continents}
57 |
58 | self.continent_lists = [
59 | [c, *continent_maps[c]["regions"], *continent_maps[c]["countries"]] for c in continent_maps
60 | ]
61 |
62 | nltk.download("punkt")
63 | nltk.download("averaged_perceptron_tagger")
64 | nltk.download("maxent_ne_chunker")
65 | nltk.download("words")
66 |
67 | def __call__(self, row, *args, **kwargs):
68 | return_continent = {continent: False for continent in list(ContinentLabels)}
69 |
70 | places = geograpy.get_geoPlace_context(text=row[self.column]).countries
71 |
72 | if not len(places) > 0:
73 | return return_continent
74 |
75 | continent_search = [cl[0] for cl in self.continent_lists if places[0] in cl]
76 |
77 | if len(continent_search) > 0:
78 | continent = continent_search[0]
79 | label = getattr(ContinentLabels, continent.upper())
80 | return_continent.update({label: True})
81 |
82 | return return_continent
83 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregation_modules/disaggregation_module.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Optional, Type, Union
3 |
4 | from aenum import Constant
5 |
6 | from disaggregators import disaggregation_modules
7 |
8 |
9 | class DisaggregationModuleLabels(Constant):
10 | pass
11 |
12 |
13 | class DisaggregationModuleConfig:
14 | labels: Type[DisaggregationModuleLabels]
15 |
16 |
17 | class DisaggregationModule(ABC):
18 | def __init__(self, module_id: str, column: Optional[str], config: DisaggregationModuleConfig = None):
19 | self.name = module_id
20 | self.column = column
21 | self.citations: List[str] = []
22 |
23 | if config:
24 | self._apply_config(config)
25 |
26 | @abstractmethod
27 | def __call__(self, row, *args, **kwargs):
28 | raise NotImplementedError()
29 |
30 | @property
31 | @abstractmethod
32 | def labels(self) -> Type[DisaggregationModuleLabels]:
33 | pass
34 |
35 | def _apply_config(self, config: DisaggregationModuleConfig):
36 | pass
37 |
38 | @property
39 | def field_names(self):
40 | return {f"{self.name}.{x}" for x in list(self.labels)}
41 |
42 |
43 | class CustomDisaggregator(DisaggregationModule, ABC):
44 | """
45 | This class exists to provide a simple interface for creating custom disaggregation modules. This is useful because
46 | the DisaggregationModule abstract class may enforce extra rules that we don't want users to have to worry about.
47 | """
48 |
49 | def __init__(self, *args, **kwargs):
50 | super().__init__(module_id=self.module_id, *args, **kwargs)
51 |
52 | @property
53 | @abstractmethod
54 | def module_id(self):
55 | raise NotImplementedError()
56 |
57 | @property
58 | @abstractmethod
59 | def labels(self):
60 | raise NotImplementedError()
61 |
62 |
63 | class DisaggregationModuleFactory:
64 | @staticmethod
65 | def create_module(module: Union[str, Type[CustomDisaggregator], DisaggregationModule], *args, **kwargs):
66 | if isinstance(module, str):
67 | return DisaggregationModuleFactory.create_from_id(module, *args, **kwargs)
68 | elif isinstance(module, DisaggregationModule):
69 | return module
70 | elif issubclass(module, CustomDisaggregator):
71 | return DisaggregationModuleFactory.create_from_class(module, *args, **kwargs)
72 | else:
73 | raise ValueError("Invalid module type received.")
74 |
75 | @staticmethod
76 | def create_from_id(module_id: str, *args, **kwargs) -> DisaggregationModule:
77 | if module_id not in disaggregation_modules.AVAILABLE_MODULES:
78 | raise ValueError("Invalid module_id received.")
79 |
80 | return disaggregation_modules.AVAILABLE_MODULES[module_id](*args, **kwargs)
81 |
82 | @staticmethod
83 | def create_from_class(module: Type[CustomDisaggregator], *args, **kwargs) -> DisaggregationModule:
84 | return module(*args, **kwargs)
85 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregation_modules/gender/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Type
2 |
3 | import pandas as pd
4 | import spacy
5 | from datasets import load_dataset
6 |
7 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels
8 |
9 |
10 | class GenderLabels(DisaggregationModuleLabels):
11 | MALE = "male"
12 | FEMALE = "female"
13 |
14 |
15 | class GenderConfig(DisaggregationModuleConfig):
16 | def __init__(self, labels: Type[GenderLabels], word_lists: Dict[GenderLabels, List[str]]):
17 | self.labels = labels
18 | self.word_lists = word_lists
19 |
20 |
21 | class Gender(DisaggregationModule):
22 | labels = GenderLabels
23 | spacy_model = "en_core_web_lg"
24 |
25 | citations = [
26 | """
27 | @inproceedings{dinan-etal-2020-multi,
28 | title = "Multi-Dimensional Gender Bias Classification",
29 | author = "Dinan, Emily and
30 | Fan, Angela and
31 | Wu, Ledell and
32 | Weston, Jason and
33 | Kiela, Douwe and
34 | Williams, Adina",
35 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", # noqa: E501
36 | year = "2020",
37 | publisher = "Association for Computational Linguistics",
38 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.23",
39 | doi = "10.18653/v1/2020.emnlp-main.23",
40 | }
41 | """
42 | ]
43 |
44 | def __init__(self, *args, **kwargs):
45 | self.gender_df = load_dataset("md_gender_bias", "gendered_words", split="train").to_pandas()
46 | self.gender_df.columns = [GenderLabels.MALE, GenderLabels.FEMALE]
47 |
48 | try:
49 | self.nlp = spacy.load(self.spacy_model)
50 | except OSError:
51 | raise ValueError(
52 | f"This disaggregation module depends on the {self.spacy_model} model from spaCy.\n"
53 | f"You can install it by running: python -m spacy download {self.spacy_model}"
54 | )
55 |
56 | super().__init__(module_id="gender", *args, **kwargs)
57 |
58 | def _apply_config(self, config: GenderConfig):
59 | self.labels = config.labels
60 | for gender, word_list in config.word_lists.items():
61 | self.gender_df[gender] = pd.DataFrame(word_list)
62 |
63 | def __call__(self, row, *args, **kwargs):
64 | return_genders = {gender: False for gender in list(GenderLabels)}
65 |
66 | doc = self.nlp(row[self.column])
67 |
68 | nouns = [c.text for x in doc.noun_chunks for c in x.root.subtree]
69 |
70 | for noun in nouns:
71 | result = self.gender_df[self.gender_df == noun].any()
72 | for gender_hit in list(result[result].keys()):
73 | return_genders.update({gender_hit: True})
74 |
75 | return return_genders
76 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregation_modules/pronoun/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Set, Type
2 |
3 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels
4 |
5 |
6 | class PronounLabels(DisaggregationModuleLabels):
7 | SHE_HER = "she_her"
8 | HE_HIM = "he_him"
9 | THEY_THEM = "they_them"
10 |
11 |
12 | class PronounConfig(DisaggregationModuleConfig):
13 | def __init__(self, labels: Type[PronounLabels], pronouns: Dict[PronounLabels, Set[str]]):
14 | self.labels = labels
15 | self.pronouns = pronouns
16 |
17 |
18 | class Pronoun(DisaggregationModule):
19 | labels = PronounLabels
20 | AVAILABLE_PRONOUNS = {
21 | PronounLabels.SHE_HER: {"she", "her", "hers", "herself"},
22 | PronounLabels.HE_HIM: {"he", "him", "his", "himself"},
23 | PronounLabels.THEY_THEM: {"they", "them", "their", "theirs", "themself", "themselves"},
24 | }
25 |
26 | def __init__(self, *args, **kwargs):
27 | super().__init__(module_id="pronoun", *args, **kwargs)
28 |
29 | def _apply_config(self, config: PronounConfig):
30 | self.labels = config.labels
31 | self.AVAILABLE_PRONOUNS = {**config.pronouns, **self.AVAILABLE_PRONOUNS}
32 |
33 | def __call__(self, row, *args, **kwargs):
34 | text = row[self.column]
35 | pronoun_flag = {
36 | av_p: any(p in text.lower().split() for p in self.AVAILABLE_PRONOUNS[av_p])
37 | for av_p in self.AVAILABLE_PRONOUNS
38 | }
39 |
40 | return pronoun_flag
41 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregation_modules/religion/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Type
2 |
3 | from sentence_transformers import SentenceTransformer
4 | from sentence_transformers.util import semantic_search
5 |
6 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels
7 |
8 |
9 | class ReligionLabels(DisaggregationModuleLabels):
10 | JUDAISM = "judaism"
11 | ISLAM = "islam"
12 | BUDDHISM = "buddhism"
13 | CHRISTIANITY = "christianity"
14 |
15 |
16 | class ReligionConfig(DisaggregationModuleConfig):
17 | def __init__(self, labels: Type[ReligionLabels] = None, threshold: float = None):
18 | self.labels = labels
19 | self.threshold = threshold
20 |
21 |
22 | class Religion(DisaggregationModule):
23 | labels = ReligionLabels
24 | threshold = 0.14 # Arbitrary threshold, hand-tuned.
25 | religions = [
26 | ReligionLabels.JUDAISM,
27 | ReligionLabels.ISLAM,
28 | ReligionLabels.BUDDHISM,
29 | ReligionLabels.CHRISTIANITY,
30 | ]
31 |
32 | def __init__(self, *args, **kwargs):
33 | self.model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
34 |
35 | super().__init__(module_id="religion", *args, **kwargs)
36 |
37 | self.embeddings = self.model.encode([str(religion) for religion in self.religions], convert_to_tensor=True)
38 |
39 | def _apply_config(self, config: ReligionConfig):
40 | if config.labels:
41 | self.labels = config.labels
42 | self.religions = self.religions + list(config.labels)
43 |
44 | self.threshold = config.threshold or self.threshold
45 |
46 | def __call__(self, row, *args, **kwargs):
47 | return_religion = {religion: False for religion in list(ReligionLabels)}
48 |
49 | query = self.model.encode(row[self.column], convert_to_tensor=True)
50 | religion_hit = semantic_search(query, self.embeddings, top_k=1)[0][0]
51 |
52 | if religion_hit["score"] > self.threshold:
53 | return_religion.update({self.religions[religion_hit["corpus_id"]]: True})
54 |
55 | return return_religion
56 |
--------------------------------------------------------------------------------
/src/disaggregators/disaggregator.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Set, Type, Union
2 |
3 | from disaggregators.disaggregation_modules import (
4 | CustomDisaggregator,
5 | DisaggregationModule,
6 | DisaggregationModuleFactory,
7 | )
8 |
9 |
10 | class Disaggregator:
11 | def __init__(
12 | self,
13 | module: Optional[
14 | Union[
15 | str,
16 | List[str],
17 | DisaggregationModule,
18 | List[DisaggregationModule],
19 | Type[CustomDisaggregator],
20 | List[Type[CustomDisaggregator]],
21 | ]
22 | ] = None,
23 | *args,
24 | **kwargs,
25 | ):
26 | if module is None:
27 | module = []
28 |
29 | if not isinstance(module, list):
30 | module_list = [module]
31 | else:
32 | module_list = module
33 |
34 | self.modules = [DisaggregationModuleFactory.create_module(module, *args, **kwargs) for module in module_list]
35 |
36 | def get_function(self) -> Callable:
37 | # Merge dicts - https://stackoverflow.com/a/3495395
38 | return lambda x: {
39 | f"{d[0]}.{str(k)}": v
40 | for d in [(module.name, module(x)) for module in self.modules]
41 | for k, v in d[1].items()
42 | }
43 |
44 | def __call__(self, x) -> Callable:
45 | return self.get_function()(x)
46 |
47 | @property
48 | def fields(self) -> Set:
49 | return {*[f"{module.name}.{str(label)}" for module in self.modules for label in module.labels]}
50 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/disaggregators/b8ea3170b119e6768812874b71367b8932db0d37/tests/__init__.py
--------------------------------------------------------------------------------
/tests/disaggregation_modules/test_age.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from disaggregators.disaggregation_modules.age import Age, AgeConfig, AgeLabels
4 |
5 |
6 | def test_initialize():
7 | disagg_module = Age(column=None)
8 | assert disagg_module.name == "age"
9 | assert set(disagg_module.labels) == {AgeLabels.CHILD, AgeLabels.YOUTH, AgeLabels.ADULT, AgeLabels.SENIOR}
10 |
11 |
12 | @pytest.mark.slow
13 | @pytest.mark.parametrize(
14 | "text,expected",
15 | [
16 | ("The man went to the park.", []),
17 | ("The 40-year-old man went to the park.", [AgeLabels.ADULT]),
18 | ("The 10 year old child ate a hot dog.", [AgeLabels.CHILD]),
19 | ("Clara is 18 years old.", [AgeLabels.YOUTH]),
20 | ("Hamed's grandpa is an 86 yr old smoker.", [AgeLabels.SENIOR]),
21 | ("The 40-year-old man went to the park with his 10 year old daughter.", [AgeLabels.ADULT, AgeLabels.CHILD]),
22 | ("Farzaneh is 18 years old and her grandmother is 86 yrs old", [AgeLabels.YOUTH, AgeLabels.SENIOR]),
23 | ],
24 | )
25 | def test_call_default(text, expected):
26 | base_labels = {age: False for age in list(AgeLabels)}
27 | data = {"text": text}
28 | disagg_module = Age(column="text")
29 | results = disagg_module(data)
30 | assert results == {**base_labels, **{label: True for label in expected}}
31 |
32 |
33 | @pytest.mark.slow
34 | def test_call_custom():
35 | class CustomAgeLabels(AgeLabels):
36 | NIMA = "nima"
37 | OLDER_THAN_NIMA = "older_nima"
38 | YOUNGER_THAN_NIMA = "younger_nima"
39 |
40 | examples = [
41 | {"text": "Jenny is 20 years old."},
42 | {"text": "Nima is 26 years old."},
43 | {"text": "Mark Knopfler is 73 years old."},
44 | ]
45 |
46 | disagg_module = Age(
47 | config=AgeConfig(
48 | labels=CustomAgeLabels,
49 | ages=[CustomAgeLabels.YOUNGER_THAN_NIMA, CustomAgeLabels.NIMA, CustomAgeLabels.OLDER_THAN_NIMA],
50 | breakpoints=[0, 25, 27],
51 | ),
52 | column="text",
53 | )
54 | results = [disagg_module(example) for example in examples]
55 |
56 | assert results == [
57 | {CustomAgeLabels.NIMA: False, CustomAgeLabels.OLDER_THAN_NIMA: False, CustomAgeLabels.YOUNGER_THAN_NIMA: True},
58 | {CustomAgeLabels.NIMA: True, CustomAgeLabels.OLDER_THAN_NIMA: False, CustomAgeLabels.YOUNGER_THAN_NIMA: False},
59 | {CustomAgeLabels.NIMA: False, CustomAgeLabels.OLDER_THAN_NIMA: True, CustomAgeLabels.YOUNGER_THAN_NIMA: False},
60 | ]
61 |
--------------------------------------------------------------------------------
/tests/disaggregation_modules/test_continent.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from disaggregators.disaggregation_modules.continent import Continent, ContinentLabels
4 |
5 |
6 | def test_initialize():
7 | disagg_module = Continent(column=None)
8 | assert disagg_module.name == "continent"
9 | assert set(disagg_module.labels) == {
10 | ContinentLabels.AFRICA,
11 | ContinentLabels.AMERICAS,
12 | ContinentLabels.ASIA,
13 | ContinentLabels.EUROPE,
14 | ContinentLabels.OCEANIA,
15 | }
16 |
17 |
18 | @pytest.mark.slow
19 | @pytest.mark.parametrize(
20 | "text,expected",
21 | [
22 | (
23 | "Tomorrow I am visiting Italy.",
24 | [ContinentLabels.EUROPE],
25 | ),
26 | (
27 | "A gentle breeze blows through the leaves of a cherry blossom tree in Vancouver.",
28 | [ContinentLabels.AMERICAS],
29 | ),
30 | (
31 | "Let's eat pizza.",
32 | [],
33 | ),
34 | ],
35 | )
36 | def test_call_default(text, expected):
37 | base_labels = {age: False for age in list(ContinentLabels)}
38 | data = {"text": text}
39 | disagg_module = Continent(column="text")
40 | results = disagg_module(data)
41 | assert results == {**base_labels, **{label: True for label in expected}}
42 |
--------------------------------------------------------------------------------
/tests/disaggregation_modules/test_gender.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from disaggregators.disaggregation_modules.gender import Gender, GenderConfig, GenderLabels
4 |
5 |
6 | def test_initialize():
7 | disagg_module = Gender(column=None)
8 | assert disagg_module.name == "gender"
9 | assert set(disagg_module.labels) == {GenderLabels.MALE, GenderLabels.FEMALE}
10 |
11 |
12 | @pytest.mark.slow
13 | @pytest.mark.parametrize(
14 | "text,expected",
15 | [
16 | ("That is one large cat!", []),
17 | ("The 40-year-old man went to the park.", [GenderLabels.MALE]),
18 | ("The clown gave the girl an ice cream cone.", [GenderLabels.FEMALE]),
19 | ("What is the boy's name?", [GenderLabels.MALE]),
20 | ("I checked the lady's ticket.", [GenderLabels.FEMALE]),
21 | ("The guy gave the woman a high-five.", [GenderLabels.MALE, GenderLabels.FEMALE]),
22 | ],
23 | )
24 | def test_call_default(text, expected):
25 | base_labels = {age: False for age in list(GenderLabels)}
26 | data = {"text": text}
27 | disagg_module = Gender(column="text")
28 | results = disagg_module(data)
29 | assert results == {**base_labels, **{label: True for label in expected}}
30 |
31 |
32 | @pytest.mark.slow
33 | def test_call_custom():
34 | class CustomGenderLabels(GenderLabels):
35 | NON_BINARY = "non-binary"
36 |
37 | _CUSTOM_WORD_LISTS = {CustomGenderLabels.NON_BINARY: ["clown"]}
38 |
39 | data = {"text": "The sad clown went to the doctor."}
40 | disagg_module = Gender(
41 | config=GenderConfig(labels=CustomGenderLabels, word_lists=_CUSTOM_WORD_LISTS), column="text"
42 | )
43 | results = disagg_module(data)
44 | assert results == {
45 | CustomGenderLabels.MALE: False,
46 | CustomGenderLabels.FEMALE: False,
47 | CustomGenderLabels.NON_BINARY: True,
48 | }
49 |
--------------------------------------------------------------------------------
/tests/disaggregation_modules/test_pronoun.py:
--------------------------------------------------------------------------------
1 | from disaggregators.disaggregation_modules.pronoun import Pronoun, PronounConfig, PronounLabels
2 |
3 |
4 | def test_initialize():
5 | disagg_module = Pronoun(column=None)
6 | assert disagg_module.name == "pronoun"
7 | assert set(disagg_module.labels) == {PronounLabels.HE_HIM, PronounLabels.SHE_HER, PronounLabels.THEY_THEM}
8 |
9 |
10 | def test_call_default():
11 | data = {"text": "He went to the park."}
12 | disagg_module = Pronoun(column="text")
13 | results = disagg_module(data)
14 | assert results == {PronounLabels.HE_HIM: True, PronounLabels.SHE_HER: False, PronounLabels.THEY_THEM: False}
15 |
16 |
17 | def test_call_custom():
18 | class CustomPronounLabels(PronounLabels):
19 | ZE_ZIR = "ze_zir"
20 |
21 | _CUSTOM_PRONOUN_MAPPING = {CustomPronounLabels.ZE_ZIR: {"ze", "zir", "zirs", "zirself"}}
22 |
23 | data = {"text": "Ze went to the park."}
24 | disagg_module = Pronoun(
25 | config=PronounConfig(labels=CustomPronounLabels, pronouns=_CUSTOM_PRONOUN_MAPPING), column="text"
26 | )
27 | results = disagg_module(data)
28 | assert results == {
29 | CustomPronounLabels.ZE_ZIR: True,
30 | CustomPronounLabels.HE_HIM: False,
31 | CustomPronounLabels.SHE_HER: False,
32 | CustomPronounLabels.THEY_THEM: False,
33 | }
34 |
--------------------------------------------------------------------------------
/tests/disaggregation_modules/test_religion.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from disaggregators.disaggregation_modules.religion import Religion, ReligionConfig, ReligionLabels
4 |
5 |
6 | def test_initialize():
7 | disagg_module = Religion(column=None)
8 | assert disagg_module.name == "religion"
9 | assert set(disagg_module.labels) == {
10 | ReligionLabels.JUDAISM,
11 | ReligionLabels.ISLAM,
12 | ReligionLabels.BUDDHISM,
13 | ReligionLabels.CHRISTIANITY,
14 | }
15 |
16 |
17 | @pytest.mark.slow
18 | @pytest.mark.parametrize(
19 | "text,expected",
20 | [
21 | (
22 | "The menorah is a seven-branched candelabrum that is described in the Hebrew Bible.",
23 | [ReligionLabels.JUDAISM],
24 | ),
25 | (
26 | "Traditionally, Eid al-Fitr begins at sunset on the night of the first sighting of the crescent moon.",
27 | [ReligionLabels.ISLAM],
28 | ),
29 | (
30 | "Three main prevailing theories exist on the finalization of Lent as a forty-day fast"
31 | "prior to the arrival of Easter Sunday.",
32 | [ReligionLabels.CHRISTIANITY],
33 | ),
34 | (
35 | "Samsara means 'wandering' or 'world', with the connotation of cyclic, circuitous change.",
36 | [ReligionLabels.BUDDHISM],
37 | ),
38 | (
39 | "CBC Music is a Canadian FM radio network operated by the Canadian Broadcasting Corporation.",
40 | [],
41 | ),
42 | ],
43 | )
44 | def test_call_default(text, expected):
45 | base_labels = {age: False for age in list(ReligionLabels)}
46 | data = {"text": text}
47 | disagg_module = Religion(column="text")
48 | results = disagg_module(data)
49 | assert results == {**base_labels, **{label: True for label in expected}}
50 |
51 |
52 | @pytest.mark.slow
53 | def test_call_custom():
54 | class CustomReligionLabels(ReligionLabels):
55 | BOKONONISM = "bokononism"
56 |
57 | data = {"text": "Busy, busy, busy."}
58 | disagg_module = Religion(config=ReligionConfig(labels=CustomReligionLabels), column="text")
59 |
60 | results = disagg_module(data)
61 |
62 | assert results[CustomReligionLabels.BOKONONISM]
63 |
--------------------------------------------------------------------------------
/tests/integration/test_disaggregation.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 | from datasets import Dataset
4 |
5 | from disaggregators import Disaggregator
6 | from disaggregators.disaggregation_modules.pronoun import Pronoun, PronounConfig, PronounLabels
7 |
8 |
9 | @pytest.fixture()
10 | def dataset():
11 | return Dataset.from_dict({"text": ["Hello world!", "Fizz buzz."]})
12 |
13 |
14 | class CustomPronounLabels(PronounLabels):
15 | ZE_ZIR = "ze_zir"
16 |
17 |
18 | @pytest.fixture()
19 | def disaggregator():
20 | _CUSTOM_PRONOUN_MAPPING = {CustomPronounLabels.ZE_ZIR: {"ze", "zir", "zirs", "zirself"}}
21 |
22 | disagg_module = Pronoun(
23 | config=PronounConfig(labels=CustomPronounLabels, pronouns=_CUSTOM_PRONOUN_MAPPING), column="text"
24 | )
25 |
26 | return Disaggregator(["age", "gender", disagg_module], column="text")
27 |
28 |
29 | @pytest.fixture()
30 | def expected_features():
31 | return {
32 | "age.child",
33 | "age.youth",
34 | "age.adult",
35 | "age.senior",
36 | "gender.male",
37 | "gender.female",
38 | }
39 |
40 |
41 | def test_datasets(dataset, disaggregator, expected_features):
42 | ds_mapped = dataset.map(disaggregator)
43 | assert expected_features.issubset(set(ds_mapped.features))
44 |
45 |
46 | def test_pandas(dataset, disaggregator, expected_features):
47 | df = dataset.to_pandas()
48 | new_cols = df.apply(disaggregator, axis=1)
49 | df = pd.merge(df, pd.json_normalize(new_cols), left_index=True, right_index=True)
50 | assert expected_features.issubset(set(df.columns))
51 |
52 |
53 | def test_each_module(dataset, module):
54 | disaggregator = Disaggregator(module, column="text")
55 | ds_mapped = dataset.map(disaggregator)
56 |
57 | expected_features = disaggregator.fields
58 |
59 | assert expected_features.issubset(set(ds_mapped.features))
60 |
--------------------------------------------------------------------------------
/tests/test_disaggregation_module.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from disaggregators import DisaggregationModule, DisaggregationModuleFactory
4 |
5 |
6 | def test_create_subclassed_module(dummy_module):
7 | custom_module = dummy_module(column=None)
8 | assert custom_module.name == "dummy-module"
9 |
10 |
11 | def test_load_module_from_string_id(mocker, dummy_module):
12 | mock_modules = mocker.MagicMock()
13 | mock_modules.AVAILABLE_MODULES = {"dummy-module": dummy_module}
14 | mocker.patch("disaggregators.disaggregation_modules.disaggregation_module.disaggregation_modules", mock_modules)
15 |
16 | loaded_module = DisaggregationModuleFactory.create_from_id(module_id="dummy-module", column=None)
17 | assert issubclass(type(loaded_module), DisaggregationModule)
18 | assert loaded_module.name == "dummy-module"
19 |
20 |
21 | def test_load_module_with_bad_string_id(mocker):
22 | mock_modules = mocker.MagicMock()
23 | mock_modules.AVAILABLE_MODULES = {}
24 | mocker.patch("disaggregators.disaggregation_modules.disaggregation_module.disaggregation_modules", mock_modules)
25 |
26 | with pytest.raises(ValueError, match="Invalid module_id received."):
27 | DisaggregationModuleFactory.create_from_id(module_id="bad-module")
28 |
29 |
30 | def test_override_module_config(configured_module, custom_dummy_labels, dummy_module, dummy_labels):
31 | assert configured_module.labels == custom_dummy_labels
32 |
33 | # Ensure that the original module hasn't been modified
34 | original_module = dummy_module(column=None)
35 | assert original_module.labels == dummy_labels
36 |
--------------------------------------------------------------------------------
/tests/test_disaggregator.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import pytest
4 | from datasets import Dataset
5 |
6 | from disaggregators import CustomDisaggregator, DisaggregationModule, Disaggregator
7 |
8 |
9 | class TestDisaggregator:
10 | @pytest.fixture(autouse=True)
11 | def mock_module_factory(self, mocker, dummy_module):
12 | mock_factory = mocker.MagicMock()
13 | mock_factory.create_module.side_effect = lambda module_id=None, column=None: dummy_module(
14 | module_id=module_id, column=column
15 | )
16 |
17 | mocker.patch("disaggregators.disaggregator.DisaggregationModuleFactory", mock_factory)
18 |
19 | def test_create_empty_disaggregator(self):
20 | disagg = Disaggregator()
21 | assert isinstance(disagg.modules, list)
22 | assert len(disagg.modules) == 0
23 |
24 | def test_create_disaggregator_with_single_module(self):
25 | disagg = Disaggregator("pronoun")
26 | assert len(disagg.modules) == 1
27 | assert isinstance(disagg.modules[0], DisaggregationModule)
28 | assert disagg.modules[0].name == "pronoun"
29 |
30 | def test_create_disaggregator_with_multiple_modules(self):
31 | disagg = Disaggregator(["pronoun", "spelling"])
32 | assert len(disagg.modules) == 2
33 | assert all([isinstance(module, DisaggregationModule) for module in disagg.modules])
34 | assert disagg.modules[0].name == "pronoun"
35 | assert disagg.modules[1].name == "spelling"
36 |
37 | def test_get_disaggregator_function_single_aggregation_module(self):
38 | disagg = Disaggregator("dummy-module")
39 | assert disagg({"a": 1, "b": 2}) == {"dummy-module.dummy-value-1": True, "dummy-module.dummy-value-2": True}
40 | disagg_func = disagg.get_function()
41 | assert isinstance(disagg_func, Callable)
42 | assert disagg_func({"a": 1, "b": 2}) == {
43 | "dummy-module.dummy-value-1": True,
44 | "dummy-module.dummy-value-2": True,
45 | }
46 |
47 | def test_get_disaggregator_function_multiple_aggregation_modules(self):
48 | disagg = Disaggregator(["dummy-one", "dummy-two"])
49 | assert disagg({"a": 1, "b": 2}) == {
50 | "dummy-one.dummy-value-1": True,
51 | "dummy-one.dummy-value-2": True,
52 | "dummy-two.dummy-value-1": True,
53 | "dummy-two.dummy-value-2": True,
54 | }
55 | disagg_func = disagg.get_function()
56 | assert isinstance(disagg_func, Callable)
57 | assert disagg_func({"a": 1, "b": 2}) == {
58 | "dummy-one.dummy-value-1": True,
59 | "dummy-one.dummy-value-2": True,
60 | "dummy-two.dummy-value-1": True,
61 | "dummy-two.dummy-value-2": True,
62 | }
63 |
64 | @pytest.mark.parametrize(
65 | "modules,expected",
66 | [
67 | ([], set()),
68 | (["dummy-one"], {"dummy-one.dummy-value-1", "dummy-one.dummy-value-2"}),
69 | (
70 | ["dummy-one", "dummy-two"],
71 | {
72 | "dummy-one.dummy-value-1",
73 | "dummy-one.dummy-value-2",
74 | "dummy-two.dummy-value-1",
75 | "dummy-two.dummy-value-2",
76 | },
77 | ),
78 | ],
79 | )
80 | def test_get_fields(self, modules, expected):
81 | disagg = Disaggregator(modules)
82 | assert disagg.fields == expected
83 |
84 |
85 | @pytest.fixture
86 | def custom_module(dummy_labels):
87 | class CustomModule(CustomDisaggregator):
88 | module_id = "custom"
89 | labels = dummy_labels
90 |
91 | def __call__(self, row, *args, **kwargs):
92 | return {dummy_labels.DUMMY_ONE: "cat", dummy_labels.DUMMY_TWO: "dog"}
93 |
94 | return CustomModule
95 |
96 |
97 | def test_inject_custom_module_subclass(custom_module, dummy_labels):
98 | disagg = Disaggregator(custom_module, column=None)
99 | assert disagg.fields == {"custom.dummy-value-1", "custom.dummy-value-2"}
100 |
101 | ds = Dataset.from_dict({"text": ["Hello world!"]}).map(disagg)
102 | assert set(ds.features) == {"text", *disagg.fields}
103 | assert ds[0] == {"text": "Hello world!", "custom.dummy-value-1": "cat", "custom.dummy-value-2": "dog"}
104 |
105 |
106 | def test_module_instance(configured_module, configured_dummy_expected_results):
107 | disagg = Disaggregator(configured_module, column=None)
108 | assert disagg.fields == configured_module.field_names
109 |
110 | ds = Dataset.from_dict({"text": ["Hello world!"]}).map(disagg)
111 | assert set(ds.features) == {"text", *disagg.fields}
112 | assert ds[0] == {"text": "Hello world!", **configured_dummy_expected_results}
113 |
--------------------------------------------------------------------------------