├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── conftest.py ├── noxfile.py ├── pyproject.toml ├── setup.cfg ├── setup.py ├── src └── disaggregators │ ├── __init__.py │ ├── disaggregation_modules │ ├── __init__.py │ ├── age │ │ └── __init__.py │ ├── continent │ │ └── __init__.py │ ├── disaggregation_module.py │ ├── gender │ │ └── __init__.py │ ├── pronoun │ │ └── __init__.py │ └── religion │ │ └── __init__.py │ └── disaggregator.py └── tests ├── __init__.py ├── disaggregation_modules ├── test_age.py ├── test_continent.py ├── test_gender.py ├── test_pronoun.py └── test_religion.py ├── integration └── test_disaggregation.py ├── test_disaggregation_module.py └── test_disaggregator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Locked files 2 | *.lock 3 | !dvc.lock 4 | 5 | # Extracted dummy data 6 | datasets/**/dummy_data-zip-extracted/ 7 | 8 | # Compiled python modules. 9 | *.pyc 10 | 11 | # Byte-compiled 12 | _pycache__/ 13 | .cache/ 14 | 15 | # Python egg metadata, regenerated from source files by setuptools. 16 | *.egg-info 17 | .eggs/ 18 | 19 | # PyPI distribution artifacts. 20 | build/ 21 | dist/ 22 | 23 | # Environments 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | env.bak/ 30 | venv.bak/ 31 | 32 | # pyenv 33 | .python-version 34 | 35 | # Tests 36 | .pytest_cache/ 37 | 38 | # Other 39 | *.DS_Store 40 | 41 | # PyCharm/vscode 42 | .idea 43 | .vscode 44 | 45 | # Vim 46 | .*.swp 47 | 48 | # playground 49 | /playground 50 | 51 | # Sphinx documentation 52 | docs/_build/ 53 | docs/source/_build/ 54 | 55 | notebooks/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style test 2 | 3 | # Check that source code meets quality standards 4 | 5 | quality: 6 | black --check --line-length 119 --target-version py36 tests src 7 | isort --check-only tests src 8 | flake8 tests src 9 | 10 | # Format source code automatically 11 | 12 | style: 13 | black --line-length 119 --target-version py36 tests src 14 | isort tests src 15 | 16 | # Run tests for the library 17 | 18 | test: 19 | python -m pytest -n auto --dist=loadfile ./tests/ 20 | 21 | test-fast: 22 | python -m pytest -n auto --dist=loadfile -m "not slow" ./tests/ 23 | 24 | 25 | # Utility for Nox 26 | 27 | load_spacy_model: 28 | spacy validate | grep en_core_web_lg || spacy download en_core_web_lg 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | Hugging Face Disaggregators 4 |
5 |

6 | 7 |

8 | 9 | GitHub 10 | 11 | 12 | GitHub release 13 | 14 |

15 | 16 | > ⚠️ Please note: This library is in early development, and the disaggregation modules that are included are proofs of concept that are _not_ production-ready. Additionally, all APIs are subject to breaking changes any time before a 1.0.0 release. Rigorously tested versions of the included modules will be released in the future, so stay tuned. [We'd love your feedback in the meantime!](https://github.com/huggingface/disaggregators/discussions/23) 17 | 18 | The `disaggregators` library allows you to easily add new features to your datasets to enable disaggregated data exploration and disaggregated model evaluation. `disaggregators` is preloaded with disaggregation modules for text data, with image modules coming soon! 19 | 20 | This library is intended to be used with [🤗 Datasets](https://github.com/huggingface/datasets), but should work with any other "mappable" interface to a dataset. 21 | 22 | ## Requirements and Installation 23 | 24 | `disaggregators` has been tested on Python 3.8, 3.9, and 3.10. 25 | 26 | `pip install disaggregators` will fetch the latest release from PyPI. 27 | 28 | Note that some disaggregation modules require extra dependencies such as SpaCy modules, which may need to be installed manually. If these dependencies aren't installed, `disaggregators` will inform you about how to install them. 29 | 30 | To install directly from this GitHub repo, use the following command: 31 | ```shell 32 | pip install git+https://github.com/huggingface/disaggregators.git 33 | ``` 34 | 35 | ## Usage 36 | 37 | You will likely want to use 🤗 Datasets with `disaggregators`. 38 | 39 | ```shell 40 | pip install datasets 41 | ``` 42 | 43 | The snippet below loads the IMDB dataset from the Hugging Face Hub, and initializes a disaggregator for "pronoun" that will run on the IMDB dataset's "text" column. If you would like to run multiple disaggregations, you can pass a list to the `Disaggregator` constructor (e.g. `Disaggregator(["pronoun", "sentiment"], column="text")`). We then use the 🤗 Datasets `map` method to apply the disaggregation to the dataset. 44 | 45 | ```python 46 | from disaggregators import Disaggregator 47 | from datasets import load_dataset 48 | 49 | dataset = load_dataset("imdb", split="train") 50 | disaggregator = Disaggregator("pronoun", column="text") 51 | 52 | ds = dataset.map(disaggregator) # New boolean columns are added for she/her, he/him, and they/them 53 | ``` 54 | 55 | The resulting dataset can now be used for data exploration and disaggregated model evaluation. 56 | 57 | You can also run disaggregations on Pandas DataFrames with `.apply` and `.merge`: 58 | 59 | ```python 60 | from disaggregators import Disaggregator 61 | import pandas as pd 62 | df = pd.DataFrame({"text": ["They went to the park."]}) 63 | 64 | disaggregator = Disaggregator("pronoun", column="text") 65 | 66 | new_cols = df.apply(disaggregator, axis=1) 67 | df = pd.merge(df, pd.json_normalize(new_cols), left_index=True, right_index=True) 68 | ``` 69 | 70 | ### Available Disaggregation Modules 71 | 72 | The following modules are currently available: 73 | 74 | - `"age"` 75 | - `"gender"` 76 | - `"pronoun"` 77 | - `"religion"` 78 | - `"continent"` 79 | 80 | Note that `disaggregators` is in active development, and that these (and future) modules are subject to changing interfaces and implementations at any time before a `1.0.0` release. Each module provides its own method for overriding the default configuration, with the general interface documented below. 81 | 82 | ### Module Configurations 83 | 84 | Modules may make certain variables and functionality configurable. If you'd like to configure a module, import the module, its labels, and its config class. Then, override the labels and set the configuration as needed while instantiating the module. Once instantiated, you can pass the module to the `Disaggregator`. The example below shows this with the `Age` module. 85 | 86 | ```python 87 | from disaggregators import Disaggregator 88 | from disaggregators.disaggregation_modules.age import Age, AgeLabels, AgeConfig 89 | 90 | class MeSHAgeLabels(AgeLabels): 91 | INFANT = "infant" 92 | CHILD_PRESCHOOL = "child_preschool" 93 | CHILD = "child" 94 | ADOLESCENT = "adolescent" 95 | ADULT = "adult" 96 | MIDDLE_AGED = "middle_aged" 97 | AGED = "aged" 98 | AGED_80_OVER = "aged_80_over" 99 | 100 | age = Age( 101 | config=AgeConfig( 102 | labels=MeSHAgeLabels, 103 | ages=[list(MeSHAgeLabels)], 104 | breakpoints=[0, 2, 5, 12, 18, 44, 64, 79] 105 | ), 106 | column="question" 107 | ) 108 | 109 | disaggregator = Disaggregator([age, "gender"], column="question") 110 | ``` 111 | 112 | ### Custom Modules 113 | 114 | Custom modules can be created by extending the `CustomDisaggregator`. All custom modules must have `labels` and a `module_id`, and must implement a `__call__` method. 115 | 116 | ```python 117 | from disaggregators import Disaggregator, DisaggregationModuleLabels, CustomDisaggregator 118 | 119 | class TabsSpacesLabels(DisaggregationModuleLabels): 120 | TABS = "tabs" 121 | SPACES = "spaces" 122 | 123 | class TabsSpaces(CustomDisaggregator): 124 | module_id = "tabs_spaces" 125 | labels = TabsSpacesLabels 126 | 127 | def __call__(self, row, *args, **kwargs): 128 | if "\t" in row[self.column]: 129 | return {self.labels.TABS: True, self.labels.SPACES: False} 130 | else: 131 | return {self.labels.TABS: False, self.labels.SPACES: True} 132 | 133 | disaggregator = Disaggregator(TabsSpaces, column="text") 134 | ``` 135 | 136 | ## Development 137 | 138 | Development requirements can be installed with `pip install .[dev]`. See the `Makefile` for useful targets, such as code quality and test running. 139 | 140 | To run tests locally across multiple Python versions (3.8, 3.9, and 3.10), ensure that you have all the Python versions available and then run `nox -r`. Note that this is quite slow, so it's only worth doing to double-check your code before you open a Pull Request. 141 | 142 | ## Contact 143 | 144 | Nima Boscarino – `nima huggingface co` 145 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | 4 | from typing import Type 5 | 6 | from disaggregators import DisaggregationModule, DisaggregationModuleLabels, DisaggregationModuleConfig 7 | 8 | 9 | def pytest_collection_modifyitems(items): 10 | for item in items: 11 | if "integration" in item.nodeid: 12 | item.add_marker(pytest.mark.slow) 13 | item.add_marker(pytest.mark.integration) 14 | 15 | 16 | def pytest_generate_tests(metafunc): 17 | if metafunc.definition.name == "test_each_module": 18 | metafunc.parametrize("module", [ 19 | x.name for x in os.scandir( 20 | metafunc.definition.config.rootdir / "src/disaggregators/disaggregation_modules" 21 | ) if x.is_dir() and not x.name.startswith("__") 22 | ]) 23 | 24 | # Fixtures 25 | 26 | 27 | class DummyLabels(DisaggregationModuleLabels): 28 | DUMMY_ONE = "dummy-value-1" 29 | DUMMY_TWO = "dummy-value-2" 30 | 31 | 32 | @pytest.fixture 33 | def dummy_labels(): 34 | return DummyLabels 35 | 36 | 37 | @pytest.fixture 38 | def dummy_module_config(dummy_labels): 39 | class DummyModuleConfig(DisaggregationModuleConfig): 40 | def __init__(self, labels: Type[dummy_labels]): 41 | self.labels = labels 42 | 43 | return DummyModuleConfig 44 | 45 | 46 | @pytest.fixture 47 | def dummy_module(dummy_labels, dummy_module_config): 48 | class DummyModule(DisaggregationModule): 49 | labels = dummy_labels 50 | 51 | def __init__(self, module_id="dummy-module", *args, **kwargs): 52 | super().__init__(module_id=module_id, *args, **kwargs) 53 | 54 | def _apply_config(self, config): 55 | self.labels = config.labels 56 | 57 | def __call__(self, row, *args, **kwargs): 58 | return {label: True for label in list(self.labels)} 59 | 60 | return DummyModule 61 | 62 | 63 | @pytest.fixture 64 | def custom_dummy_labels(): 65 | class CustomDummyLabels(DisaggregationModuleLabels): 66 | DUMMY_ONE = "dummy-value-1" 67 | DUMMY_TWO = "dummy-value-2" 68 | DUMMY_THREE = "dummy-value-3" 69 | 70 | return CustomDummyLabels 71 | 72 | 73 | @pytest.fixture 74 | def configured_module(custom_dummy_labels, dummy_module_config, dummy_module): 75 | return dummy_module(config=dummy_module_config(labels=custom_dummy_labels), column=None) 76 | 77 | 78 | @pytest.fixture 79 | def configured_dummy_expected_results(custom_dummy_labels, configured_module): 80 | return {f"{configured_module.name}.{label}": True for label in custom_dummy_labels} 81 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import nox 2 | 3 | python_versions = ["3.8", "3.9", "3.10"] 4 | 5 | 6 | @nox.session(python=python_versions) 7 | def tests(session): 8 | session.install(".[tests]") 9 | session.run("make", "load_spacy_model", external=True) 10 | session.run("make", "test", external=True) 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | markers = [ 3 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 4 | "integration: marks integration tests (deselect with '-m \"not integration\"')", 5 | ] 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = LICENSE 3 | 4 | [isort] 5 | ensure_newline_before_comments = True 6 | force_grid_wrap = 0 7 | include_trailing_comma = True 8 | line_length = 119 9 | lines_after_imports = 2 10 | multi_line_output = 3 11 | use_parentheses = True 12 | 13 | [flake8] 14 | ignore = E203, E501, W503 15 | max-line-length = 119 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """ HuggingFace/Disaggregators is an open library for disaggregating datasets. 3 | 4 | Note: 5 | 6 | VERSION needs to be formatted following the MAJOR.MINOR.PATCH convention 7 | 8 | Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 9 | 10 | To create the package for pypi. 11 | 12 | 0. Prerequisites: 13 | - Dependencies: 14 | - twine: "pip install twine" 15 | - Create an account in (and join the 'disaggregators' project): 16 | - PyPI: https://pypi.org/ 17 | - Test PyPI: https://test.pypi.org/ 18 | 19 | 1. Change the version in: 20 | - __init__.py 21 | - setup.py 22 | 23 | 2. Commit these changes: "git commit -m 'Release: VERSION'" 24 | 25 | 3. Add a tag in git to mark the release: "git tag VERSION -m 'Add tag VERSION for pypi'" 26 | Push the tag to remote: git push --tags origin main 27 | 28 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 29 | creating the wheel and the source distribution (obviously). 30 | 31 | First, delete any "build" directory that may exist from previous builds. 32 | 33 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory. 34 | (this will build a wheel for the python version you use to build it). 35 | 36 | For the sources, run: "python setup.py sdist" 37 | You should now have a /dist directory with both .whl and .tar.gz source versions. 38 | 39 | 5. Check that everything looks correct by uploading the package to the pypi test server: 40 | 41 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 42 | 43 | Check that you can install it in a virtualenv/notebook by running: 44 | pip install -i https://testpypi.python.org/pypi disaggregators 45 | 46 | 6. Upload the final version to actual pypi: 47 | twine upload dist/* -r pypi 48 | 49 | 7. Fill release notes in the tag in GitHub once everything is looking hunky-dory. 50 | 51 | 8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0). 52 | Then push the change with a message 'set dev version' 53 | """ 54 | 55 | from setuptools import find_packages, setup 56 | 57 | 58 | REQUIRED_PKGS = [ 59 | # Utilities from PyPA to e.g., compare versions 60 | "packaging", 61 | "spacy", 62 | "datasets", 63 | "aenum>=3.1.11", 64 | "sentence-transformers>=2.2.2", 65 | "geograpy3", 66 | "nltk", 67 | "requests", 68 | ] 69 | 70 | TESTS_REQUIRE = [ 71 | # test dependencies 72 | "pytest", 73 | "pytest-datadir", 74 | "pytest-xdist", 75 | "pytest-mock", 76 | "nox", 77 | "pandas", 78 | ] 79 | 80 | QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.0.0", "pyyaml>=5.3.1"] 81 | 82 | 83 | EXTRAS_REQUIRE = { 84 | "dev": TESTS_REQUIRE + QUALITY_REQUIRE, 85 | "tests": TESTS_REQUIRE, 86 | "quality": QUALITY_REQUIRE, 87 | } 88 | 89 | setup( 90 | name="disaggregators", 91 | version="0.1.3.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) 92 | description="HuggingFace community-driven open-source library for dataset disaggregation", 93 | long_description=open("README.md", encoding="utf-8").read(), 94 | long_description_content_type="text/markdown", 95 | author="HuggingFace Inc.", 96 | author_email="nima@huggingface.co", 97 | url="https://github.com/NimaBoscarino/disaggregators", 98 | download_url="https://github.com/NimaBoscarino/disaggregators/tags", 99 | license="Apache 2.0", 100 | package_dir={"": "src"}, 101 | packages=find_packages("src"), 102 | install_requires=REQUIRED_PKGS, 103 | extras_require=EXTRAS_REQUIRE, 104 | python_requires=">=3.7.0", 105 | classifiers=[ 106 | "Development Status :: 1 - Planning", 107 | "Intended Audience :: Developers", 108 | "Intended Audience :: Education", 109 | "Intended Audience :: Science/Research", 110 | "License :: OSI Approved :: Apache Software License", 111 | "Operating System :: OS Independent", 112 | "Programming Language :: Python :: 3", 113 | "Programming Language :: Python :: 3.8", 114 | "Programming Language :: Python :: 3.9", 115 | "Programming Language :: Python :: 3.10", 116 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 117 | ], 118 | keywords="machine learning evaluate evaluation disaggregation", 119 | zip_safe=False, # Required for mypy to find the py.typed file 120 | ) 121 | -------------------------------------------------------------------------------- /src/disaggregators/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # Copyright 2022 The HuggingFace Disaggregators Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | # pylint: enable=line-too-long 18 | # pylint: disable=g-import-not-at-top,g-bad-import-order,wrong-import-position 19 | 20 | __version__ = "0.1.3.dev0" 21 | 22 | from packaging import version 23 | 24 | from disaggregators.disaggregation_modules import ( 25 | CustomDisaggregator, 26 | DisaggregationModule, 27 | DisaggregationModuleConfig, 28 | DisaggregationModuleFactory, 29 | DisaggregationModuleLabels, 30 | ) 31 | 32 | from .disaggregator import Disaggregator 33 | 34 | 35 | SCRIPTS_VERSION = "main" if version.parse(__version__).is_devrelease else __version__ 36 | 37 | del version 38 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregation_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .age import Age 2 | from .continent import Continent 3 | from .disaggregation_module import ( 4 | CustomDisaggregator, 5 | DisaggregationModule, 6 | DisaggregationModuleConfig, 7 | DisaggregationModuleFactory, 8 | DisaggregationModuleLabels, 9 | ) 10 | from .gender import Gender 11 | from .pronoun import Pronoun 12 | from .religion import Religion 13 | 14 | 15 | AVAILABLE_MODULES = {"pronoun": Pronoun, "age": Age, "gender": Gender, "religion": Religion, "continent": Continent} 16 | 17 | __all__ = [ 18 | "DisaggregationModule", 19 | "DisaggregationModuleFactory", 20 | "DisaggregationModuleLabels", 21 | "CustomDisaggregator", 22 | "DisaggregationModuleConfig", 23 | ] 24 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregation_modules/age/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | from bisect import bisect 3 | from typing import List, Type 4 | 5 | import spacy 6 | 7 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels 8 | 9 | 10 | class AgeLabels(DisaggregationModuleLabels): 11 | CHILD = "child" 12 | YOUTH = "youth" 13 | ADULT = "adult" 14 | SENIOR = "senior" 15 | 16 | 17 | class AgeConfig(DisaggregationModuleConfig): 18 | def __init__(self, labels: Type[AgeLabels], ages: List, breakpoints: List): 19 | self.labels = labels 20 | self.ages = ages 21 | self.breakpoints = breakpoints 22 | 23 | 24 | class Age(DisaggregationModule): 25 | labels = AgeLabels 26 | AGES = [AgeLabels.CHILD, AgeLabels.YOUTH, AgeLabels.ADULT, AgeLabels.SENIOR] 27 | AGE_BREAKPOINTS = [0, 12, 20, 65] 28 | spacy_model = "en_core_web_lg" 29 | 30 | def __init__(self, *args, **kwargs): 31 | try: 32 | self.nlp = spacy.load(self.spacy_model, enable="ner") 33 | except OSError: 34 | raise ValueError( 35 | f"This disaggregation module depends on the {self.spacy_model} model from spaCy.\n" 36 | f"You can install it by running: python -m spacy download {self.spacy_model}" 37 | ) 38 | 39 | super().__init__(module_id="age", *args, **kwargs) 40 | 41 | def _apply_config(self, config: AgeConfig): 42 | self.labels = config.labels 43 | self.AGES = config.ages 44 | self.AGE_BREAKPOINTS = config.breakpoints 45 | 46 | def __call__(self, row, *args, **kwargs): 47 | return_ages = {age: False for age in self.AGES} 48 | text = row[self.column] 49 | doc = self.nlp(text) 50 | date_entities = [y for y in doc.ents if y.label_ == "DATE"] 51 | 52 | for date_entity in date_entities: 53 | value = re.search(r"\d+", date_entity.text) 54 | 55 | if value is None: 56 | continue 57 | 58 | value = int(value[0]) 59 | 60 | if value <= 120: 61 | age_bucket = bisect(self.AGE_BREAKPOINTS, value) 62 | if age_bucket >= 1: 63 | return_ages.update({self.AGES[age_bucket - 1]: True}) 64 | 65 | return return_ages 66 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregation_modules/continent/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # noinspection PyPackageRequirements 4 | import geograpy # from geograpy3 5 | import nltk 6 | import requests 7 | 8 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleLabels 9 | 10 | 11 | class ContinentLabels(DisaggregationModuleLabels): 12 | AFRICA = "africa" 13 | AMERICAS = "americas" 14 | ASIA = "asia" 15 | EUROPE = "europe" 16 | OCEANIA = "oceania" 17 | 18 | 19 | class Continent(DisaggregationModule): 20 | labels = ContinentLabels 21 | continents = [ 22 | ContinentLabels.AFRICA, 23 | ContinentLabels.AMERICAS, 24 | ContinentLabels.ASIA, 25 | ContinentLabels.EUROPE, 26 | ContinentLabels.OCEANIA, 27 | ] 28 | 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(module_id="continent", *args, **kwargs) 31 | 32 | countries_url = "https://raw.githubusercontent.com/bigscience-workshop/data_sourcing/master/"\ 33 | "sourcing_sprint/resources/country_regions.json" 34 | 35 | response = json.loads(requests.get(countries_url).text) 36 | 37 | self.continents = response[0] 38 | countries = response[1] 39 | region_countries = response[2] 40 | 41 | def get_countries_and_regions(continent_or_region): 42 | return_countries_and_regions = {"countries": [], "regions": []} 43 | 44 | for region in region_countries.get(continent_or_region): 45 | if region in countries: 46 | return_countries_and_regions["countries"] = return_countries_and_regions["countries"] + [region] 47 | else: 48 | countries_and_regions = get_countries_and_regions(region) 49 | return_countries_and_regions["regions"] = return_countries_and_regions["regions"] + [region] 50 | return_countries_and_regions["countries"] = ( 51 | return_countries_and_regions["countries"] + countries_and_regions["countries"] 52 | ) 53 | 54 | return return_countries_and_regions 55 | 56 | continent_maps = {c: get_countries_and_regions(c) for c in self.continents} 57 | 58 | self.continent_lists = [ 59 | [c, *continent_maps[c]["regions"], *continent_maps[c]["countries"]] for c in continent_maps 60 | ] 61 | 62 | nltk.download("punkt") 63 | nltk.download("averaged_perceptron_tagger") 64 | nltk.download("maxent_ne_chunker") 65 | nltk.download("words") 66 | 67 | def __call__(self, row, *args, **kwargs): 68 | return_continent = {continent: False for continent in list(ContinentLabels)} 69 | 70 | places = geograpy.get_geoPlace_context(text=row[self.column]).countries 71 | 72 | if not len(places) > 0: 73 | return return_continent 74 | 75 | continent_search = [cl[0] for cl in self.continent_lists if places[0] in cl] 76 | 77 | if len(continent_search) > 0: 78 | continent = continent_search[0] 79 | label = getattr(ContinentLabels, continent.upper()) 80 | return_continent.update({label: True}) 81 | 82 | return return_continent 83 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregation_modules/disaggregation_module.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional, Type, Union 3 | 4 | from aenum import Constant 5 | 6 | from disaggregators import disaggregation_modules 7 | 8 | 9 | class DisaggregationModuleLabels(Constant): 10 | pass 11 | 12 | 13 | class DisaggregationModuleConfig: 14 | labels: Type[DisaggregationModuleLabels] 15 | 16 | 17 | class DisaggregationModule(ABC): 18 | def __init__(self, module_id: str, column: Optional[str], config: DisaggregationModuleConfig = None): 19 | self.name = module_id 20 | self.column = column 21 | self.citations: List[str] = [] 22 | 23 | if config: 24 | self._apply_config(config) 25 | 26 | @abstractmethod 27 | def __call__(self, row, *args, **kwargs): 28 | raise NotImplementedError() 29 | 30 | @property 31 | @abstractmethod 32 | def labels(self) -> Type[DisaggregationModuleLabels]: 33 | pass 34 | 35 | def _apply_config(self, config: DisaggregationModuleConfig): 36 | pass 37 | 38 | @property 39 | def field_names(self): 40 | return {f"{self.name}.{x}" for x in list(self.labels)} 41 | 42 | 43 | class CustomDisaggregator(DisaggregationModule, ABC): 44 | """ 45 | This class exists to provide a simple interface for creating custom disaggregation modules. This is useful because 46 | the DisaggregationModule abstract class may enforce extra rules that we don't want users to have to worry about. 47 | """ 48 | 49 | def __init__(self, *args, **kwargs): 50 | super().__init__(module_id=self.module_id, *args, **kwargs) 51 | 52 | @property 53 | @abstractmethod 54 | def module_id(self): 55 | raise NotImplementedError() 56 | 57 | @property 58 | @abstractmethod 59 | def labels(self): 60 | raise NotImplementedError() 61 | 62 | 63 | class DisaggregationModuleFactory: 64 | @staticmethod 65 | def create_module(module: Union[str, Type[CustomDisaggregator], DisaggregationModule], *args, **kwargs): 66 | if isinstance(module, str): 67 | return DisaggregationModuleFactory.create_from_id(module, *args, **kwargs) 68 | elif isinstance(module, DisaggregationModule): 69 | return module 70 | elif issubclass(module, CustomDisaggregator): 71 | return DisaggregationModuleFactory.create_from_class(module, *args, **kwargs) 72 | else: 73 | raise ValueError("Invalid module type received.") 74 | 75 | @staticmethod 76 | def create_from_id(module_id: str, *args, **kwargs) -> DisaggregationModule: 77 | if module_id not in disaggregation_modules.AVAILABLE_MODULES: 78 | raise ValueError("Invalid module_id received.") 79 | 80 | return disaggregation_modules.AVAILABLE_MODULES[module_id](*args, **kwargs) 81 | 82 | @staticmethod 83 | def create_from_class(module: Type[CustomDisaggregator], *args, **kwargs) -> DisaggregationModule: 84 | return module(*args, **kwargs) 85 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregation_modules/gender/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type 2 | 3 | import pandas as pd 4 | import spacy 5 | from datasets import load_dataset 6 | 7 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels 8 | 9 | 10 | class GenderLabels(DisaggregationModuleLabels): 11 | MALE = "male" 12 | FEMALE = "female" 13 | 14 | 15 | class GenderConfig(DisaggregationModuleConfig): 16 | def __init__(self, labels: Type[GenderLabels], word_lists: Dict[GenderLabels, List[str]]): 17 | self.labels = labels 18 | self.word_lists = word_lists 19 | 20 | 21 | class Gender(DisaggregationModule): 22 | labels = GenderLabels 23 | spacy_model = "en_core_web_lg" 24 | 25 | citations = [ 26 | """ 27 | @inproceedings{dinan-etal-2020-multi, 28 | title = "Multi-Dimensional Gender Bias Classification", 29 | author = "Dinan, Emily and 30 | Fan, Angela and 31 | Wu, Ledell and 32 | Weston, Jason and 33 | Kiela, Douwe and 34 | Williams, Adina", 35 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", # noqa: E501 36 | year = "2020", 37 | publisher = "Association for Computational Linguistics", 38 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.23", 39 | doi = "10.18653/v1/2020.emnlp-main.23", 40 | } 41 | """ 42 | ] 43 | 44 | def __init__(self, *args, **kwargs): 45 | self.gender_df = load_dataset("md_gender_bias", "gendered_words", split="train").to_pandas() 46 | self.gender_df.columns = [GenderLabels.MALE, GenderLabels.FEMALE] 47 | 48 | try: 49 | self.nlp = spacy.load(self.spacy_model) 50 | except OSError: 51 | raise ValueError( 52 | f"This disaggregation module depends on the {self.spacy_model} model from spaCy.\n" 53 | f"You can install it by running: python -m spacy download {self.spacy_model}" 54 | ) 55 | 56 | super().__init__(module_id="gender", *args, **kwargs) 57 | 58 | def _apply_config(self, config: GenderConfig): 59 | self.labels = config.labels 60 | for gender, word_list in config.word_lists.items(): 61 | self.gender_df[gender] = pd.DataFrame(word_list) 62 | 63 | def __call__(self, row, *args, **kwargs): 64 | return_genders = {gender: False for gender in list(GenderLabels)} 65 | 66 | doc = self.nlp(row[self.column]) 67 | 68 | nouns = [c.text for x in doc.noun_chunks for c in x.root.subtree] 69 | 70 | for noun in nouns: 71 | result = self.gender_df[self.gender_df == noun].any() 72 | for gender_hit in list(result[result].keys()): 73 | return_genders.update({gender_hit: True}) 74 | 75 | return return_genders 76 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregation_modules/pronoun/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Set, Type 2 | 3 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels 4 | 5 | 6 | class PronounLabels(DisaggregationModuleLabels): 7 | SHE_HER = "she_her" 8 | HE_HIM = "he_him" 9 | THEY_THEM = "they_them" 10 | 11 | 12 | class PronounConfig(DisaggregationModuleConfig): 13 | def __init__(self, labels: Type[PronounLabels], pronouns: Dict[PronounLabels, Set[str]]): 14 | self.labels = labels 15 | self.pronouns = pronouns 16 | 17 | 18 | class Pronoun(DisaggregationModule): 19 | labels = PronounLabels 20 | AVAILABLE_PRONOUNS = { 21 | PronounLabels.SHE_HER: {"she", "her", "hers", "herself"}, 22 | PronounLabels.HE_HIM: {"he", "him", "his", "himself"}, 23 | PronounLabels.THEY_THEM: {"they", "them", "their", "theirs", "themself", "themselves"}, 24 | } 25 | 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(module_id="pronoun", *args, **kwargs) 28 | 29 | def _apply_config(self, config: PronounConfig): 30 | self.labels = config.labels 31 | self.AVAILABLE_PRONOUNS = {**config.pronouns, **self.AVAILABLE_PRONOUNS} 32 | 33 | def __call__(self, row, *args, **kwargs): 34 | text = row[self.column] 35 | pronoun_flag = { 36 | av_p: any(p in text.lower().split() for p in self.AVAILABLE_PRONOUNS[av_p]) 37 | for av_p in self.AVAILABLE_PRONOUNS 38 | } 39 | 40 | return pronoun_flag 41 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregation_modules/religion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | from sentence_transformers import SentenceTransformer 4 | from sentence_transformers.util import semantic_search 5 | 6 | from ..disaggregation_module import DisaggregationModule, DisaggregationModuleConfig, DisaggregationModuleLabels 7 | 8 | 9 | class ReligionLabels(DisaggregationModuleLabels): 10 | JUDAISM = "judaism" 11 | ISLAM = "islam" 12 | BUDDHISM = "buddhism" 13 | CHRISTIANITY = "christianity" 14 | 15 | 16 | class ReligionConfig(DisaggregationModuleConfig): 17 | def __init__(self, labels: Type[ReligionLabels] = None, threshold: float = None): 18 | self.labels = labels 19 | self.threshold = threshold 20 | 21 | 22 | class Religion(DisaggregationModule): 23 | labels = ReligionLabels 24 | threshold = 0.14 # Arbitrary threshold, hand-tuned. 25 | religions = [ 26 | ReligionLabels.JUDAISM, 27 | ReligionLabels.ISLAM, 28 | ReligionLabels.BUDDHISM, 29 | ReligionLabels.CHRISTIANITY, 30 | ] 31 | 32 | def __init__(self, *args, **kwargs): 33 | self.model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") 34 | 35 | super().__init__(module_id="religion", *args, **kwargs) 36 | 37 | self.embeddings = self.model.encode([str(religion) for religion in self.religions], convert_to_tensor=True) 38 | 39 | def _apply_config(self, config: ReligionConfig): 40 | if config.labels: 41 | self.labels = config.labels 42 | self.religions = self.religions + list(config.labels) 43 | 44 | self.threshold = config.threshold or self.threshold 45 | 46 | def __call__(self, row, *args, **kwargs): 47 | return_religion = {religion: False for religion in list(ReligionLabels)} 48 | 49 | query = self.model.encode(row[self.column], convert_to_tensor=True) 50 | religion_hit = semantic_search(query, self.embeddings, top_k=1)[0][0] 51 | 52 | if religion_hit["score"] > self.threshold: 53 | return_religion.update({self.religions[religion_hit["corpus_id"]]: True}) 54 | 55 | return return_religion 56 | -------------------------------------------------------------------------------- /src/disaggregators/disaggregator.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Set, Type, Union 2 | 3 | from disaggregators.disaggregation_modules import ( 4 | CustomDisaggregator, 5 | DisaggregationModule, 6 | DisaggregationModuleFactory, 7 | ) 8 | 9 | 10 | class Disaggregator: 11 | def __init__( 12 | self, 13 | module: Optional[ 14 | Union[ 15 | str, 16 | List[str], 17 | DisaggregationModule, 18 | List[DisaggregationModule], 19 | Type[CustomDisaggregator], 20 | List[Type[CustomDisaggregator]], 21 | ] 22 | ] = None, 23 | *args, 24 | **kwargs, 25 | ): 26 | if module is None: 27 | module = [] 28 | 29 | if not isinstance(module, list): 30 | module_list = [module] 31 | else: 32 | module_list = module 33 | 34 | self.modules = [DisaggregationModuleFactory.create_module(module, *args, **kwargs) for module in module_list] 35 | 36 | def get_function(self) -> Callable: 37 | # Merge dicts - https://stackoverflow.com/a/3495395 38 | return lambda x: { 39 | f"{d[0]}.{str(k)}": v 40 | for d in [(module.name, module(x)) for module in self.modules] 41 | for k, v in d[1].items() 42 | } 43 | 44 | def __call__(self, x) -> Callable: 45 | return self.get_function()(x) 46 | 47 | @property 48 | def fields(self) -> Set: 49 | return {*[f"{module.name}.{str(label)}" for module in self.modules for label in module.labels]} 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/disaggregators/b8ea3170b119e6768812874b71367b8932db0d37/tests/__init__.py -------------------------------------------------------------------------------- /tests/disaggregation_modules/test_age.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from disaggregators.disaggregation_modules.age import Age, AgeConfig, AgeLabels 4 | 5 | 6 | def test_initialize(): 7 | disagg_module = Age(column=None) 8 | assert disagg_module.name == "age" 9 | assert set(disagg_module.labels) == {AgeLabels.CHILD, AgeLabels.YOUTH, AgeLabels.ADULT, AgeLabels.SENIOR} 10 | 11 | 12 | @pytest.mark.slow 13 | @pytest.mark.parametrize( 14 | "text,expected", 15 | [ 16 | ("The man went to the park.", []), 17 | ("The 40-year-old man went to the park.", [AgeLabels.ADULT]), 18 | ("The 10 year old child ate a hot dog.", [AgeLabels.CHILD]), 19 | ("Clara is 18 years old.", [AgeLabels.YOUTH]), 20 | ("Hamed's grandpa is an 86 yr old smoker.", [AgeLabels.SENIOR]), 21 | ("The 40-year-old man went to the park with his 10 year old daughter.", [AgeLabels.ADULT, AgeLabels.CHILD]), 22 | ("Farzaneh is 18 years old and her grandmother is 86 yrs old", [AgeLabels.YOUTH, AgeLabels.SENIOR]), 23 | ], 24 | ) 25 | def test_call_default(text, expected): 26 | base_labels = {age: False for age in list(AgeLabels)} 27 | data = {"text": text} 28 | disagg_module = Age(column="text") 29 | results = disagg_module(data) 30 | assert results == {**base_labels, **{label: True for label in expected}} 31 | 32 | 33 | @pytest.mark.slow 34 | def test_call_custom(): 35 | class CustomAgeLabels(AgeLabels): 36 | NIMA = "nima" 37 | OLDER_THAN_NIMA = "older_nima" 38 | YOUNGER_THAN_NIMA = "younger_nima" 39 | 40 | examples = [ 41 | {"text": "Jenny is 20 years old."}, 42 | {"text": "Nima is 26 years old."}, 43 | {"text": "Mark Knopfler is 73 years old."}, 44 | ] 45 | 46 | disagg_module = Age( 47 | config=AgeConfig( 48 | labels=CustomAgeLabels, 49 | ages=[CustomAgeLabels.YOUNGER_THAN_NIMA, CustomAgeLabels.NIMA, CustomAgeLabels.OLDER_THAN_NIMA], 50 | breakpoints=[0, 25, 27], 51 | ), 52 | column="text", 53 | ) 54 | results = [disagg_module(example) for example in examples] 55 | 56 | assert results == [ 57 | {CustomAgeLabels.NIMA: False, CustomAgeLabels.OLDER_THAN_NIMA: False, CustomAgeLabels.YOUNGER_THAN_NIMA: True}, 58 | {CustomAgeLabels.NIMA: True, CustomAgeLabels.OLDER_THAN_NIMA: False, CustomAgeLabels.YOUNGER_THAN_NIMA: False}, 59 | {CustomAgeLabels.NIMA: False, CustomAgeLabels.OLDER_THAN_NIMA: True, CustomAgeLabels.YOUNGER_THAN_NIMA: False}, 60 | ] 61 | -------------------------------------------------------------------------------- /tests/disaggregation_modules/test_continent.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from disaggregators.disaggregation_modules.continent import Continent, ContinentLabels 4 | 5 | 6 | def test_initialize(): 7 | disagg_module = Continent(column=None) 8 | assert disagg_module.name == "continent" 9 | assert set(disagg_module.labels) == { 10 | ContinentLabels.AFRICA, 11 | ContinentLabels.AMERICAS, 12 | ContinentLabels.ASIA, 13 | ContinentLabels.EUROPE, 14 | ContinentLabels.OCEANIA, 15 | } 16 | 17 | 18 | @pytest.mark.slow 19 | @pytest.mark.parametrize( 20 | "text,expected", 21 | [ 22 | ( 23 | "Tomorrow I am visiting Italy.", 24 | [ContinentLabels.EUROPE], 25 | ), 26 | ( 27 | "A gentle breeze blows through the leaves of a cherry blossom tree in Vancouver.", 28 | [ContinentLabels.AMERICAS], 29 | ), 30 | ( 31 | "Let's eat pizza.", 32 | [], 33 | ), 34 | ], 35 | ) 36 | def test_call_default(text, expected): 37 | base_labels = {age: False for age in list(ContinentLabels)} 38 | data = {"text": text} 39 | disagg_module = Continent(column="text") 40 | results = disagg_module(data) 41 | assert results == {**base_labels, **{label: True for label in expected}} 42 | -------------------------------------------------------------------------------- /tests/disaggregation_modules/test_gender.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from disaggregators.disaggregation_modules.gender import Gender, GenderConfig, GenderLabels 4 | 5 | 6 | def test_initialize(): 7 | disagg_module = Gender(column=None) 8 | assert disagg_module.name == "gender" 9 | assert set(disagg_module.labels) == {GenderLabels.MALE, GenderLabels.FEMALE} 10 | 11 | 12 | @pytest.mark.slow 13 | @pytest.mark.parametrize( 14 | "text,expected", 15 | [ 16 | ("That is one large cat!", []), 17 | ("The 40-year-old man went to the park.", [GenderLabels.MALE]), 18 | ("The clown gave the girl an ice cream cone.", [GenderLabels.FEMALE]), 19 | ("What is the boy's name?", [GenderLabels.MALE]), 20 | ("I checked the lady's ticket.", [GenderLabels.FEMALE]), 21 | ("The guy gave the woman a high-five.", [GenderLabels.MALE, GenderLabels.FEMALE]), 22 | ], 23 | ) 24 | def test_call_default(text, expected): 25 | base_labels = {age: False for age in list(GenderLabels)} 26 | data = {"text": text} 27 | disagg_module = Gender(column="text") 28 | results = disagg_module(data) 29 | assert results == {**base_labels, **{label: True for label in expected}} 30 | 31 | 32 | @pytest.mark.slow 33 | def test_call_custom(): 34 | class CustomGenderLabels(GenderLabels): 35 | NON_BINARY = "non-binary" 36 | 37 | _CUSTOM_WORD_LISTS = {CustomGenderLabels.NON_BINARY: ["clown"]} 38 | 39 | data = {"text": "The sad clown went to the doctor."} 40 | disagg_module = Gender( 41 | config=GenderConfig(labels=CustomGenderLabels, word_lists=_CUSTOM_WORD_LISTS), column="text" 42 | ) 43 | results = disagg_module(data) 44 | assert results == { 45 | CustomGenderLabels.MALE: False, 46 | CustomGenderLabels.FEMALE: False, 47 | CustomGenderLabels.NON_BINARY: True, 48 | } 49 | -------------------------------------------------------------------------------- /tests/disaggregation_modules/test_pronoun.py: -------------------------------------------------------------------------------- 1 | from disaggregators.disaggregation_modules.pronoun import Pronoun, PronounConfig, PronounLabels 2 | 3 | 4 | def test_initialize(): 5 | disagg_module = Pronoun(column=None) 6 | assert disagg_module.name == "pronoun" 7 | assert set(disagg_module.labels) == {PronounLabels.HE_HIM, PronounLabels.SHE_HER, PronounLabels.THEY_THEM} 8 | 9 | 10 | def test_call_default(): 11 | data = {"text": "He went to the park."} 12 | disagg_module = Pronoun(column="text") 13 | results = disagg_module(data) 14 | assert results == {PronounLabels.HE_HIM: True, PronounLabels.SHE_HER: False, PronounLabels.THEY_THEM: False} 15 | 16 | 17 | def test_call_custom(): 18 | class CustomPronounLabels(PronounLabels): 19 | ZE_ZIR = "ze_zir" 20 | 21 | _CUSTOM_PRONOUN_MAPPING = {CustomPronounLabels.ZE_ZIR: {"ze", "zir", "zirs", "zirself"}} 22 | 23 | data = {"text": "Ze went to the park."} 24 | disagg_module = Pronoun( 25 | config=PronounConfig(labels=CustomPronounLabels, pronouns=_CUSTOM_PRONOUN_MAPPING), column="text" 26 | ) 27 | results = disagg_module(data) 28 | assert results == { 29 | CustomPronounLabels.ZE_ZIR: True, 30 | CustomPronounLabels.HE_HIM: False, 31 | CustomPronounLabels.SHE_HER: False, 32 | CustomPronounLabels.THEY_THEM: False, 33 | } 34 | -------------------------------------------------------------------------------- /tests/disaggregation_modules/test_religion.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from disaggregators.disaggregation_modules.religion import Religion, ReligionConfig, ReligionLabels 4 | 5 | 6 | def test_initialize(): 7 | disagg_module = Religion(column=None) 8 | assert disagg_module.name == "religion" 9 | assert set(disagg_module.labels) == { 10 | ReligionLabels.JUDAISM, 11 | ReligionLabels.ISLAM, 12 | ReligionLabels.BUDDHISM, 13 | ReligionLabels.CHRISTIANITY, 14 | } 15 | 16 | 17 | @pytest.mark.slow 18 | @pytest.mark.parametrize( 19 | "text,expected", 20 | [ 21 | ( 22 | "The menorah is a seven-branched candelabrum that is described in the Hebrew Bible.", 23 | [ReligionLabels.JUDAISM], 24 | ), 25 | ( 26 | "Traditionally, Eid al-Fitr begins at sunset on the night of the first sighting of the crescent moon.", 27 | [ReligionLabels.ISLAM], 28 | ), 29 | ( 30 | "Three main prevailing theories exist on the finalization of Lent as a forty-day fast" 31 | "prior to the arrival of Easter Sunday.", 32 | [ReligionLabels.CHRISTIANITY], 33 | ), 34 | ( 35 | "Samsara means 'wandering' or 'world', with the connotation of cyclic, circuitous change.", 36 | [ReligionLabels.BUDDHISM], 37 | ), 38 | ( 39 | "CBC Music is a Canadian FM radio network operated by the Canadian Broadcasting Corporation.", 40 | [], 41 | ), 42 | ], 43 | ) 44 | def test_call_default(text, expected): 45 | base_labels = {age: False for age in list(ReligionLabels)} 46 | data = {"text": text} 47 | disagg_module = Religion(column="text") 48 | results = disagg_module(data) 49 | assert results == {**base_labels, **{label: True for label in expected}} 50 | 51 | 52 | @pytest.mark.slow 53 | def test_call_custom(): 54 | class CustomReligionLabels(ReligionLabels): 55 | BOKONONISM = "bokononism" 56 | 57 | data = {"text": "Busy, busy, busy."} 58 | disagg_module = Religion(config=ReligionConfig(labels=CustomReligionLabels), column="text") 59 | 60 | results = disagg_module(data) 61 | 62 | assert results[CustomReligionLabels.BOKONONISM] 63 | -------------------------------------------------------------------------------- /tests/integration/test_disaggregation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | from datasets import Dataset 4 | 5 | from disaggregators import Disaggregator 6 | from disaggregators.disaggregation_modules.pronoun import Pronoun, PronounConfig, PronounLabels 7 | 8 | 9 | @pytest.fixture() 10 | def dataset(): 11 | return Dataset.from_dict({"text": ["Hello world!", "Fizz buzz."]}) 12 | 13 | 14 | class CustomPronounLabels(PronounLabels): 15 | ZE_ZIR = "ze_zir" 16 | 17 | 18 | @pytest.fixture() 19 | def disaggregator(): 20 | _CUSTOM_PRONOUN_MAPPING = {CustomPronounLabels.ZE_ZIR: {"ze", "zir", "zirs", "zirself"}} 21 | 22 | disagg_module = Pronoun( 23 | config=PronounConfig(labels=CustomPronounLabels, pronouns=_CUSTOM_PRONOUN_MAPPING), column="text" 24 | ) 25 | 26 | return Disaggregator(["age", "gender", disagg_module], column="text") 27 | 28 | 29 | @pytest.fixture() 30 | def expected_features(): 31 | return { 32 | "age.child", 33 | "age.youth", 34 | "age.adult", 35 | "age.senior", 36 | "gender.male", 37 | "gender.female", 38 | } 39 | 40 | 41 | def test_datasets(dataset, disaggregator, expected_features): 42 | ds_mapped = dataset.map(disaggregator) 43 | assert expected_features.issubset(set(ds_mapped.features)) 44 | 45 | 46 | def test_pandas(dataset, disaggregator, expected_features): 47 | df = dataset.to_pandas() 48 | new_cols = df.apply(disaggregator, axis=1) 49 | df = pd.merge(df, pd.json_normalize(new_cols), left_index=True, right_index=True) 50 | assert expected_features.issubset(set(df.columns)) 51 | 52 | 53 | def test_each_module(dataset, module): 54 | disaggregator = Disaggregator(module, column="text") 55 | ds_mapped = dataset.map(disaggregator) 56 | 57 | expected_features = disaggregator.fields 58 | 59 | assert expected_features.issubset(set(ds_mapped.features)) 60 | -------------------------------------------------------------------------------- /tests/test_disaggregation_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from disaggregators import DisaggregationModule, DisaggregationModuleFactory 4 | 5 | 6 | def test_create_subclassed_module(dummy_module): 7 | custom_module = dummy_module(column=None) 8 | assert custom_module.name == "dummy-module" 9 | 10 | 11 | def test_load_module_from_string_id(mocker, dummy_module): 12 | mock_modules = mocker.MagicMock() 13 | mock_modules.AVAILABLE_MODULES = {"dummy-module": dummy_module} 14 | mocker.patch("disaggregators.disaggregation_modules.disaggregation_module.disaggregation_modules", mock_modules) 15 | 16 | loaded_module = DisaggregationModuleFactory.create_from_id(module_id="dummy-module", column=None) 17 | assert issubclass(type(loaded_module), DisaggregationModule) 18 | assert loaded_module.name == "dummy-module" 19 | 20 | 21 | def test_load_module_with_bad_string_id(mocker): 22 | mock_modules = mocker.MagicMock() 23 | mock_modules.AVAILABLE_MODULES = {} 24 | mocker.patch("disaggregators.disaggregation_modules.disaggregation_module.disaggregation_modules", mock_modules) 25 | 26 | with pytest.raises(ValueError, match="Invalid module_id received."): 27 | DisaggregationModuleFactory.create_from_id(module_id="bad-module") 28 | 29 | 30 | def test_override_module_config(configured_module, custom_dummy_labels, dummy_module, dummy_labels): 31 | assert configured_module.labels == custom_dummy_labels 32 | 33 | # Ensure that the original module hasn't been modified 34 | original_module = dummy_module(column=None) 35 | assert original_module.labels == dummy_labels 36 | -------------------------------------------------------------------------------- /tests/test_disaggregator.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import pytest 4 | from datasets import Dataset 5 | 6 | from disaggregators import CustomDisaggregator, DisaggregationModule, Disaggregator 7 | 8 | 9 | class TestDisaggregator: 10 | @pytest.fixture(autouse=True) 11 | def mock_module_factory(self, mocker, dummy_module): 12 | mock_factory = mocker.MagicMock() 13 | mock_factory.create_module.side_effect = lambda module_id=None, column=None: dummy_module( 14 | module_id=module_id, column=column 15 | ) 16 | 17 | mocker.patch("disaggregators.disaggregator.DisaggregationModuleFactory", mock_factory) 18 | 19 | def test_create_empty_disaggregator(self): 20 | disagg = Disaggregator() 21 | assert isinstance(disagg.modules, list) 22 | assert len(disagg.modules) == 0 23 | 24 | def test_create_disaggregator_with_single_module(self): 25 | disagg = Disaggregator("pronoun") 26 | assert len(disagg.modules) == 1 27 | assert isinstance(disagg.modules[0], DisaggregationModule) 28 | assert disagg.modules[0].name == "pronoun" 29 | 30 | def test_create_disaggregator_with_multiple_modules(self): 31 | disagg = Disaggregator(["pronoun", "spelling"]) 32 | assert len(disagg.modules) == 2 33 | assert all([isinstance(module, DisaggregationModule) for module in disagg.modules]) 34 | assert disagg.modules[0].name == "pronoun" 35 | assert disagg.modules[1].name == "spelling" 36 | 37 | def test_get_disaggregator_function_single_aggregation_module(self): 38 | disagg = Disaggregator("dummy-module") 39 | assert disagg({"a": 1, "b": 2}) == {"dummy-module.dummy-value-1": True, "dummy-module.dummy-value-2": True} 40 | disagg_func = disagg.get_function() 41 | assert isinstance(disagg_func, Callable) 42 | assert disagg_func({"a": 1, "b": 2}) == { 43 | "dummy-module.dummy-value-1": True, 44 | "dummy-module.dummy-value-2": True, 45 | } 46 | 47 | def test_get_disaggregator_function_multiple_aggregation_modules(self): 48 | disagg = Disaggregator(["dummy-one", "dummy-two"]) 49 | assert disagg({"a": 1, "b": 2}) == { 50 | "dummy-one.dummy-value-1": True, 51 | "dummy-one.dummy-value-2": True, 52 | "dummy-two.dummy-value-1": True, 53 | "dummy-two.dummy-value-2": True, 54 | } 55 | disagg_func = disagg.get_function() 56 | assert isinstance(disagg_func, Callable) 57 | assert disagg_func({"a": 1, "b": 2}) == { 58 | "dummy-one.dummy-value-1": True, 59 | "dummy-one.dummy-value-2": True, 60 | "dummy-two.dummy-value-1": True, 61 | "dummy-two.dummy-value-2": True, 62 | } 63 | 64 | @pytest.mark.parametrize( 65 | "modules,expected", 66 | [ 67 | ([], set()), 68 | (["dummy-one"], {"dummy-one.dummy-value-1", "dummy-one.dummy-value-2"}), 69 | ( 70 | ["dummy-one", "dummy-two"], 71 | { 72 | "dummy-one.dummy-value-1", 73 | "dummy-one.dummy-value-2", 74 | "dummy-two.dummy-value-1", 75 | "dummy-two.dummy-value-2", 76 | }, 77 | ), 78 | ], 79 | ) 80 | def test_get_fields(self, modules, expected): 81 | disagg = Disaggregator(modules) 82 | assert disagg.fields == expected 83 | 84 | 85 | @pytest.fixture 86 | def custom_module(dummy_labels): 87 | class CustomModule(CustomDisaggregator): 88 | module_id = "custom" 89 | labels = dummy_labels 90 | 91 | def __call__(self, row, *args, **kwargs): 92 | return {dummy_labels.DUMMY_ONE: "cat", dummy_labels.DUMMY_TWO: "dog"} 93 | 94 | return CustomModule 95 | 96 | 97 | def test_inject_custom_module_subclass(custom_module, dummy_labels): 98 | disagg = Disaggregator(custom_module, column=None) 99 | assert disagg.fields == {"custom.dummy-value-1", "custom.dummy-value-2"} 100 | 101 | ds = Dataset.from_dict({"text": ["Hello world!"]}).map(disagg) 102 | assert set(ds.features) == {"text", *disagg.fields} 103 | assert ds[0] == {"text": "Hello world!", "custom.dummy-value-1": "cat", "custom.dummy-value-2": "dog"} 104 | 105 | 106 | def test_module_instance(configured_module, configured_dummy_expected_results): 107 | disagg = Disaggregator(configured_module, column=None) 108 | assert disagg.fields == configured_module.field_names 109 | 110 | ds = Dataset.from_dict({"text": ["Hello world!"]}).map(disagg) 111 | assert set(ds.features) == {"text", *disagg.fields} 112 | assert ds[0] == {"text": "Hello world!", **configured_dummy_expected_results} 113 | --------------------------------------------------------------------------------