├── .coveragerc ├── .flake8 ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE.md ├── Makefile ├── Pipfile ├── Pipfile.lock ├── README.md ├── dcbench ├── __init__.py ├── __main__.py ├── common │ ├── __init__.py │ ├── artifact.py │ ├── artifact_container.py │ ├── modeling.py │ ├── problem.py │ ├── result.py │ ├── solution.py │ ├── solution_set.py │ ├── solve.py │ ├── solver.py │ ├── table.py │ ├── task.py │ ├── trial.py │ └── utils.py ├── config.py ├── constants.py ├── tasks │ ├── .DS_Store │ ├── __init__.py │ ├── budgetclean │ │ ├── __init__.py │ │ ├── baselines.py │ │ ├── common.py │ │ ├── cpclean │ │ │ ├── README.md │ │ │ ├── algorithm │ │ │ │ ├── min_max.py │ │ │ │ ├── select.py │ │ │ │ ├── sort_count.py │ │ │ │ └── utils.py │ │ │ ├── clean.py │ │ │ ├── debugger.py │ │ │ ├── knn_evaluator.py │ │ │ ├── query.py │ │ │ └── utils.py │ │ └── problem.py │ ├── minidata │ │ ├── __init__.py │ │ └── unagi_configs.py │ └── slice_discovery │ │ ├── __init__.py │ │ ├── baselines.py │ │ ├── metrics.py │ │ ├── pipeline.py │ │ ├── problem.py │ │ └── run.py └── version.py ├── docs ├── Makefile ├── assets │ ├── banner.png │ └── logo.png ├── make.bat ├── populate_docs.py ├── requirements.txt └── source │ ├── apidocs │ ├── dcbench.budgetclean.rst │ ├── dcbench.common.rst │ ├── dcbench.rst │ ├── dcbench.tasks.budgetclean.rst │ ├── dcbench.tasks.minidata.rst │ ├── dcbench.tasks.rst │ ├── dcbench.tasks.slice.rst │ ├── dcbench.tasks.slice_discovery.rst │ └── modules.rst │ ├── conf.py │ ├── index.rst │ ├── install.rst │ ├── intro.rst │ ├── task_descriptions │ ├── budgetclean.rst │ ├── minidata.rst │ └── slice_discovery.rst │ ├── task_template.rst │ └── tasks.rst ├── notebooks ├── Untitled.ipynb ├── dcbench_budgetclean.ipynb ├── dcbench_slice_discovery-Copy1.ipynb └── dcbench_slice_discovery.ipynb ├── pyproject.toml ├── setup.py └── tests ├── __init__.py ├── conftest.py └── dcbench ├── __init__.py ├── common ├── __init__.py ├── test_artifact.py ├── test_artifact_container.py ├── test_problem.py └── test_task.py ├── tasks └── test_slice_discovery.py ├── test_config.py └── test_dcbench.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = dcbench 4 | 5 | [report] 6 | exclude_lines = 7 | if self.debug: 8 | pragma: no cover 9 | raise NotImplementedError 10 | raise NotImplemented 11 | if __name__ == .__main__.: 12 | ignore_errors = True 13 | omit = 14 | tests/* 15 | dcbench/tasks/budgetclean/* 16 | setup.py -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # This is our code-style check. We currently allow the following exceptions: 2 | # - E731: do not assign a lambda expression, use a def 3 | # - W503: line break before binary operator 4 | # - E741: do not use variables named 'l', 'O', or 'I' 5 | # - E203: whitespace before ':' 6 | [flake8] 7 | count = True 8 | max-line-length = 88 9 | statistics = True 10 | ignore = E731,W503,E741,E203 11 | exclude = setup.py -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | pull_request: 7 | branches: [ '*' ] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | jobs: 13 | 14 | Linting: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ['3.7', '3.8', '3.9'] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - uses: actions/cache@v2 28 | with: 29 | path: ~/.cache/pip 30 | key: ${{ runner.os }}-pip 31 | 32 | - name: Install Dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | make dev 36 | - name: Lint with isort, black, docformatter, flake8 37 | run: | 38 | make lint 39 | 40 | Documentation: 41 | needs: Linting 42 | runs-on: ubuntu-latest 43 | strategy: 44 | matrix: 45 | python-version: ['3.8', '3.9'] 46 | 47 | steps: 48 | - uses: actions/checkout@v2 49 | - name: Set up Python ${{ matrix.python-version }} 50 | uses: actions/setup-python@v2 51 | with: 52 | python-version: ${{ matrix.python-version }} 53 | 54 | - uses: actions/cache@v2 55 | with: 56 | path: ~/.cache/pip 57 | key: ${{ runner.os }}-pip 58 | 59 | - name: Install Dependencies 60 | run: | 61 | python -m pip install --upgrade pip 62 | make dev 63 | 64 | - name: Generate Docs 65 | run: | 66 | make docs 67 | Build: 68 | permissions: 69 | contents: 'read' 70 | id-token: 'write' 71 | if: 72 | contains(' 73 | refs/heads/main 74 | ', github.event.pull_request.base.ref) 75 | runs-on: ${{ matrix.os }} 76 | strategy: 77 | matrix: 78 | os: [ubuntu-latest] #, macos-latest] 79 | python-version: ['3.7', '3.8', '3.9'] 80 | 81 | steps: 82 | - uses: actions/checkout@v2 83 | 84 | 85 | 86 | - name: Set up Python ${{ matrix.python-version }} 87 | uses: actions/setup-python@v2 88 | with: 89 | python-version: ${{ matrix.python-version }} 90 | 91 | 92 | - id: auth 93 | uses: google-github-actions/auth@v0 94 | with: 95 | workload_identity_provider: 'projects/419310667461/locations/global/workloadIdentityPools/github-actions-pool/providers/github-actions-provider' 96 | service_account: 'github-actions@hai-gcp-fine-grained.iam.gserviceaccount.com' 97 | 98 | - name: Set up gcloud Cloud SDK environment 99 | uses: google-github-actions/setup-gcloud@v0.2.0 100 | 101 | - name: Use gcloud CLI 102 | run: gcloud info 103 | 104 | - uses: actions/cache@v2 105 | with: 106 | path: ~/.cache/pip 107 | key: ${{ runner.os }}-pip 108 | 109 | - name: Install Dependencies 110 | run: | 111 | pip install -e ".[dev]" 112 | 113 | - name: Test with pytest 114 | run: | 115 | make test-cov 116 | 117 | - name: Upload to codecov.io 118 | uses: codecov/codecov-action@v1 119 | with: 120 | file: ./coverage.xml 121 | flags: unittests 122 | name: codecov-umbrella 123 | fail_ci_if_error: true 124 | 125 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | dcbench-config.yaml 6 | slurm 7 | 8 | # C extensions 9 | *.so 10 | *.DS_Store 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | tmp/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # Other stuff 136 | .vscode/ 137 | tmp/ -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output = 3 3 | include_trailing_comma = True 4 | force_grid_wrap = 0 5 | use_parentheses = True 6 | ensure_newline_before_comments = True 7 | line_length = 88 -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/timothycrosley/isort 3 | rev: 5.7.0 4 | hooks: 5 | - id: isort 6 | - repo: https://github.com/psf/black 7 | rev: 20.8b1 8 | hooks: 9 | - id: black 10 | language_version: python3 11 | - repo: https://gitlab.com/pycqa/flake8 12 | rev: 3.8.4 13 | hooks: 14 | - id: flake8 -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this benchmark, please cite it as below." 3 | authors: 4 | - family-names: Eyuboglu 5 | given-names: Sabri 6 | orcid: "https://orcid.org/0000-0002-8412-0266" 7 | - family-names: Karlaš 8 | given-names: Bojan 9 | - family-names: Zhang 10 | given-names: Ce 11 | - family-names: Ré 12 | given-names: Christopher 13 | - family-names: Zou 14 | given-names: James 15 | title: "dcbench" 16 | version: 1.0.0 17 | doi: 10.5281/zenodo.1234 18 | date-released: 2021-11-29 19 | url: "https://github.com/data-centric-ai/dcbench" 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to dcbench 2 | 3 | We welcome contributions of all kinds: code, documentation, feedback and support. If 4 | you use dcbench in your work (blogs posts, research, company) and find it 5 | useful, spread the word! 6 | 7 | This contribution borrows from and is heavily inspired by [Huggingface transformers](https://github.com/huggingface/transformers). 8 | 9 | ## How to contribute 10 | 11 | There are 4 ways you can contribute: 12 | * Issues: raising bugs, suggesting new features 13 | * Fixes: resolving outstanding bugs 14 | * Features: contributing new features 15 | * Documentation: contributing documentation or examples 16 | 17 | ## Submitting a new issue or feature request 18 | 19 | Do your best to follow these guidelines when submitting an issue or a feature 20 | request. It will make it easier for us to give feedback and move your request forward. 21 | 22 | ### Bugs 23 | 24 | First, we would really appreciate it if you could **make sure the bug was not 25 | already reported** (use the search bar on Github under Issues). 26 | 27 | If you didn't find anything, please use the bug issue template to file a Github issue. 28 | 29 | 30 | ### Features 31 | 32 | A world-class feature request addresses the following points: 33 | 34 | 1. Motivation first: 35 | * Is it related to a problem/frustration with the library? If so, please explain 36 | why. Providing a code snippet that demonstrates the problem is best. 37 | * Is it related to something you would need for a project? We'd love to hear 38 | about it! 39 | * Is it something you worked on and think could benefit the community? 40 | Awesome! Tell us what problem it solved for you. 41 | 2. Write a *full paragraph* describing the feature; 42 | 3. Provide a **code snippet** that demonstrates its future use; 43 | 4. In case this is related to a paper, please attach a link; 44 | 5. Attach any additional information (drawings, screenshots, etc.) you think may help. 45 | 46 | If your issue is well written we're already 80% of the way there by the time you 47 | post it. 48 | 49 | ## Contributing (Pull Requests) 50 | 51 | Before writing code, we strongly advise you to search through the existing PRs or 52 | issues to make sure that nobody is already working on the same thing. If you are 53 | unsure, it is always a good idea to open an issue to get some feedback. 54 | 55 | You will need basic `git` proficiency to be able to contribute to 56 | `dcbench`. `git` is not the easiest tool to use but it has the greatest 57 | manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro 58 | Git](https://git-scm.com/book/en/v2) is a very good reference. 59 | 60 | Follow these steps to start contributing: 61 | 62 | 1. Fork the [repository](https://github.com/data-centric-ai/dcbench) by 63 | clicking on the 'Fork' button on the repository's page. 64 | This creates a copy of the code under your GitHub user account. 65 | 66 | 2. Clone your fork to your local disk, and add the base repository as a remote: 67 | 68 | ```bash 69 | $ git clone git@github.com:/dcbench.git 70 | $ cd dcbench 71 | $ git remote add upstream https://github.com/data-centric-ai/dcbench.git 72 | ``` 73 | 74 | 3. Create a new branch to hold your development changes: 75 | 76 | ```bash 77 | $ git checkout -b a-descriptive-name-for-my-changes 78 | ``` 79 | 80 | **Do not** work on the `main` branch. 81 | 82 | 4. dcbench manages dependencies using [`poetry`](https://python-poetry.org). 83 | Set up a development environment with `poetry` by running the following command in 84 | a virtual environment: 85 | 86 | ```bash 87 | $ pip install poetry 88 | $ poetry install 89 | ``` 90 | Note: in order to pass the full test suite (step 5), you'll need to install all extra in addition. 91 | ```bash 92 | $ poetry install --extras "adversarial augmentation summarization text vision" 93 | ``` 94 | 5. Develop features on your branch. 95 | 96 | As you work on the features, you should make sure that the test suite 97 | passes: 98 | 99 | ```bash 100 | $ pytest 101 | ``` 102 | 103 | dcbench relies on `black` and `isort` to format its source code 104 | consistently. After you make changes, autoformat them with: 105 | 106 | ```bash 107 | $ make autoformat 108 | ``` 109 | 110 | dcbench also uses `flake8` to check for coding mistakes. Quality control 111 | runs in CI, however you should also run the same checks with: 112 | 113 | ```bash 114 | $ make lint 115 | ``` 116 | 117 | If you're modifying documents under `docs/source`, make sure to validate that 118 | they can still be built. This check also runs in CI. To run a local check 119 | make sure you have installed the documentation builder requirements, by 120 | running `pip install -r docs/requirements.txt` from the root of this repository 121 | and then run: 122 | 123 | ```bash 124 | $ make docs 125 | ``` 126 | 127 | Once you're happy with your changes, add changed files using `git add` and 128 | make a commit with `git commit` to record your changes locally: 129 | 130 | ```bash 131 | $ git add modified_file.py 132 | $ git commit 133 | ``` 134 | 135 | Please write [good commit messages](https://chris.beams.io/posts/git-commit/). 136 | 137 | It is a good idea to sync your copy of the code with the original 138 | repository regularly. This way you can quickly account for changes: 139 | 140 | ```bash 141 | $ git fetch upstream 142 | $ git rebase upstream/main 143 | ``` 144 | 145 | Push the changes to your account using: 146 | 147 | ```bash 148 | $ git push -u origin a-descriptive-name-for-my-changes 149 | ``` 150 | 151 | You can use `pre-commit` to make sure you don't forget to format your code properly, 152 | the dependency should already be made available by `poetry`. 153 | 154 | Just install `pre-commit` for the `dcbench` directory, 155 | 156 | ```bash 157 | $ pre-commit install 158 | ``` 159 | 160 | 6. Once you are satisfied (**and the checklist below is happy too**), go to the 161 | webpage of your fork on GitHub. Click on 'Pull request' to send your changes 162 | to the project maintainers for review. 163 | 164 | 7. It's ok if maintainers ask you for changes. It happens to core contributors 165 | too! So everyone can see the changes in the Pull request, work in your local 166 | branch and push the changes to your fork. They will automatically appear in 167 | the pull request. 168 | 169 | 8. We follow a one-commit-per-PR policy. Before your PR can be merged, you will have to 170 | `git rebase` to squash your changes into a single commit. 171 | 172 | ### Checklist 173 | 174 | 0. One commit per PR. 175 | 1. The title of your pull request should be a summary of its contribution; 176 | 2. If your pull request addresses an issue, please mention the issue number in 177 | the pull request description to make sure they are linked (and people 178 | consulting the issue know you are working on it); 179 | 3. To indicate a work in progress please prefix the title with `[WIP]`. These 180 | are useful to avoid duplicated work, and to differentiate it from PRs ready 181 | to be merged; 182 | 4. Make sure existing tests pass; 183 | 5. Add high-coverage tests. No quality testing = no merge. 184 | 6. All public methods must have informative docstrings that work nicely with sphinx. 185 | 186 | 187 | ### Tests 188 | 189 | A test suite is included to test the library behavior. 190 | Library tests can be found in the 191 | [tests folder](https://github.com/data-centric-aidcbench/tree/main/tests). 192 | 193 | From the root of the 194 | repository, here's how to run tests with `pytest` for the library: 195 | 196 | ```bash 197 | $ make test 198 | ``` 199 | 200 | You can specify a smaller set of tests in order to test only the feature 201 | you're working on. 202 | 203 | Per the checklist above, all PRs should include high-coverage tests. 204 | To produce a code coverage report, run the following `pytest` 205 | ``` 206 | pytest --cov-report term-missing,html --cov=dcbench . 207 | ``` 208 | This will populate a directory `htmlcov` with an HTML report. 209 | Open `htmlcov/index.html` in a browser to view the report. 210 | 211 | 212 | ### Style guide 213 | 214 | For documentation strings, dcbench follows the 215 | [google style](https://google.github.io/styleguide/pyguide.html). -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 The Meerkat Team. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | autoformat: 2 | black dcbench/ tests/ 3 | autoflake --in-place --remove-all-unused-imports -r dcbench tests 4 | isort --atomic dcbench/ tests/ 5 | docformatter --in-place --recursive dcbench tests 6 | 7 | lint: 8 | isort -c dcbench/ tests/ 9 | black dcbench/ tests/ --check 10 | flake8 dcbench/ tests/ 11 | 12 | test: 13 | pytest 14 | 15 | test-basic: 16 | set -e 17 | python -c "import dcbench as mk" 18 | python -c "import dcbench.version as mversion" 19 | 20 | test-cov: 21 | pytest --cov=./ --cov-report=xml 22 | 23 | docs: 24 | sphinx-build -b html docs/source/ docs/build/html/ 25 | 26 | docs-check: 27 | sphinx-build -b html docs/source/ docs/build/html/ -W 28 | 29 | livedocs: 30 | sphinx-autobuild -b html docs/source/ docs/build/html/ 31 | 32 | dev: 33 | pip install black isort flake8 docformatter pytest-cov sphinx-rtd-theme nbsphinx recommonmark pre-commit 34 | 35 | all: autoformat lint docs test 36 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [dev-packages] 7 | twine = "*" 8 | ipython = "*" 9 | 10 | [requires] 11 | python_version = "3.8" 12 | 13 | [dev-packages.dcbench] 14 | path = "." 15 | extras = [ "dev",] 16 | editable = true 17 | 18 | [packages.dcbench] 19 | editable = true 20 | path = "." 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | banner 4 | 5 | ----- 6 | ![GitHub Workflow Status](https://img.shields.io/github/workflow/status/data-centric-ai/dcbench/CI) 7 | ![GitHub](https://img.shields.io/github/license/data-centric-ai/dcbench) 8 | [![Documentation Status](https://readthedocs.org/projects/dcbench/badge/?version=latest)](https://dcbench.readthedocs.io/en/latest/?badge=latest) 9 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 10 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/dcbench)](https://pypi.org/project/dcbench/) 11 | [![codecov](https://codecov.io/gh/data-centric-ai/dcbench/branch/main/graph/badge.svg?token=MOLQYUSYQU)](https://codecov.io/gh/data-centric-ai/dcbench) 12 | 13 | A benchmark of data-centric tasks from across the machine learning lifecycle. 14 | 15 | [**Getting Started**](#%EF%B8%8F-quickstart) 16 | | [**What is dcbench?**](#-what-is-dcbench) 17 | | [**Docs**](https://dcbench.readthedocs.io/en/latest/index.html) 18 | | [**Contributing**](CONTRIBUTING.md) 19 | | [**Website**](https://www.datacentricai.cc/) 20 | | [**About**](#%EF%B8%8F-about) 21 |
22 | 23 | 24 | ## ⚡️ Quickstart 25 | 26 | ```bash 27 | pip install dcbench 28 | ``` 29 | > Optional: some parts of Meerkat rely on optional dependencies. If you know which optional dependencies you'd like to install, you can do so using something like `pip install dcbench[dev]` instead. See setup.py for a full list of optional dependencies. 30 | 31 | > Installing from dev: `pip install "dcbench[dev] @ git+https://github.com/data-centric-ai/dcbench@main"` 32 | 33 | Using a Jupyter notebook or some other interactive environment, you can import the library 34 | and explore the data-centric problems in the benchmark: 35 | 36 | ```python 37 | import dcbench 38 | dcbench.tasks 39 | ``` 40 | To learn more, follow the [walkthrough](https://dcbench.readthedocs.io/en/latest/intro.html#api-walkthrough) in the docs. 41 | 42 | 43 | ## 💡 What is dcbench? 44 | This benchmark evaluates the steps in your machine learning workflow beyond model training and tuning. This includes feature cleaning, slice discovery, and coreset selection. We call these “data-centric” tasks because they're focused on exploring and manipulating data – not training models. ``dcbench`` supports a growing list of them: 45 | 46 | * [Minimal Data Selection](https://dcbench.readthedocs.io/en/latest/tasks.html#minimal-data-selection) 47 | * [Slice Discovery](https://dcbench.readthedocs.io/en/latest/tasks.html#slice-discovery) 48 | * [Minimal Feature Cleaning](https://dcbench.readthedocs.io/en/latest/tasks.html#minimal-feature-cleaning) 49 | 50 | 51 | ``dcbench`` includes tasks that look very different from one another: the inputs and 52 | outputs of the slice discovery task are not the same as those of the 53 | minimal data cleaning task. However, we think it important that 54 | researchers and practitioners be able to run evaluations on data-centric 55 | tasks across the ML lifecycle without having to learn a bunch of 56 | different APIs or rewrite evaluation scripts. 57 | 58 | So, ``dcbench`` is designed to be a common home for these diverse, but 59 | related, tasks. In ``dcbench`` all of these tasks are structured in a 60 | similar manner and they are supported by a common Python API that makes 61 | it easy to download data, run evaluations, and compare methods. 62 | 63 | 64 | ## ✉️ About 65 | `dcbench` is being developed alongside the data-centric-ai benchmark. Reach out to Bojan Karlaš (karlasb [at] inf [dot] ethz [dot] ch) and Sabri Eyuboglu (eyuboglu [at] stanford [dot] edu if you would like to get involved or contribute!) 66 | -------------------------------------------------------------------------------- /dcbench/__init__.py: -------------------------------------------------------------------------------- 1 | """The dcbench module is a collection for benchmarks that test various apsects 2 | of data preparation and handling in the context of AI workflows.""" 3 | # flake8: noqa 4 | 5 | from .common import Artifact, Problem, Solution, Table, Task 6 | from .common.artifact import ( 7 | CSVArtifact, 8 | DataPanelArtifact, 9 | ModelArtifact, 10 | VisionDatasetArtifact, 11 | YAMLArtifact, 12 | ) 13 | from .config import config 14 | from .tasks.budgetclean import BudgetcleanProblem, BudgetcleanSolution 15 | from .tasks.minidata import MiniDataProblem, MiniDataSolution 16 | from .tasks.slice_discovery import SliceDiscoveryProblem, SliceDiscoverySolution 17 | 18 | __all__ = [ 19 | "Artifact", 20 | "Problem", 21 | "Solution", 22 | "BudgetcleanProblem", 23 | "MiniDataProblem", 24 | "SliceDiscoveryProblem", 25 | "BudgetcleanSolution", 26 | "MiniDataSolution", 27 | "SliceDiscoverySolution", 28 | "Task", 29 | "ModelArtifact", 30 | "YAMLArtifact", 31 | "DataPanelArtifact", 32 | "VisionDatasetArtifact", 33 | "CSVArtifact", 34 | "config", 35 | ] 36 | 37 | from .tasks.budgetclean import task as budgetclean_task 38 | from .tasks.minidata import task as minidata_task 39 | from .tasks.slice_discovery import task as slice_discovery_task 40 | 41 | tasks = Table( 42 | [ 43 | minidata_task, 44 | slice_discovery_task, 45 | budgetclean_task, 46 | ] 47 | ) 48 | -------------------------------------------------------------------------------- /dcbench/__main__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | __all__ = ("main",) 4 | _log = logging.getLogger(__name__) 5 | 6 | # flake8: noqa 7 | BANNER = """ 8 | ____ _________ ____ ____ _______ __________ __ 9 | / __ \/ ____/ | / _/ / __ )/ ____/ | / / ____/ / / / 10 | / / / / / / /| | / / / __ / __/ / |/ / / / /_/ / 11 | / /_/ / /___/ ___ |_/ / / /_/ / /___/ /| / /___/ __ / 12 | /_____/\____/_/ |_/___/ /_____/_____/_/ |_/\____/_/ /_/ 13 | 14 | """ 15 | 16 | 17 | # class MainGroup(click.Group): 18 | # def format_usage(self, ctx, formatter): 19 | # formatter.write(BANNER) 20 | # super().format_usage(ctx, formatter) 21 | 22 | 23 | # @click.group( 24 | # context_settings=dict( 25 | # help_option_names=["-h", "--help"], auto_envvar_prefix="DCBENCH" 26 | # ), 27 | # cls=MainGroup, 28 | # ) 29 | # @click.version_option(prog_name="dcbench", version=__version__) 30 | # @click.option( 31 | # "--optional-artifacts-url", 32 | # help="The URL pointing to the optional artifacts bundle.", 33 | # ) 34 | # def main( 35 | # hidden_artifacts_url: Optional[str] = None, 36 | # working_dir: Optional[str] = None, 37 | # **kwargs 38 | # ): 39 | # """Collection of benchmarks that test various aspects of ML data 40 | # preprocessing and management.""" 41 | 42 | 43 | # @main.command(help="List all available scenarios.") 44 | # def scenarios(): 45 | # for id in Scenario.list(): 46 | # click.echo(id) 47 | 48 | 49 | # @main.command( 50 | # help="List solutions for a given scenario and corresponding evaluation results if available." 51 | # ) 52 | # @click.option("--scenario-id", type=str, help="The ID of the scenario.", required=True) 53 | # def solutions(scenario_id: str): 54 | # scenario = Scenario.scenarios.get(scenario_id, None)() 55 | # if scenario is None: 56 | # click.echo( 57 | # "The scenario with identifier '%s' not found." % scenario_id, err=True 58 | # ) 59 | # return 60 | # click.echo(scenario.solutions) 61 | 62 | 63 | # @main.command(help="Create a new solution for a given scenario.") 64 | # @click.option("--scenario-id", type=str, help="The ID of the scenario.", required=True) 65 | # @click.option("--name", type=str, help="The name of the new solution.") 66 | # @click.option( 67 | # "--paper", 68 | # type=str, 69 | # help="The URL pointing to a paper describing the solution method.", 70 | # ) 71 | # @click.option( 72 | # "--code", 73 | # type=str, 74 | # help="The URL pointing to a repository or notebook containing the solution code.", 75 | # ) 76 | # @click.option( 77 | # "--artifacts-url", type=str, help="The URL pointing to the solution artifacts." 78 | # ) 79 | # def new_solution( 80 | # scenario_id: str, 81 | # name: Optional[str], 82 | # paper: Optional[str], 83 | # code: Optional[str], 84 | # artifacts_url: Optional[str], 85 | # ): 86 | # scenario = Scenario.scenarios.get(scenario_id, None)() 87 | # if scenario is None: 88 | # click.echo( 89 | # "The scenario with identifier '%s' not found." % scenario_id, err=True 90 | # ) 91 | # return 92 | # solution = Solution( 93 | # scenario, name=name, paper=paper, code=code, artifacts_url=artifacts_url 94 | # ) 95 | # solution.save() 96 | # click.echo("New solution saved to:", err=True) 97 | # click.echo(solution.location) 98 | 99 | 100 | # @main.command(help="Evaluate solutions for one or more scenarios.") 101 | # @click.option( 102 | # "--scenario-id", 103 | # type=str, 104 | # help="The ID of the scenario. If omitted then all scenarios are considered.", 105 | # ) 106 | # @click.option( 107 | # "--force", 108 | # type=bool, 109 | # is_flag=True, 110 | # help="Evaluates even if a previous evaluation result exists.", 111 | # ) 112 | # def solve(scenario_id: Optional[str], force: bool): 113 | # scenarios = [] 114 | # if scenario_id is not None: 115 | # scenario = Scenario.scenarios.get(scenario_id, None) 116 | # if scenario is None: 117 | # click.echo( 118 | # "The scenario with identifier '%s' not found." % scenario_id, err=True 119 | # ) 120 | # return 121 | # scenarios.append(scenario) 122 | # else: 123 | # scenarios = [ 124 | # Scenario.scenarios[id]() for id in sorted(Scenario.scenarios.keys()) 125 | # ] 126 | 127 | # for scenario in scenarios: 128 | # for solution in scenario.solutions.values(): 129 | # if solution.result is not None or force: 130 | # click.echo( 131 | # "Evaluating solution '%s' of scenario '%s'." 132 | # % (solution.id, scenario.id), 133 | # err=True, 134 | # ) 135 | # solution.evaluate() 136 | # solution.save() 137 | # click.echo("Result:", err=True) 138 | # click.echo(solution.result) 139 | 140 | 141 | # @main.command(help="Evaluate solutions for one or more scenarios.") 142 | # @click.option( 143 | # "--scenario-id", 144 | # type=str, 145 | # help="The ID of the scenario. If omitted then all scenarios are considered.", 146 | # ) 147 | # @click.option( 148 | # "--force", 149 | # type=bool, 150 | # is_flag=True, 151 | # help="Evaluates even if a previous evaluation result exists.", 152 | # ) 153 | # def evaluate(scenario_id: Optional[str], force: bool): 154 | # scenarios = [] 155 | # if scenario_id is not None: 156 | # scenario = Scenario.scenarios.get(scenario_id, None) 157 | # if scenario is None: 158 | # click.echo( 159 | # "The scenario with identifier '%s' not found." % scenario_id, err=True 160 | # ) 161 | # return 162 | # scenarios.append(scenario) 163 | # else: 164 | # scenarios = [ 165 | # Scenario.scenarios[id]() for id in sorted(Scenario.scenarios.keys()) 166 | # ] 167 | 168 | # for scenario in scenarios: 169 | # for solution in scenario.solutions.values(): 170 | # if solution.result is not None or force: 171 | # click.echo( 172 | # "Evaluating solution '%s' of scenario '%s'." 173 | # % (solution.id, scenario.id), 174 | # err=True, 175 | # ) 176 | # solution.evaluate() 177 | # solution.save() 178 | # click.echo("Result:", err=True) 179 | # click.echo(solution.result) 180 | 181 | 182 | if __name__ == "__main__": 183 | main(prog_name="dcbench") 184 | -------------------------------------------------------------------------------- /dcbench/common/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .artifact import Artifact 4 | from .problem import Problem 5 | from .result import Result 6 | from .solution import Solution 7 | from .table import Table 8 | from .task import Task 9 | 10 | __all__ = ["Artifact", "Problem", "Solution", "Task", "Table", "Result"] 11 | -------------------------------------------------------------------------------- /dcbench/common/artifact_container.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import uuid 5 | from abc import ABC 6 | from collections.abc import Mapping 7 | from dataclasses import dataclass 8 | from typing import Any 9 | 10 | import yaml 11 | from meerkat.tools.lazy_loader import LazyLoader 12 | 13 | import dcbench.constants as constants 14 | from dcbench.config import config 15 | 16 | from .artifact import Artifact 17 | from .table import Attribute, AttributeSpec, RowMixin 18 | 19 | storage = LazyLoader("google.cloud.storage") 20 | 21 | 22 | @dataclass 23 | class ArtifactSpec: 24 | description: str 25 | artifact_type: type 26 | optional: bool = False 27 | 28 | 29 | class ArtifactContainer(ABC, Mapping, RowMixin): 30 | """A logical collection of artifacts and attributes (simple tags describing the 31 | container), which are useful for finding, sorting and grouping containers. 32 | 33 | Args: 34 | artifacts (Mapping[str, Union[Artifact, Any]]): A mapping with the same keys 35 | as the `ArtifactContainer.artifact_specs` (possibly excluding optional 36 | artifacts). Each value can either be an :class:`Artifact`, in which case the 37 | artifact type must match the type specified in the corresponding 38 | :class:`ArtifactSpec`, or a raw object, in which case a new artifact of the 39 | type specified in `artifact_specs` is created from the raw object and an 40 | ``artifact_id`` is generated according to the following pattern: 41 | ``//artifacts//``. 42 | attributes (Mapping[str, PRIMITIVE_TYPE], optional): A mapping with the same 43 | keys as the `ArtifactContainer.attribute_specs` (possibly excluding optional 44 | attributes). Each value must be of the type specified in the corresponding 45 | :class:`AttributeSpec`. Defaults to None. 46 | container_id (str, optional): The ID of the container. Defaults to None, in 47 | which case a UUID is generated. 48 | 49 | Attributes: 50 | artifacts (Dict[str, Artifact]): A dictionary of artifacts, indexed by name. 51 | 52 | .. Tip:: 53 | We can use the index operator directly on :class:`ArtifactContainer` 54 | objects to both fetch the artifact, download it if necessary, and load 55 | it into memory. For example, to load the artifact ``"data"`` into 56 | memory from a container ``container``, we can simply call 57 | ``container["data"]``, which is equivalent to calling 58 | ``container.artifacts["data"].download()`` followed by 59 | ``container.artifacts["data"].load()``. 60 | 61 | attributes (Dict[str, Attribute]): A dictionary of attributes, indexed by 62 | name. 63 | 64 | .. Tip:: Accessing attributes 65 | Atttributes can be accessed via a dot-notation (as long as the attribute 66 | name does not conflict). For example, to access the attribute ``"data"`` 67 | in a container ``container``, we can simply call ``container.data``. 68 | 69 | 70 | Notes 71 | ----- 72 | 73 | :class:`ArtifactContainer` is an abstract base class, and should not be 74 | instantiated directly. There are two main groups of :class:`ArtifactContainer` 75 | subclasses: 76 | 77 | #. :class:`dcbench.Problem` - A logical collection of artifacts and 78 | attributes that correspond to a specific problem to be solved. 79 | 80 | - Example subclasses: :class:`dcbench.SliceDiscoveryProblem`, 81 | :class:`dcbench.BudgetcleanProblem` 82 | #. :class:`dcbench.Solution` - A logical collection of artifacts and 83 | attributes that correspond to a solution to a problem. 84 | 85 | - Example subclasses: :class:`dcbench.SliceDiscoverySolution`, 86 | :class:`dcbench.BudgetcleanSolution` 87 | 88 | A concrete (i.e. non-abstract) subclass of :class:`ArtifactContainer` must include 89 | (1) a specification for the artifacts it holds, (2) a specification for the 90 | attributes used to tag it, and (3) a `task_id` linking the subclass 91 | to one of dcbench's tasks (see :ref:`task-intro`). For example, in the code block 92 | below we include such a specification in the definition of a simple container that 93 | holds a training dataset and a test dataset (see 94 | :class:`dcbench.SliceDiscoveryProblem` for a real example): 95 | 96 | .. code-block:: python 97 | 98 | class DemoContainer(ArtifactContainer): 99 | artifact_specs = { 100 | "train_dataset": ArtifactSpec( 101 | artifact_type=CSVArtifact, 102 | description="A CSV containing training data." 103 | ), 104 | "test_dataset": ArtifactSpec( 105 | artifact_type=CSVArtifact, 106 | description="A CSV containing test data." 107 | ), 108 | } 109 | attribute_specs = { 110 | "dataset_name": AttributeSpec( 111 | attribute_type=str, 112 | description="The name of the dataset." 113 | ), 114 | } 115 | task_id = "slice_discovery" 116 | 117 | """ 118 | 119 | artifact_specs: Mapping[str, ArtifactSpec] 120 | task_id: str 121 | attribute_specs: Mapping[str, AttributeSpec] = {} 122 | 123 | # abstract subclasses like Problem and Solution specify this so that all of their 124 | # subclasses may be grouped by container_type when stored on disk 125 | container_type: str = "artifact_container" 126 | 127 | def __init__( 128 | self, 129 | artifacts: Mapping[str, Artifact], 130 | attributes: Mapping[str, Attribute] = None, 131 | container_id: str = None, 132 | ): 133 | if container_id is None: 134 | container_id = uuid.uuid4().hex 135 | 136 | super().__init__(id=container_id) 137 | self._check_artifact_specs(artifacts=artifacts) 138 | artifacts = self._create_artifacts(artifacts=artifacts) 139 | self.artifacts = artifacts 140 | 141 | if attributes is None: 142 | attributes = {} 143 | self.attributes = attributes # This setter will check the artifact_specs 144 | 145 | @property 146 | def is_downloaded(self) -> bool: 147 | """Checks if all of the artifacts in the container are downloaded to the local 148 | directory specified in the config file at ``config.local_dir``. 149 | 150 | Returns: 151 | bool: True if artifact is downloaded, False otherwise. 152 | """ 153 | return all(x.is_downloaded for x in self.artifacts.values()) 154 | 155 | @property 156 | def is_uploaded(self) -> bool: 157 | """Checks if all of the artifacts in the container are uploaded to the GCS 158 | bucket specified in the config file at ``config.public_bucket_name``. 159 | 160 | Returns: 161 | bool: True if artifact is uploaded, False otherwise. 162 | """ 163 | return all(x.is_uploaded for x in self.artifacts.values()) 164 | 165 | def upload(self, force: bool = False, bucket: "storage.Bucket" = None): 166 | """Uploads all of the artifacts in the container to a GCS bucket, skipping 167 | artifacts that are already uploaded. 168 | 169 | Args: 170 | force (bool, optional): Force upload even if an artifact is already 171 | uploaded. Defaults to False. 172 | bucket (storage.Bucket, optional): The GCS bucket to which the artifacts are 173 | uploaded. Defaults to None, in which case the artifact is uploaded to 174 | the bucket speciried in the config file at config.public_bucket_name. 175 | 176 | Returns: 177 | bool: True if any artifacts were uploaded, False otherwise. 178 | """ 179 | if bucket is None: 180 | client = storage.Client() 181 | bucket = client.get_bucket(config.public_bucket_name) 182 | 183 | return any( 184 | [ 185 | artifact.upload(force=force, bucket=bucket) 186 | for artifact in self.artifacts.values() 187 | ] 188 | ) 189 | 190 | def download(self, force: bool = False) -> bool: 191 | """Downloads artifacts in the container from the GCS bucket specified in the 192 | config file at ``config.public_bucket_name`` to the local directory specified 193 | in the config file at ``config.local_dir``. The relative path to the 194 | artifact within that directory is ``self.path``, which by default is 195 | just the artifact ID with the default extension. 196 | 197 | Args: 198 | force (bool, optional): Force download even if an artifact is already 199 | downloaded. Defaults to False. 200 | 201 | Returns: 202 | bool: True if any artifacts were downloaded, False otherwise. 203 | """ 204 | return any( 205 | [artifact.download(force=force) for artifact in self.artifacts.values()] 206 | ) 207 | 208 | @staticmethod 209 | def from_yaml(loader: yaml.Loader, node): 210 | """This function is called by the YAML loader to convert a YAML node 211 | into an :class:`ArtifactContainer` object. 212 | 213 | It should not be called directly. 214 | """ 215 | data = loader.construct_mapping(node, deep=True) 216 | return data["class"]( 217 | container_id=data["container_id"], 218 | artifacts=data["artifacts"], 219 | attributes=data["attributes"], 220 | ) 221 | 222 | @staticmethod 223 | def to_yaml(dumper: yaml.Dumper, data: ArtifactContainer): 224 | """This function is called by the YAML dumper to convert an 225 | :class:`ArtifactContainer` object into a YAML node. 226 | 227 | It should not be called directly. 228 | """ 229 | data = { 230 | "class": type(data), 231 | "container_id": data.id, 232 | "attributes": data._attributes, 233 | "artifacts": data.artifacts, 234 | } 235 | return dumper.represent_mapping("!ArtifactContainer", data) 236 | 237 | # Provide dict interface for accessing artifacts by name 238 | def __getitem__(self, key): 239 | artifact = self.artifacts.__getitem__(key) 240 | if not artifact.is_downloaded: 241 | artifact.download() 242 | return self.artifacts.__getitem__(key).load() 243 | 244 | def __iter__(self): 245 | return self.artifacts.__iter__() 246 | 247 | def __len__(self): 248 | return self.artifacts.__len__() 249 | 250 | def __getattr__(self, k: str) -> Any: 251 | if k == "_attributes" or k == "attributes": 252 | # avoids recursion error when unpickling an ArtifactContainer 253 | raise AttributeError(k) 254 | 255 | try: 256 | return self.attributes[k] 257 | except KeyError: 258 | raise AttributeError(k) 259 | 260 | def __repr__(self): 261 | artifacts = {k: v.__class__.__name__ for k, v in self.artifacts.items()} 262 | return ( 263 | f"{self.__class__.__name__}(artifacts={artifacts}, " 264 | f"attributes={self.attributes})" 265 | ) 266 | 267 | @classmethod 268 | def _check_artifact_specs(cls, artifacts: Mapping[str, Artifact]): 269 | for name, artifact in artifacts.items(): 270 | if name not in cls.artifact_specs: 271 | raise ValueError( 272 | f"Passed artifact name '{name}', but the specification for" 273 | f" {cls.__name__} doesn't include it." 274 | ) 275 | 276 | # defer the check to see if an artifact can actually be created from the raw 277 | # data to _create_artifacts 278 | if isinstance(artifact, Artifact) and not isinstance( 279 | artifact, cls.artifact_specs[name].artifact_type 280 | ): 281 | raise ValueError( 282 | f"Passed an artifact of type {type(artifact)} to {cls.__name__}" 283 | f" for the artifact named '{name}'. The specification for" 284 | f" {cls.__name__} expects an Artifact of type" 285 | f" {cls.artifact_specs[name].artifact_type}." 286 | ) 287 | 288 | for name, spec in cls.artifact_specs.items(): 289 | if name not in artifacts: 290 | if spec.optional: 291 | continue 292 | raise ValueError( 293 | f"Must pass required artifact with key '{name}' to {cls.__name__}." 294 | ) 295 | 296 | def _create_artifacts(self, artifacts: Mapping[str, Artifact]): 297 | return { 298 | name: artifact 299 | if isinstance(artifact, Artifact) 300 | else self.artifact_specs[name].artifact_type.from_data( 301 | data=artifact, 302 | artifact_id=os.path.join( 303 | self.task_id, 304 | self.container_type, 305 | constants.ARTIFACTS_DIR, 306 | self.id, 307 | name, 308 | ), 309 | ) 310 | for name, artifact in artifacts.items() 311 | } 312 | 313 | 314 | yaml.add_multi_representer(ArtifactContainer, ArtifactContainer.to_yaml) 315 | yaml.add_constructor("!ArtifactContainer", ArtifactContainer.from_yaml) 316 | -------------------------------------------------------------------------------- /dcbench/common/modeling.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import PIL 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | from torch.hub import load_state_dict_from_url 9 | from torchvision.models import DenseNet as _DenseNet 10 | from torchvision.models import ResNet as _ResNet 11 | from torchvision.models.densenet import _load_state_dict 12 | from torchvision.models.densenet import model_urls as densenet_model_urls 13 | from torchvision.models.resnet import BasicBlock, Bottleneck 14 | from torchvision.models.resnet import model_urls as resnet_model_urls 15 | 16 | 17 | class Model(pl.LightningModule): 18 | 19 | DEFAULT_CONFIG = {} 20 | 21 | def __init__(self, config: dict = None): 22 | super().__init__() 23 | self.config = self.DEFAULT_CONFIG.copy() 24 | if config is not None: 25 | self.config.update(config) 26 | 27 | self._set_model() 28 | 29 | @abstractmethod 30 | def _set_model(self): 31 | raise NotImplementedError() 32 | 33 | 34 | class ResNet(_ResNet): 35 | 36 | ACTIVATION_DIMS = [64, 128, 256, 512] 37 | ACTIVATION_WIDTH_HEIGHT = [64, 32, 16, 8] 38 | RESNET_TO_ARCH = {"resnet18": [2, 2, 2, 2], "resnet50": [3, 4, 6, 3]} 39 | 40 | def __init__( 41 | self, 42 | num_classes: int, 43 | arch: str = "resnet18", 44 | dropout: float = 0.0, 45 | pretrained: bool = True, 46 | ): 47 | if arch not in self.RESNET_TO_ARCH: 48 | raise ValueError( 49 | f"config['classifier'] must be one of: {self.RESNET_TO_ARCH.keys()}" 50 | ) 51 | 52 | block = BasicBlock if arch == "resnet18" else Bottleneck 53 | super().__init__(block, self.RESNET_TO_ARCH[arch]) 54 | if pretrained: 55 | state_dict = load_state_dict_from_url( 56 | resnet_model_urls[arch], progress=True 57 | ) 58 | self.load_state_dict(state_dict) 59 | 60 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 61 | self.fc = nn.Sequential( 62 | nn.Dropout(dropout), nn.Linear(512 * block.expansion, num_classes) 63 | ) 64 | 65 | 66 | def default_transform(img: PIL.Image.Image): 67 | return transforms.Compose( 68 | [ 69 | transforms.Resize(256), 70 | transforms.CenterCrop(224), 71 | transforms.ToTensor(), 72 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 73 | ] 74 | )(img) 75 | 76 | 77 | def default_train_transform(img: PIL.Image.Image): 78 | return transforms.Compose( 79 | [ 80 | transforms.RandomResizedCrop(224), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 84 | ] 85 | )(img) 86 | 87 | 88 | class DenseNet(_DenseNet): 89 | 90 | DENSENET_TO_ARCH = { 91 | "densenet121": { 92 | "growth_rate": 32, 93 | "block_config": (6, 12, 24, 16), 94 | "num_init_features": 64, 95 | } 96 | } 97 | 98 | def __init__( 99 | self, num_classes: int, arch: str = "densenet121", pretrained: bool = True 100 | ): 101 | if arch not in self.DENSENET_TO_ARCH: 102 | raise ValueError( 103 | f"config['classifier'] must be one of: {self.DENSENET_TO_ARCH.keys()}" 104 | ) 105 | 106 | super().__init__(**self.DENSENET_TO_ARCH[arch]) 107 | if pretrained: 108 | _load_state_dict(self, densenet_model_urls[arch], progress=True) 109 | 110 | self.classifier = nn.Linear(self.classifier.in_features, num_classes) 111 | 112 | 113 | class VisionClassifier(Model): 114 | 115 | DEFAULT_CONFIG = { 116 | "lr": 1e-4, 117 | "model_name": "resnet", 118 | "arch": "resnet18", 119 | "pretrained": True, 120 | "num_classes": 2, 121 | "transform": default_transform, 122 | "train_transform": default_train_transform, 123 | } 124 | 125 | def _set_model(self): 126 | if self.config["model_name"] == "resnet": 127 | self.model = ResNet( 128 | num_classes=self.config["num_classes"], 129 | arch=self.config["arch"], 130 | pretrained=self.config["pretrained"], 131 | ) 132 | elif self.config["model_name"] == "densenet": 133 | self.model = DenseNet( 134 | num_classes=self.config["num_classes"], arch=self.config["arch"] 135 | ) 136 | else: 137 | raise ValueError(f"Model name {self.config['model_name']} not supported.") 138 | 139 | def forward(self, x): 140 | return self.model(x) 141 | 142 | def training_step(self, batch, batch_idx): 143 | inputs, targets, _ = batch["input"], batch["target"], batch["id"] 144 | outs = self.forward(inputs) 145 | 146 | loss = nn.functional.cross_entropy(outs, targets) 147 | self.log("train_loss", loss, on_step=True, logger=True) 148 | return loss 149 | 150 | def validation_step(self, batch, batch_idx): 151 | inputs, targets = batch["input"], batch["target"] 152 | 153 | outs = self.forward(inputs) 154 | loss = nn.functional.cross_entropy(outs, targets) 155 | self.log("valid_loss", loss) 156 | 157 | def validation_epoch_end(self, outputs) -> None: 158 | for metric_name, metric in self.metrics.items(): 159 | self.log(f"valid_{metric_name}", metric.compute()) 160 | metric.reset() 161 | 162 | def test_epoch_end(self, outputs) -> None: 163 | return self.validation_epoch_end(outputs) 164 | 165 | def test_step(self, batch, batch_idx): 166 | return self.validation_step(batch, batch_idx) 167 | 168 | def configure_optimizers(self): 169 | optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) 170 | return optimizer 171 | -------------------------------------------------------------------------------- /dcbench/common/problem.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import abstractmethod 4 | from typing import TYPE_CHECKING, Any, Callable, Optional 5 | 6 | from .artifact_container import ArtifactContainer 7 | from .result import Result 8 | from .table import Table 9 | 10 | if TYPE_CHECKING: 11 | from .solution import Solution 12 | from .trial import Trial 13 | 14 | 15 | class Problem(ArtifactContainer): 16 | """A logical collection of :class:`Artifact`s and "attributes" that correspond to a 17 | specific problem to be solved. 18 | 19 | See the walkthrough section on :ref:`problem-intro` for more information. 20 | """ 21 | 22 | container_type: str = "problem" 23 | 24 | # these class properties must be filled in by problem subclasses 25 | name: str 26 | summary: str 27 | task_id: str 28 | solution_class: type 29 | 30 | @abstractmethod 31 | def solve(self, **kwargs: Any) -> Solution: 32 | raise NotImplementedError() 33 | 34 | @abstractmethod 35 | def evaluate(self, solution: Solution) -> Result: 36 | raise NotImplementedError() 37 | 38 | 39 | class ProblemTable(Table): 40 | def trial(self, solver: Optional[Callable[[Problem], Solution]] = None) -> Trial: 41 | from .trial import Trial 42 | 43 | return Trial(problems=list(self.values()), solver=solver) 44 | -------------------------------------------------------------------------------- /dcbench/common/result.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | 4 | from .table import RowMixin 5 | 6 | 7 | class Result(RowMixin): 8 | pass 9 | -------------------------------------------------------------------------------- /dcbench/common/solution.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterator, Mapping 3 | 4 | from pandas import Series 5 | 6 | from .artifact_container import ArtifactContainer 7 | 8 | 9 | class Solution(ArtifactContainer): 10 | container_type: str = "solution" 11 | -------------------------------------------------------------------------------- /dcbench/common/solution_set.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from dataclasses import dataclass 4 | import datetime 5 | import uuid 6 | 7 | from dcbench.common.table import RowMixin 8 | 9 | from typing import List 10 | import os 11 | import yaml 12 | from .solution import Solution 13 | from dcbench.config import config 14 | 15 | 16 | @dataclass 17 | class SolutionSet(RowMixin): 18 | 19 | set_id: str 20 | name: str 21 | summary: str 22 | 23 | task_id: str 24 | 25 | def __post_init__(self): 26 | super().__init__(id=self.set_id) 27 | 28 | @property 29 | def local_dir(self): 30 | return os.path.join(config.local_dir, self.set_id) 31 | 32 | @property 33 | def dir(self): 34 | return os.path.join(self.task_id, "solution_sets", self.set_id) 35 | 36 | def _write_solutions(self, containers: List[Solution]): 37 | ids = [] 38 | for container in containers: 39 | assert isinstance(container, self.solution_class) 40 | ids.append(container.id) 41 | 42 | if len(set(ids)) != len(ids): 43 | raise ValueError( 44 | "Duplicate container ids in the containers passed to `write_solutions`." 45 | ) 46 | 47 | 48 | path = os.path.join(self.local_dir, "solutions.yaml") 49 | os.makedirs(self.local_dir, exist_ok=True) 50 | yaml.dump(containers, open(path, "w")) 51 | return path 52 | 53 | def _write_state(self): 54 | path = os.path.join(self.local_dir, "state.yaml") 55 | os.makedirs(self.local_dir, exist_ok=True) 56 | yaml.dump({ 57 | "set_id": self.set_id, 58 | "name": self.name, 59 | "summary": self.summary, 60 | "task_id": self.task_id, 61 | "solution_class": self.solution_class 62 | }, open(path, "w")) 63 | return path 64 | 65 | @property 66 | def solutions_path(self): 67 | return os.path.join(self.dir, "solutions.yaml") 68 | 69 | @property 70 | def solution_class(self): 71 | import dcbench 72 | return dcbench.tasks[self.task_id].solution_class 73 | 74 | 75 | @classmethod 76 | def from_solutions( 77 | cls, 78 | solutions: List[Solution], 79 | name: str = None, 80 | summary: str = None, 81 | task_id: str = None, 82 | ): 83 | 84 | if len(solutions) == 0: 85 | raise ValueError("At least one solution must be provided.") 86 | 87 | if name is None: 88 | name = f"{datetime.date.today():%y-%m-%d-%H-%M-%S}" 89 | set_id = f"{name}-{str(uuid.uuid4())[:8]}" 90 | 91 | task_id = solutions[0].task_id 92 | # check that all solutions have same task id 93 | for solution in solutions: 94 | if solution.task_id != task_id: 95 | raise ValueError( 96 | "All solutions must be from the same task." 97 | ) 98 | 99 | instance = cls( 100 | set_id=set_id, 101 | name=name, 102 | summary=summary, 103 | task_id=task_id, 104 | ) 105 | 106 | instance._write_solutions(solutions) 107 | instance._write_state() 108 | return instance 109 | 110 | @classmethod 111 | def from_dir(cls, dir: str): 112 | state = yaml.load(open(os.path.join(dir, "state.yaml"))) 113 | return cls(**state) -------------------------------------------------------------------------------- /dcbench/common/solve.py: -------------------------------------------------------------------------------- 1 | # from typing import List 2 | 3 | # import pandas as pd 4 | # import ray 5 | 6 | # from dcbench.common.method import Method 7 | 8 | 9 | # @ray.remote 10 | # def _solve_scenario(scenario, method: Method): 11 | # solution = scenario.solve(method) 12 | # solution.dump() 13 | # return solution 14 | 15 | 16 | # def solve(scenarios, methods: List[Method]) -> pd.DataFrame: 17 | # method_refs = [ray.put(method) for method in methods] 18 | # solution_refs = [] 19 | # for scenario in scenarios: 20 | # for method in method_refs: 21 | # solution_refs.append(_solve_scenario.remote(scenario, method)) 22 | # solutions = ray.get(solution_refs) 23 | # return pd.DataFrame([solution.meta() for solution in solutions]) 24 | -------------------------------------------------------------------------------- /dcbench/common/solver.py: -------------------------------------------------------------------------------- 1 | def solver(id: str, summary: str): 2 | def _solver(fn: callable): 3 | fn.id = id 4 | fn.attributes = {"summary": summary} 5 | return fn 6 | 7 | return _solver 8 | -------------------------------------------------------------------------------- /dcbench/common/table.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass 3 | from itertools import chain 4 | from typing import Dict, Iterator, Mapping, Optional, Sequence, Union 5 | 6 | import pandas as pd 7 | 8 | Attribute = Union[int, float, str, bool] 9 | 10 | 11 | @dataclass 12 | class AttributeSpec: 13 | description: str 14 | attribute_type: type 15 | optional: bool = False 16 | 17 | 18 | class RowMixin: 19 | 20 | attribute_specs: Mapping[str, AttributeSpec] 21 | 22 | def __init__(self, id: str, attributes: Mapping[str, Attribute] = None): 23 | self.id = id 24 | self._attributes = attributes 25 | 26 | @property 27 | def attributes(self) -> Optional[Mapping[str, Attribute]]: 28 | return self._attributes 29 | 30 | @attributes.setter 31 | def attributes(self, value: Mapping[str, Attribute]): 32 | self._check_attribute_specs(value) 33 | self._attributes = value 34 | 35 | @classmethod 36 | def _check_attribute_specs(cls, attributes: Mapping[str, Attribute]): 37 | for name, attribute in attributes.items(): 38 | if name not in cls.attribute_specs: 39 | raise ValueError( 40 | f"Passed attribute name '{name}', but the specification for" 41 | f" {cls.__name__} doesn't include it." 42 | ) 43 | 44 | if not isinstance(attribute, cls.attribute_specs[name].attribute_type): 45 | raise ValueError( 46 | f"Passed an attribute of type {type(attribute)} to {cls.__name__}" 47 | f" for the attribute named '{name}'. The specification for" 48 | f" {cls.__name__} expects an attribute of type" 49 | f" {cls.attribute_specs[name].attribute_type}." 50 | ) 51 | for name, attribute_spec in cls.attribute_specs.items(): 52 | if attribute_spec.optional: 53 | continue 54 | if name not in attributes: 55 | raise ValueError( 56 | f"Must pass required attribute with key {name} to {cls.__name__}." 57 | ) 58 | 59 | 60 | class RowUnion(RowMixin): 61 | def __init__(self, id: str, elements: Sequence[RowMixin]): 62 | self._elements = elements 63 | attributes: Dict[str, Attribute] = {} 64 | for element in reversed(elements): 65 | attributes.update(element.attributes) 66 | super().__init__(id, attributes=attributes) 67 | 68 | 69 | def predicate(a: Attribute, b: Union[Attribute, slice, Sequence[Attribute]]) -> bool: 70 | if isinstance(b, slice): 71 | return (b.start is not None and a >= b.start) and ( 72 | b.stop is not None and a < b.stop 73 | ) 74 | elif isinstance(b, Sequence): 75 | return a in b 76 | else: 77 | return a == b 78 | 79 | 80 | class Table(Mapping[str, RowMixin]): 81 | def __init__(self, data: Sequence[RowMixin]): 82 | self._data = {item.id: item for item in data} 83 | 84 | def __getitem__(self, k: str) -> RowMixin: 85 | result = self._data.get(k, None) 86 | if result is None: 87 | raise KeyError() 88 | return result 89 | 90 | def __iter__(self) -> Iterator[str]: 91 | return self._data.__iter__() 92 | 93 | def __len__(self) -> int: 94 | return self._data.__len__() 95 | 96 | def _add_row(self, row: RowMixin) -> None: 97 | self._data[row.id] = row 98 | 99 | @property 100 | def df(self): 101 | return pd.DataFrame.from_dict( 102 | {k: v.attributes for k, v in self._data.items()}, orient="index" 103 | ) 104 | 105 | def where(self, **kwargs: Union[Attribute, slice, Sequence[Attribute]]) -> "Table": 106 | result_data = [ 107 | item 108 | for item in self._data.values() 109 | if all( 110 | predicate(item.attributes.get(k, None), v) for (k, v) in kwargs.items() 111 | ) 112 | ] 113 | return type(self)(result_data) 114 | 115 | def average( 116 | self, *targets: str, groupby: Optional[Sequence[str]] = None, std: bool = False 117 | ) -> "Table": 118 | groupby = groupby or [] 119 | df = self.df[chain(targets, groupby)] 120 | if groupby is not None and len(groupby) > 0: 121 | df = df.groupby(groupby) 122 | df_result = df.mean() 123 | if isinstance(df_result, pd.Series): 124 | df_result = df_result.to_frame().T 125 | if std: 126 | df_std = df.std() 127 | if isinstance(df_std, pd.Series): 128 | df_std = df_std.to_frame().T 129 | df_result = pd.merge( 130 | df_result, 131 | df_std, 132 | left_index=True, 133 | right_index=True, 134 | suffixes=("", ":std"), 135 | ) 136 | df_result = df_result.reset_index() 137 | result_rows = [ 138 | RowMixin(id=str(id), attributes=row) for id, row in df_result.iterrows() 139 | ] 140 | return Table(result_rows) 141 | 142 | def __repr__(self) -> str: 143 | return self.df.__repr__() 144 | 145 | def _repr_html_(self) -> Optional[str]: 146 | return self.df._repr_html_() 147 | 148 | def __add__(self, other: RowMixin) -> "Table": 149 | result = copy.deepcopy(self) 150 | result._add_row(other) 151 | return result 152 | 153 | def __iadd__(self, other: RowMixin) -> "Table": 154 | self._add_row(other) 155 | return self 156 | -------------------------------------------------------------------------------- /dcbench/common/task.py: -------------------------------------------------------------------------------- 1 | from calendar import LocaleTextCalendar 2 | import functools 3 | import os 4 | from dataclasses import dataclass 5 | from typing import List 6 | from urllib.request import urlretrieve 7 | import warnings 8 | import datetime 9 | import uuid 10 | 11 | import yaml 12 | from meerkat.tools.lazy_loader import LazyLoader 13 | from tqdm import tqdm 14 | 15 | from dcbench.common.problem import ProblemTable 16 | from dcbench.common.table import RowMixin, Table 17 | from dcbench.config import config 18 | 19 | from .artifact_container import ArtifactContainer 20 | from .solution import Solution 21 | from .problem import Problem 22 | 23 | storage = LazyLoader("google.cloud.storage") 24 | 25 | 26 | @dataclass 27 | class Task(RowMixin): 28 | 29 | task_id: str 30 | name: str 31 | summary: str 32 | problem_class: type 33 | solution_class: type 34 | baselines: Table = Table([]) 35 | 36 | def __post_init__(self): 37 | super().__init__( 38 | id=self.task_id, attributes={"name": self.name, "summary": self.summary} 39 | ) 40 | 41 | @property 42 | def problems_path(self): 43 | return os.path.join(self.task_id, "problems.yaml") 44 | 45 | @property 46 | def local_problems_path(self): 47 | return os.path.join(config.local_dir, self.problems_path) 48 | 49 | @property 50 | def remote_problems_url(self): 51 | return os.path.join(config.public_remote_url, self.problems_path) 52 | 53 | def write_problems(self, containers: List[Problem], append: bool = True): 54 | ids = [] 55 | for container in containers: 56 | assert isinstance(container, self.problem_class) 57 | ids.append(container.id) 58 | 59 | if len(set(ids)) != len(ids): 60 | raise ValueError( 61 | "Duplicate container ids in the containers passed to `write_problems`." 62 | ) 63 | 64 | if append: 65 | for id, problem in self.problems.items(): 66 | if id not in ids: 67 | containers.append(problem) 68 | 69 | os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True) 70 | yaml.dump(containers, open(self.local_problems_path, "w")) 71 | self._load_problems.cache_clear() 72 | 73 | def solution_set_path(self, set_id: str = None): 74 | if set_id is None: 75 | # create unique id with today's date formatted like YY-MM-DD and a hash 76 | set_id = f"{datetime.date.today():%y-%m-%d}-{str(uuid.uuid4())[:8]}" 77 | return os.path.join(self.task_id, f"solution_sets/{set_id}/solutions.yaml") 78 | 79 | def local_solution_set_path(self, set_id: str = None): 80 | path = self.solution_set_path(set_id=set_id) 81 | return os.path.join(config.local_dir, path) 82 | 83 | 84 | def upload_problems(self, include_artifacts: bool = False, force: bool = True): 85 | """ 86 | Uploads the problems to the remote storage. 87 | 88 | Args: 89 | include_artifacts (bool): If True, also uploads the artifacts of the 90 | problems. 91 | force (bool): If True, if the problem overwrites the remote problems. 92 | Defaults to True. 93 | .. warning:: 94 | 95 | It is somewhat dangerous to set `force=False`, as this could lead 96 | to remote and local problems being out of sync. 97 | """ 98 | client = storage.Client() 99 | bucket = client.get_bucket(config.public_bucket_name) 100 | 101 | local_problems = self.problems 102 | if not force and False: 103 | temp_fp, _ = urlretrieve(self.remote_problems_url) 104 | remote_problems_ids = [ 105 | problem.id 106 | for problem in yaml.load(open(temp_fp), Loader=yaml.FullLoader) 107 | ] 108 | for problem_id in list(local_problems.keys()): 109 | if problem_id in remote_problems_ids: 110 | warnings.warn( 111 | f"Skipping problem {problem_id} because it is already uploaded." 112 | ) 113 | del local_problems._data[problem_id] 114 | 115 | for container in tqdm(local_problems.values()): 116 | assert isinstance(container, self.problem_class) 117 | if include_artifacts: 118 | container.upload(bucket=bucket, force=force) 119 | blob = bucket.blob(self.problems_path) 120 | blob.upload_from_filename(self.local_problems_path) 121 | 122 | def download_problems(self, include_artifacts: bool = False): 123 | os.makedirs(os.path.dirname(self.local_problems_path), exist_ok=True) 124 | # TODO: figure out issue with caching on this call to urlretrieve 125 | urlretrieve(self.remote_problems_url, self.local_problems_path) 126 | self._load_problems.cache_clear() 127 | 128 | for container in self.problems.values(): 129 | assert isinstance(container, self.problem_class) 130 | if include_artifacts: 131 | container.download() 132 | 133 | @functools.lru_cache() 134 | def _load_problems(self): 135 | if not os.path.exists(self.local_problems_path): 136 | self.download_problems() 137 | problems = yaml.load(open(self.local_problems_path), Loader=yaml.FullLoader) 138 | return ProblemTable(problems) 139 | 140 | @property 141 | def problems(self): 142 | return self._load_problems() 143 | 144 | 145 | @property 146 | def solution_sets(self): 147 | return list(os.listdir( 148 | os.path.join(config.local_dir, self.task_id, "solution_sets") 149 | )) 150 | 151 | def __repr__(self): 152 | return f'Task(task_id="{self.task_id}", name="{self.name}")' 153 | 154 | def __hash__(self): 155 | # necessary for lru cache 156 | return hash(repr(self)) 157 | -------------------------------------------------------------------------------- /dcbench/common/trial.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence 2 | 3 | from tqdm import tqdm 4 | 5 | from .table import RowUnion, Table 6 | 7 | if TYPE_CHECKING: 8 | from .problem import Problem 9 | from .result import Result 10 | from .solution import Solution 11 | else: 12 | 13 | class Problem: 14 | pass 15 | 16 | class Solution: 17 | pass 18 | 19 | 20 | class Trial(Table): 21 | def __init__( 22 | self, 23 | problems: Optional[Sequence[Problem]] = None, 24 | solver: Optional[Callable[[Problem], Solution]] = None, 25 | ): 26 | 27 | self.problems = problems or [] 28 | self.solver = solver 29 | self.solutions: Dict[str, Solution] = {} 30 | self.results: Dict[str, Result] = {} 31 | super().__init__([]) 32 | 33 | def evaluate(self, repeat: int = 1, quiet: bool = False) -> "Trial": 34 | assert repeat >= 1 35 | assert self.solver is not None 36 | 37 | problems = self.problems if quiet else tqdm(self.problems, desc="Problems") 38 | for problem in problems: 39 | repetitions = ( 40 | range(repeat) 41 | if quiet or repeat == 1 42 | else tqdm(range(repeat), desc="Repetitions") 43 | ) 44 | for _ in repetitions: 45 | solution = self.solver(problem) 46 | result = problem.evaluate(solution) 47 | self.solutions[solution.id] = solution 48 | self.results[solution.id] = result 49 | self._add_row( 50 | RowUnion(id=solution.id, elements=[problem, solution, result]) 51 | ) 52 | return self 53 | 54 | def save(self) -> None: 55 | raise NotImplementedError() 56 | -------------------------------------------------------------------------------- /dcbench/common/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/dcbench/common/utils.py -------------------------------------------------------------------------------- /dcbench/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | 5 | import yaml 6 | 7 | CONFIG_ENV_VARIABLE = "DCBENCH_CONFIG" 8 | default_local_dir = os.path.join(Path.home(), ".dcbench") 9 | 10 | 11 | def get_config_location(): 12 | path = os.environ.get( 13 | CONFIG_ENV_VARIABLE, os.path.join(default_local_dir, "dcbench-config.yaml") 14 | ) 15 | return path 16 | 17 | 18 | def get_config(): 19 | path = get_config_location() 20 | if not os.path.exists(path): 21 | config = {} 22 | else: 23 | config = yaml.load(open(path, "r"), Loader=yaml.FullLoader) 24 | return config 25 | 26 | 27 | @dataclass 28 | class DCBenchConfig: 29 | 30 | local_dir: str = default_local_dir 31 | public_bucket_name: str = "dcbench" 32 | hidden_bucket_name: str = "dcbench-hidden" 33 | 34 | @property 35 | def public_remote_url(self): 36 | return f"https://storage.googleapis.com/{self.public_bucket_name}" 37 | 38 | @property 39 | def hidden_remote_url(self): 40 | return f"https://storage.googleapis.com/{self.hidden_bucket_name}" 41 | 42 | # dataset specific download directories 43 | celeba_dir: str = os.path.join(default_local_dir, "datasets", "celeba") 44 | imagenet_dir: str = os.path.join(default_local_dir, "datasets", "imagenet") 45 | 46 | 47 | config = DCBenchConfig(**get_config()) 48 | -------------------------------------------------------------------------------- /dcbench/constants.py: -------------------------------------------------------------------------------- 1 | ARTIFACTS_DIR = "artifacts" 2 | PROBLEMS_DIR = "problems" 3 | SOLUTIONS_DIR = "solutions" 4 | PUBLIC_ARTIFACTS_DIR = "public" 5 | HIDDEN_ARTIFACTS_DIR = "optional" 6 | LOCAL_ARTIFACTS_DIR = "local" 7 | 8 | METADATA_FILENAME = "metadata.json" 9 | RESULT_FILENAME = "result.json" 10 | -------------------------------------------------------------------------------- /dcbench/tasks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/dcbench/tasks/.DS_Store -------------------------------------------------------------------------------- /dcbench/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/dcbench/tasks/__init__.py -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from ...common import Task 4 | from ...common.table import Table 5 | from .baselines import cp_clean, random_clean 6 | from .problem import BudgetcleanProblem, BudgetcleanSolution 7 | 8 | __all__ = [""] 9 | 10 | 11 | task = Task( 12 | task_id="budgetclean", 13 | name="Data Cleaning on a Budget ", 14 | summary=( 15 | "When it comes to data preparation, data cleaning is an essential yet " 16 | "quite costly task. If we are given a fixed cleaning budget, the challenge is " 17 | "to find the training data examples that would would bring the biggest " 18 | "positive impact on model performance if we were to clean them." 19 | ), 20 | problem_class=BudgetcleanProblem, 21 | solution_class=BudgetcleanSolution, 22 | baselines=Table( 23 | data=[random_clean, cp_clean], 24 | # attributes=["summary"], 25 | ), 26 | ) 27 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/baselines.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import numpy as np 5 | 6 | from ...common.solver import solver 7 | from .common import Preprocessor 8 | from .cpclean.algorithm.select import entropy_expected 9 | from .cpclean.algorithm.sort_count import sort_count_after_clean_multi 10 | from .cpclean.clean import CPClean, Querier 11 | 12 | # avoid circular dependency 13 | from .problem import BudgetcleanProblem, BudgetcleanSolution 14 | 15 | 16 | @solver( 17 | id="random_clean", summary="Always selects a random subset of the data to clean." 18 | ) 19 | def random_clean(problem: BudgetcleanProblem, seed: int = 1337) -> BudgetcleanSolution: 20 | size = len(problem["X_train_dirty"]) 21 | budget = int(problem.attributes["budget"] * size) 22 | random.seed(seed) 23 | selection = random.sample(range(size), budget) 24 | idx_selected = [idx in selection for idx in range(size)] 25 | return problem.solve(idx_selected=idx_selected) 26 | 27 | 28 | @solver( 29 | id="cp_clean", 30 | summary=( 31 | "Perform the selection using the CPClean algorithm which aims to " 32 | " maximize expected information gain." 33 | ), 34 | ) 35 | def cp_clean( 36 | problem: BudgetcleanProblem, seed: int = 1337, n_jobs=8, kparam=3 37 | ) -> BudgetcleanSolution: 38 | 39 | size = len(problem["X_train_dirty"]) 40 | budget = int(problem.attributes["budget"] * size) 41 | 42 | X_train_dirty = problem["X_train_dirty"] 43 | X_train_clean = problem["X_train_clean"] 44 | y_train = problem["y_train"] 45 | X_val = problem["X_val"] 46 | 47 | # Compute number of repairs. 48 | def length(x): 49 | if isinstance(x, list): 50 | return len(x) 51 | return 0 52 | 53 | num_repairs = X_train_dirty.applymap(length).max().max() 54 | 55 | # Reconstruct separate repair data frames. 56 | X_train_repairs = {} 57 | for i in range(num_repairs): 58 | 59 | def getitem(x): 60 | if isinstance(x, list): 61 | return x[i] if len(x) > i else None 62 | return x 63 | 64 | X_train_repairs["repair%02d" % i] = X_train_dirty.applymap(getitem) 65 | 66 | # Replace lists with None values. 67 | def clearlists(x): 68 | if isinstance(x, list): 69 | return None 70 | return x 71 | 72 | X_train_dirty = X_train_dirty.applymap(clearlists) 73 | 74 | # Preprocess data. 75 | preprocessor = Preprocessor() 76 | preprocessor.fit(X_train_dirty, y_train) 77 | X_train_clean, y_train = preprocessor.transform(X_train_clean, y_train) 78 | X_val = preprocessor.transform(X_val) 79 | for name, X in X_train_repairs.items(): 80 | X_train_repairs[name] = preprocessor.transform(X=X) 81 | 82 | d_train_repairs = [] 83 | repair_methods = sorted(X_train_repairs.keys()) 84 | X_train_repairs_sorted = [X_train_repairs[m] for m in repair_methods] 85 | for X in X_train_repairs_sorted: 86 | d = np.sum((X - X_train_clean) ** 2, axis=1) 87 | d_train_repairs.append(d) 88 | d_train_repairs = np.array(d_train_repairs).T 89 | gt_indices = np.argmin(d_train_repairs, axis=1) 90 | X_train_gt = [] 91 | 92 | for i, gt_i in enumerate(gt_indices): 93 | X_train_gt.append(X_train_repairs_sorted[gt_i][i]) 94 | X_train_gt = np.array(X_train_gt) 95 | 96 | # Perform cleaning using CPClean. 97 | cleaner = CPClean(K=kparam, n_jobs=n_jobs, random_state=seed) 98 | X_train_repairs = np.array([X_train_repairs[k] for k in X_train_repairs]) 99 | space, S_val, gt_indices, MM = cleaner.make_space( 100 | X_train_repairs, X_val, gt=X_train_gt 101 | ) 102 | init_querier = Querier(kparam, S_val, y_train, n_jobs=n_jobs, random_state=seed) 103 | 104 | start = time.time() 105 | after_entropy_val = sort_count_after_clean_multi( 106 | S_val, y_train, kparam, n_jobs, MM=MM 107 | ) 108 | end = time.time() 109 | print("sort_count_after_clean_multi", end - start) 110 | 111 | start = time.time() 112 | _, before_entropy_val = init_querier.run_q2(MM=MM, return_entropy=True) 113 | end = time.time() 114 | print("run_q2", end - start) 115 | 116 | dirty_rows = [i for i, x in enumerate(S_val[0]) if len(x) > 1] 117 | info_gain = entropy_expected( 118 | after_entropy_val, dirty_rows, before_entropy_val, n_jobs=n_jobs 119 | ) 120 | selection = np.argpartition(info_gain, -budget)[-budget:] 121 | 122 | # Produce solution. 123 | idx_selected = [idx in selection for idx in range(len(X_train_dirty))] 124 | return problem.solve(idx_selected=idx_selected) 125 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.impute import SimpleImputer 3 | from sklearn.pipeline import Pipeline 4 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler, OneHotEncoder 5 | 6 | 7 | class Preprocessor(object): 8 | """docstring for Preprocessor.""" 9 | 10 | def __init__(self, num_strategy="mean"): 11 | super(Preprocessor, self).__init__() 12 | self.num_transformer = Pipeline( 13 | steps=[ 14 | ("imputer", SimpleImputer(strategy=num_strategy)), 15 | ("scaler", MinMaxScaler()), 16 | ] 17 | ) 18 | self.feature_enc = OneHotEncoder(sparse=False, handle_unknown="ignore") 19 | self.cat_imputer = SimpleImputer(strategy="constant", fill_value="missing") 20 | self.label_enc = LabelEncoder() 21 | 22 | def fit(self, X_train, y_train, X_full=None): 23 | self.num_features = X_train.select_dtypes(include="number").columns 24 | self.cat_features = X_train.select_dtypes(exclude="number").columns 25 | 26 | if len(self.num_features) > 0: 27 | self.num_transformer.fit(X_train[self.num_features].values) 28 | 29 | if len(self.cat_features) > 0: 30 | if X_full is None: 31 | X_full = X_train 32 | # self.feature_enc.fit(X_full[self.cat_features].values) 33 | # self.cat_imputer.fit(X_train[self.cat_features].values) 34 | self.cat_transformer = Pipeline( 35 | steps=[("imputer", self.cat_imputer), ("onehot", self.feature_enc)] 36 | ) 37 | self.cat_transformer.fit(X_full[self.cat_features].values) 38 | 39 | self.label_enc.fit(y_train.values.ravel()) 40 | 41 | def transform(self, X=None, y=None): 42 | if X is not None: 43 | X_after = [] 44 | if len(self.num_features) > 0: 45 | X_arr = X[self.num_features].values 46 | if len(X_arr.shape) == 1: 47 | X_arr = X_arr.reshape(1, -1) 48 | X_num = self.num_transformer.transform(X_arr) 49 | X_after.append(X_num) 50 | 51 | if len(self.cat_features) > 0: 52 | X_arr = X[self.cat_features].values.astype(object) 53 | if len(X_arr.shape) == 1: 54 | X_arr = X_arr.reshape(1, -1) 55 | X_cat = self.cat_transformer.transform(X_arr) 56 | X_after.append(X_cat) 57 | 58 | X = np.hstack(X_after) 59 | 60 | if y is not None: 61 | y = self.label_enc.transform(y.values.ravel()) 62 | 63 | if X is None: 64 | return y 65 | elif y is None: 66 | return X 67 | else: 68 | return X, y 69 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/README.md: -------------------------------------------------------------------------------- 1 | # Implementation of the CP Clean algorithm 2 | 3 | This code was taken from the [CPClean repository](https://github.com/chu-data-lab/CPClean). 4 | 5 | If you use it in your research, please cite the following paper: 6 | 7 | ```bibtex 8 | @article{karlas2020vldb, 9 | author = {Bojan Karlaš and Peng Li and Renzhi Wu and Nezihe Merve Gürel and Xu Chu and Wentao Wu and Ce Zhang}, 10 | journal = {Proceedings of the VLDB Endowment}, 11 | title = {Nearest neighbor classifiers over incomplete information: From certain answers to certain predictions}, 12 | year = {2020} 13 | } 14 | ``` 15 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/algorithm/min_max.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .utils import majority_vote 4 | 5 | 6 | def min_max(mm, y, K): 7 | """MinMax algorithm. 8 | 9 | Given a similarity matrix, return whether it is CP or not and the 10 | best scenario for each label. mm (np.array): shape Nx2. the min and 11 | max similarity of each row. y (list): labels K (int): KNN 12 | hyperparameter 13 | """ 14 | assert len(set(y)) == 2 15 | pred_set = set() 16 | best_scenarios = {} 17 | 18 | for c in [0, 1]: 19 | best_scenario = np.zeros(len(y)) 20 | mask = y == c 21 | 22 | # set min max 23 | best_scenario[mask] = mm[:, 1][mask] 24 | best_scenario[(mask == False)] = mm[:, 0][(mask == False)] # noqa: E712 25 | 26 | # run KNN 27 | order = np.argsort(-best_scenario, kind="stable") 28 | top_K = y[order][:K] 29 | 30 | pred = majority_vote(top_K) 31 | 32 | if pred == c: 33 | pred_set.add(c) 34 | 35 | best_scenarios[c] = (best_scenario, pred) 36 | 37 | is_cc = len(pred_set) == 1 38 | 39 | return is_cc, best_scenarios, list(pred_set) 40 | 41 | 42 | def min_max_val(MM, y, K): 43 | q1_results = [] 44 | scenarios = [] 45 | cc_preds = [] 46 | for mm in MM: 47 | cc, sc, pred = min_max(mm, y, K) 48 | q1_results.append(cc) 49 | scenarios.append(sc) 50 | cc_preds.append(pred) 51 | 52 | q1_results = np.array(q1_results) 53 | return q1_results, scenarios, cc_preds 54 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/algorithm/select.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def random_select(dirty_rows): 5 | """Random select one from dirty rows.""" 6 | sel = np.random.choice(dirty_rows) 7 | return sel 8 | 9 | 10 | def compute_avg_dirty_entropies(after_entropies, dirty_rows): 11 | avg_entropies = [] 12 | for i in dirty_rows: 13 | if after_entropies[i] is None: 14 | avg_entropies.append(np.nan) 15 | else: 16 | avg_entropies.append(sum(after_entropies[i]) / len(after_entropies[i])) 17 | return np.array(avg_entropies) 18 | 19 | 20 | def min_entropy_expected( 21 | after_entropy_val, dirty_rows, before_entropies_val, n_jobs=4 22 | ): # already checked 23 | """ 24 | Args: 25 | ac_counters_val (list): Counts after clean for each cell for each test example 26 | (only for dirty rows) 27 | dirty rows (list): indices of dirty rows 28 | """ 29 | avg_entropies_val = [ 30 | compute_avg_dirty_entropies(ae, dirty_rows) for ae in after_entropy_val 31 | ] 32 | 33 | for i in range(len(avg_entropies_val)): 34 | mask = np.isnan(avg_entropies_val[i]) 35 | avg_entropies_val[i][mask] = before_entropies_val[i] 36 | 37 | avg_entropies_val = np.array(avg_entropies_val) 38 | info_gain = (before_entropies_val.reshape(-1, 1) - avg_entropies_val).mean(axis=0) 39 | info_gain[info_gain == 0] = float("-inf") 40 | max_idx = np.argmax(info_gain) 41 | sel = dirty_rows[max_idx] 42 | return sel 43 | 44 | 45 | def entropy_expected( 46 | after_entropy_val, dirty_rows, before_entropies_val, n_jobs=4 47 | ): # already checked 48 | """ 49 | Args: 50 | ac_counters_val (list): Counts after clean for each cell for each test example 51 | (only for dirty rows) 52 | dirty rows (list): indices of dirty rows 53 | """ 54 | avg_entropies_val = [ 55 | compute_avg_dirty_entropies(ae, dirty_rows) for ae in after_entropy_val 56 | ] 57 | 58 | for i in range(len(avg_entropies_val)): 59 | mask = np.isnan(avg_entropies_val[i]) 60 | avg_entropies_val[i][mask] = before_entropies_val[i] 61 | 62 | avg_entropies_val = np.array(avg_entropies_val) 63 | info_gain = (before_entropies_val.reshape(-1, 1) - avg_entropies_val).mean(axis=0) 64 | info_gain[info_gain == 0] = float("-inf") 65 | return info_gain 66 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/algorithm/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import numpy as np 4 | from scipy.stats import entropy 5 | 6 | 7 | def compute_entropy_by_counts(counts): 8 | """Compute entropy given counts of each label. 9 | 10 | Args: 11 | counts (dict): {label: count} 12 | """ 13 | s = sum(counts.values()) 14 | if s == 0: 15 | return float("inf") 16 | p = [c / s for c in counts.values()] 17 | return entropy(p) 18 | 19 | 20 | def compute_entropy_by_labels(A): 21 | """Compute entropy over a list of labels. 22 | 23 | Args: 24 | A (list): a list of labels (e.g. [0, 0, 1, 1, 0]) 25 | """ 26 | c = Counter(A) 27 | p = [x / len(A) for x in c.values()] 28 | return entropy(p) 29 | 30 | 31 | def product(a): 32 | """Compute the product of all element in an integer array. 33 | 34 | Args: 35 | A (list): a list of integers 36 | """ 37 | if 0 in a: 38 | return 0 39 | 40 | count = Counter(a) 41 | result = 1 42 | for k, v in count.items(): 43 | result *= k ** v 44 | return result 45 | 46 | 47 | def majority_vote(A): 48 | """Take the majority vote from a list of labels. 49 | 50 | Args: 51 | A (list): a list of labels (e.g. [0, 0, 1, 1, 0]) 52 | """ 53 | major = np.argmax(np.bincount(A)) 54 | return int(major) 55 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/debugger.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import pandas as pd 4 | import utils 5 | 6 | from .knn_evaluator import KNNEvaluator 7 | 8 | 9 | class Debugger(object): 10 | """docstring for Debugger.""" 11 | 12 | def __init__(self, data, model, debug_dir): 13 | self.data = deepcopy(data) 14 | self.K = model["params"]["n_neighbors"] 15 | self.debug_dir = debug_dir 16 | self.logging = [] 17 | self.n_dirty = self.data["X_train_mv"].isnull().values.any(axis=1).sum() 18 | self.n_val = len(self.data["X_val"]) 19 | 20 | def init_log(self, percent_cc): 21 | self.clean_val_acc, self.clean_test_acc = KNNEvaluator( 22 | self.data["X_train_clean"], 23 | self.data["y_train"], 24 | self.data["X_val"], 25 | self.data["y_val"], 26 | self.data["X_test"], 27 | self.data["y_test"], 28 | ).score() 29 | self.gt_val_acc, self.gt_test_acc = KNNEvaluator( 30 | self.data["X_train_gt"], 31 | self.data["y_train"], 32 | self.data["X_val"], 33 | self.data["y_val"], 34 | self.data["X_test"], 35 | self.data["y_test"], 36 | ).score() 37 | self.X_train_mean = deepcopy(self.data["X_train_repairs"]["mean"]) 38 | self.selection = [] 39 | 40 | self.logging = [] 41 | mean_val_acc, mean_test_acc = KNNEvaluator( 42 | self.X_train_mean, 43 | self.data["y_train"], 44 | self.data["X_val"], 45 | self.data["y_val"], 46 | self.data["X_test"], 47 | self.data["y_test"], 48 | ).score() 49 | 50 | self.logging.append( 51 | [ 52 | 0, 53 | self.n_val, 54 | None, 55 | None, 56 | percent_cc, 57 | 0, 58 | self.clean_val_acc, 59 | self.gt_val_acc, 60 | mean_val_acc, 61 | self.clean_test_acc, 62 | self.gt_test_acc, 63 | mean_test_acc, 64 | ] 65 | ) 66 | self.save_log() 67 | 68 | def save_log(self): 69 | columns = [ 70 | "n_iter", 71 | "n_val", 72 | "selection", 73 | "time", 74 | "percent_cc", 75 | "percent_clean", 76 | "clean_val_acc", 77 | "gt_val_acc", 78 | "mean_val_acc", 79 | "clean_test_acc", 80 | "gt_test_acc", 81 | "mean_test_acc", 82 | ] 83 | logging_save = pd.DataFrame(self.logging, columns=columns) 84 | logging_save.to_csv(utils.makedir([self.debug_dir], "details.csv"), index=False) 85 | 86 | def log(self, n_iter, sel, sel_time, percent_cc): 87 | self.selection.append(sel) 88 | 89 | percent_clean = len(self.selection) / self.n_dirty 90 | self.X_train_mean[sel] = self.data["X_train_gt"][sel] 91 | 92 | mean_val_acc, mean_test_acc = KNNEvaluator( 93 | self.X_train_mean, 94 | self.data["y_train"], 95 | self.data["X_val"], 96 | self.data["y_val"], 97 | self.data["X_test"], 98 | self.data["y_test"], 99 | ).score() 100 | 101 | self.logging.append( 102 | [ 103 | n_iter, 104 | self.n_val, 105 | sel, 106 | sel_time, 107 | percent_cc, 108 | percent_clean, 109 | self.clean_val_acc, 110 | self.gt_val_acc, 111 | mean_val_acc, 112 | self.clean_test_acc, 113 | self.gt_test_acc, 114 | mean_test_acc, 115 | ] 116 | ) 117 | 118 | self.percent_clean = percent_clean 119 | self.save_log() 120 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/knn_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_distances(X_train, X_test): 5 | dists = np.array( 6 | [np.sqrt(np.sum((X_train - x_test) ** 2, axis=1)) for x_test in X_test] 7 | ) 8 | return dists 9 | 10 | 11 | def majority_vote(A): 12 | major = np.argmax(np.bincount(A)) 13 | return int(major) 14 | 15 | 16 | class KNNEvaluator(object): 17 | """docstring for KNNEvaluator.""" 18 | 19 | def __init__(self, X_train, y_train, X_val, y_val, X_test, y_test, K=3): 20 | super(KNNEvaluator).__init__() 21 | dists_val = compute_distances(X_train, X_val) 22 | dists_test = compute_distances(X_train, X_test) 23 | self.sim_val = 1 / (1 + dists_val) 24 | self.sim_test = 1 / (1 + dists_test) 25 | self.K = K 26 | self.y_train = y_train 27 | self.y_val = y_val 28 | self.y_test = y_test 29 | 30 | def predict(self, sim): 31 | order = np.argsort(-sim, kind="stable", axis=1) 32 | top_K_idx = order[:, : self.K] 33 | top_K = self.y_train[top_K_idx] 34 | pred = np.array([majority_vote(top) for top in top_K]) 35 | return pred 36 | 37 | def score(self): 38 | pred_val = self.predict(self.sim_val) 39 | pred_test = self.predict(self.sim_test) 40 | 41 | val_acc = (pred_val == self.y_val).mean() 42 | test_acc = (pred_test == self.y_test).mean() 43 | 44 | return val_acc, test_acc 45 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/query.py: -------------------------------------------------------------------------------- 1 | """Solution to three queriers for KNN classifier.""" 2 | import numpy as np 3 | 4 | from .algorithm.min_max import min_max_val 5 | from .algorithm.select import min_entropy_expected, random_select 6 | from .algorithm.sort_count import sort_count_after_clean_multi, sort_count_dp_multi 7 | from .algorithm.utils import compute_entropy_by_counts 8 | 9 | # from .algorithm.sort_count import 10 | 11 | 12 | class Querier(object): 13 | """docstring for Querier.""" 14 | 15 | def __init__(self, K, S_val, y_train, n_jobs=4, random_state=1): 16 | """Constructor. 17 | 18 | Args: 19 | K (int): KNN hyper-parameter 20 | space (list of list of np.array): each row contains a list of candidates 21 | (repairs) for one example 22 | y_train (np.array): labels of training set 23 | X_val (np.array): features of test set 24 | y_val (np.array): list of test set 25 | gt_indices (np.array): the ground truth index in each row 26 | """ 27 | self.K = K 28 | self.S_val = S_val 29 | self.y_train = y_train 30 | self.classes = list(set(y_train)) 31 | self.n_jobs = n_jobs 32 | self.random_state = random_state 33 | 34 | def run_q1(self, return_preds=False, MM=None): 35 | """Solution for q1. 36 | 37 | Return: 38 | q1_results (list of boolean): for each example in test set, whether it can 39 | be CP'ed. 40 | """ 41 | if MM is None: 42 | MM = [] 43 | for S in self.S_val: 44 | mm = np.array([[min(s), max(s)] for s in S]) 45 | MM.append(mm) 46 | 47 | q1_results, _, pred_sets = min_max_val(MM, self.y_train, self.K) 48 | 49 | if return_preds: 50 | return q1_results, pred_sets 51 | else: 52 | return q1_results 53 | 54 | def run_q2(self, return_entropy=False, MM=None): 55 | """Solution for q2. 56 | 57 | Return: 58 | results (list of dict): the number of worlds supporting each label for each 59 | example in test set. 60 | """ 61 | q2_results = sort_count_dp_multi( 62 | self.S_val, self.y_train, self.K, n_jobs=self.n_jobs, MM=MM 63 | ) 64 | if return_entropy: 65 | entropies_val = np.array( 66 | [compute_entropy_by_counts(counts) for counts in q2_results] 67 | ) 68 | return q2_results, entropies_val 69 | else: 70 | return q2_results 71 | 72 | def run_q1q2(self, MM=None, return_entropy=True): 73 | q1_results, cp_preds = self.run_q1(return_preds=True, MM=MM) 74 | 75 | cp_idx = [i for i, cp in enumerate(q1_results) if cp] 76 | not_cp_idx = [i for i, cp in enumerate(q1_results) if not cp] 77 | 78 | q2_cp = [] 79 | for i in cp_idx: 80 | assert len(cp_preds[i]) == 1 81 | pred = cp_preds[i][0] 82 | res = {c: 0 for c in self.classes} 83 | res[pred] = 1 84 | q2_cp.append(res) 85 | 86 | S_val_no_cp = [self.S_val[i] for i in not_cp_idx] 87 | q2_no_cp = sort_count_dp_multi( 88 | S_val_no_cp, self.y_train, self.K, n_jobs=self.n_jobs 89 | ) 90 | q2_results = self.merge_result([q2_cp, q2_no_cp], [cp_idx, not_cp_idx]) 91 | 92 | if return_entropy: 93 | entropies_val = np.array( 94 | [compute_entropy_by_counts(counts) for counts in q2_results] 95 | ) 96 | return q1_results, q2_results, entropies_val 97 | 98 | return q1_results, q2_results 99 | 100 | def run_q3_select(self, method="cpclean", before_entropy_val=None, MM=None): 101 | dirty_rows = [i for i, x in enumerate(self.S_val[0]) if len(x) > 1] 102 | 103 | if method == "cpclean": 104 | assert len(self.S_val) == len(MM) 105 | after_entropy_val = sort_count_after_clean_multi( 106 | self.S_val, self.y_train, self.K, self.n_jobs, MM=MM 107 | ) 108 | 109 | # extract counters for dirty rows 110 | if before_entropy_val is None: 111 | _, before_entropy_val = self.run_q2(MM=MM, return_entropy=True) 112 | 113 | sel = min_entropy_expected( 114 | after_entropy_val, dirty_rows, before_entropy_val, n_jobs=self.n_jobs 115 | ) 116 | 117 | after_entropy_val_sel = [ae[sel] for ae in after_entropy_val] 118 | 119 | return sel, after_entropy_val_sel 120 | elif method == "random": 121 | sel = random_select(dirty_rows) 122 | return sel, None 123 | else: 124 | raise Exception("Wrong method") 125 | 126 | def merge_result(self, results, indices): 127 | res_a, res_b = results 128 | idx_a, idx_b = indices 129 | 130 | l = len(idx_a) + len(idx_b) 131 | merge = [None for _ in range(l)] 132 | for res, i in zip(res_a, idx_a): 133 | merge[i] = res 134 | for res, i in zip(res_b, idx_b): 135 | merge[i] = res 136 | 137 | for res in merge: 138 | assert res is not None 139 | return merge 140 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/cpclean/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from multiprocessing import Process, Queue 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from scipy.stats import entropy 8 | 9 | 10 | def makedir(dir_list, file=None): 11 | save_dir = os.path.join(*dir_list) 12 | if not os.path.exists(save_dir): 13 | os.makedirs(save_dir) 14 | if file is not None: 15 | save_dir = os.path.join(save_dir, file) 16 | return save_dir 17 | 18 | 19 | def dicts_to_csv(dicts, save_path): 20 | result = [] 21 | for res in dicts: 22 | result.append(pd.Series(res)) 23 | result = pd.concat(result).to_frame().transpose() 24 | result.to_csv(save_path, index=False) 25 | 26 | 27 | def load_csv(save_dir): 28 | files = [f for f in os.listdir(save_dir) if f.endswith(".csv")] 29 | data = {} 30 | for f in files: 31 | name = f[:-4] 32 | data[name] = pd.read_csv(os.path.join(save_dir, f)) 33 | return data 34 | 35 | 36 | def load_cache(cache_dir): 37 | with open(os.path.join(cache_dir, "info.json"), "r") as f: 38 | info = json.load(f) 39 | 40 | data = load_csv(cache_dir) 41 | data["X_train_repairs"] = load_csv(os.path.join(cache_dir, "X_train_repairs")) 42 | return data, info 43 | 44 | 45 | def compute_entropy(counts): 46 | """Compute entropy given counts of each label. 47 | 48 | Args: 49 | counts (dict): {label: count} 50 | """ 51 | s = sum(counts.values()) 52 | p = [c / s for c in counts.values()] 53 | return entropy(p) 54 | 55 | 56 | class Pool(object): 57 | """docstring for Pool.""" 58 | 59 | def __init__(self, n_jobs): 60 | super(Pool, self).__init__() 61 | self.n_jobs = n_jobs 62 | 63 | def fn_batch(self, fn, arg_batch, q): 64 | res = [(i, fn(arg)) for i, arg in arg_batch] 65 | q.put(res) 66 | 67 | def array_split(self, arr, n): 68 | if len(arr) > n: 69 | res = [] 70 | idx = np.array_split(np.arange(len(arr)), n) 71 | for i in idx: 72 | res.append([(j, arr[j]) for j in i]) 73 | else: 74 | res = [[(i, a)] for i, a in enumerate(arr)] 75 | return res 76 | 77 | def map(self, fn, args): 78 | arg_batches = self.array_split(args, self.n_jobs) 79 | 80 | q = Queue() 81 | procs = [ 82 | Process(target=self.fn_batch, args=(fn, arg_batch, q)) 83 | for arg_batch in arg_batches 84 | ] 85 | 86 | for p in procs: 87 | p.start() 88 | 89 | results = [] 90 | for p in procs: 91 | results.extend(q.get()) 92 | 93 | for p in procs: 94 | p.join() 95 | 96 | sorted_results = sorted(results) 97 | results = [res for i, res in sorted_results] 98 | return results 99 | -------------------------------------------------------------------------------- /dcbench/tasks/budgetclean/problem.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Mapping 3 | 4 | import pandas as pd 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.linear_model import LogisticRegression 7 | 8 | from dcbench.common import Problem, Result, Solution 9 | from dcbench.common.artifact import CSVArtifact 10 | from dcbench.common.artifact_container import ArtifactSpec 11 | from dcbench.common.table import AttributeSpec 12 | 13 | from .common import Preprocessor 14 | 15 | 16 | class BudgetcleanSolution(Solution): 17 | artifact_specs: Mapping[str, ArtifactSpec] = { 18 | "idx_selected": ArtifactSpec(artifact_type=CSVArtifact, description="") 19 | } 20 | 21 | 22 | class BudgetcleanProblem(Problem): 23 | 24 | artifact_specs: Mapping[str, ArtifactSpec] = { 25 | "X_train_dirty": ArtifactSpec( 26 | artifact_type=CSVArtifact, 27 | description=( 28 | "Features of the dirty training dataset which we need to clean. " 29 | "Each dirty cell contains an embedded list of clean " 30 | "candidate values.", 31 | ), 32 | ), 33 | "X_train_clean": ArtifactSpec( 34 | artifact_type=CSVArtifact, 35 | description="Features of the clean training dataset where each dirty value " 36 | "from the dirty dataset is replaced with the correct " 37 | "clean candidate.", 38 | ), 39 | "y_train": ArtifactSpec( 40 | artifact_type=CSVArtifact, description="Labels of the training dataset." 41 | ), 42 | "X_val": ArtifactSpec( 43 | artifact_type=CSVArtifact, 44 | description="Feature of the validtion dataset which can be used to guide " 45 | "the cleaning optimization process.", 46 | ), 47 | "y_val": ArtifactSpec( 48 | artifact_type=CSVArtifact, description="Labels of the validation dataset." 49 | ), 50 | "X_test": ArtifactSpec( 51 | artifact_type=CSVArtifact, 52 | description=( 53 | "Features of the test dataset used to produce the final evaluation " 54 | "score of the model.", 55 | ), 56 | ), 57 | "y_test": ArtifactSpec( 58 | artifact_type=CSVArtifact, description="Labels of the test dataset." 59 | ), 60 | } 61 | 62 | attribute_specs = { 63 | "budget": AttributeSpec( 64 | attribute_type=float, 65 | description="TODO", 66 | ), 67 | "dataset": AttributeSpec( 68 | attribute_type=str, 69 | description="TODO", 70 | ), 71 | "mode": AttributeSpec( 72 | attribute_type=str, 73 | description="TODO", 74 | ), 75 | "model": AttributeSpec( 76 | attribute_type=str, 77 | description="TODO", 78 | ), 79 | } 80 | 81 | task_id: str = "budgetclean" 82 | 83 | @classmethod 84 | def list(cls): 85 | for scenario_id in cls.scenario_df["id"]: 86 | yield cls.from_id(scenario_id) 87 | 88 | @classmethod 89 | def from_id(cls, scenario_id: str): 90 | pass 91 | 92 | def solve(self, idx_selected: Any, **kwargs: Any) -> Solution: 93 | 94 | # Construct the solution object as a Pandas DataFrame. 95 | idx_selected_df = None 96 | if isinstance(idx_selected, pd.DataFrame): 97 | idx_selected_df = pd.DataFrame( 98 | {"idx_selected": idx_selected.iloc[:, 0].values} 99 | ).astype("bool") 100 | elif isinstance(idx_selected, list): 101 | idx_selected_df = pd.DataFrame({"idx_selected": idx_selected}).astype( 102 | "bool" 103 | ) 104 | else: 105 | raise ValueError( 106 | "The provided idx_selected object must be either a list or a DataFrame." 107 | ) 108 | 109 | # Check if the content of the solution object is valid. 110 | X_train_dirty = self["X_train_dirty"] 111 | if len(X_train_dirty) != len(idx_selected_df): 112 | raise ValueError( 113 | "The number of elements of the provided solution object must be the " 114 | "same as for the training dataset. (expected: %d, found: %d)" 115 | % (len(X_train_dirty), len(idx_selected_df)) 116 | ) 117 | 118 | num_selected = idx_selected_df["idx_selected"].sum() 119 | budget = int(self.attributes["budget"] * len(X_train_dirty)) 120 | if num_selected > budget: 121 | raise ValueError( 122 | "The number of selected data examples is " 123 | "higher than the allowed budget. " 124 | "(expected: %d, found: %d)" % (budget, num_selected) 125 | ) 126 | if num_selected < budget: 127 | warnings.warn( 128 | "The number of selected data examples is below the allowed budget. " 129 | "(expected: %d, found: %d)" % (budget, num_selected) 130 | ) 131 | 132 | # Construct and return a solution object. 133 | solution = BudgetcleanSolution.from_artifacts({"idx_selected": idx_selected_df}) 134 | solution.attributes["problem_id"] = self.id 135 | for k, v in self.attributes.items(): 136 | solution.attributes[k] = v 137 | return solution 138 | 139 | def evaluate(self, solution: BudgetcleanSolution) -> "Result": 140 | 141 | # Load scenario artifacts. 142 | X_train_dirty = self["X_train_dirty"] 143 | X_train_clean = self["X_train_clean"] 144 | y_train = self["y_train"] 145 | X_val = self["X_val"] 146 | y_val = self["y_val"] 147 | X_test = self["X_test"] 148 | y_test = self["y_test"] 149 | 150 | # Replace lists with None values. 151 | def clearlists(x): 152 | if isinstance(x, list): 153 | return None 154 | return x 155 | 156 | X_train_dirty = X_train_dirty.applymap(clearlists) 157 | 158 | # Load solution artifacts. 159 | idx_selected = solution["idx_selected"]["idx_selected"] 160 | 161 | # Determine the solution training datasets. 162 | X_train_solution = X_train_dirty.mask(idx_selected, X_train_clean) 163 | 164 | # Fit data preprocessor. 165 | preprocessor = Preprocessor() 166 | preprocessor.fit(X_train_dirty, y_train) 167 | 168 | # Preprocess the data. 169 | X_train_solution, y_train = preprocessor.transform(X_train_solution, y_train) 170 | X_train_dirty = preprocessor.transform(X_train_dirty) 171 | X_train_clean = preprocessor.transform(X_train_clean) 172 | X_val, y_val = preprocessor.transform(X_val, y_val) 173 | X_test, y_test = preprocessor.transform(X_test, y_test) 174 | 175 | # Train the solution, clean and dirty models. 176 | if self.attributes["model"] == "logreg": 177 | model_solution = LogisticRegression().fit(X_train_solution, y_train) 178 | model_dirty = LogisticRegression().fit(X_train_dirty, y_train) 179 | model_clean = LogisticRegression().fit(X_train_clean, y_train) 180 | elif self.attributes["model"] == "randomf": 181 | model_solution = RandomForestClassifier().fit(X_train_solution, y_train) 182 | model_dirty = RandomForestClassifier().fit(X_train_dirty, y_train) 183 | model_clean = RandomForestClassifier().fit(X_train_clean, y_train) 184 | else: 185 | raise ValueError("Unknown model attribute '%s'." % self.attributes["model"]) 186 | 187 | # Evaluate the model. 188 | result_dict = {} 189 | acc_val_solution = model_solution.score(X_val, y_val) 190 | acc_val_dirty = model_dirty.score(X_val, y_val) 191 | acc_val_clean = model_clean.score(X_val, y_val) 192 | result_dict["acc_val_gapclosed"] = (acc_val_solution - acc_val_dirty) / ( 193 | acc_val_clean - acc_val_dirty 194 | ) 195 | acc_test_solution = model_solution.score(X_test, y_test) 196 | acc_test_dirty = model_dirty.score(X_test, y_test) 197 | acc_test_clean = model_clean.score(X_test, y_test) 198 | result_dict["acc_test_gapclosed"] = (acc_test_solution - acc_test_dirty) / ( 199 | acc_test_clean - acc_test_dirty 200 | ) 201 | 202 | result_dict = {**result_dict, **solution.attributes} 203 | 204 | return Result(id=solution.id, attributes=result_dict) 205 | -------------------------------------------------------------------------------- /dcbench/tasks/minidata/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | from typing import Any, Mapping, Sequence 5 | 6 | import meerkat as mk 7 | import pandas as pd 8 | 9 | from dcbench.common import Problem, Solution, Task 10 | from dcbench.common.artifact import DataPanelArtifact, YAMLArtifact 11 | from dcbench.common.artifact_container import ArtifactSpec 12 | 13 | 14 | class MiniDataSolution(Solution): 15 | 16 | artifact_specs: Mapping[str, ArtifactSpec] = { 17 | "train_ids": ArtifactSpec( 18 | artifact_type=YAMLArtifact, 19 | description=( 20 | "A list of train example ids from the " 21 | " ``id`` column of ``train_data``." 22 | ), 23 | ), 24 | } 25 | task_id: str = "minidata" 26 | 27 | @classmethod 28 | def from_ids(cls, train_ids: Sequence[str], problem_id: str): 29 | cls.from_artifacts( 30 | {"train_ids": train_ids}, attributes={"problem_id": problem_id} 31 | ) 32 | 33 | 34 | class MiniDataProblem(Problem): 35 | 36 | artifact_specs: Mapping[str, ArtifactSpec] = { 37 | "train_data": ArtifactSpec( 38 | artifact_type=DataPanelArtifact, 39 | description="A DataPanel of train examples with columns ``id``, " 40 | "``input``, and ``target``.", 41 | ), 42 | "val_data": ArtifactSpec( 43 | artifact_type=DataPanelArtifact, 44 | description="A DataPanel of validation examples with columns ``id``, " 45 | "``input``, and ``target``.", 46 | ), 47 | "test_data": ArtifactSpec( 48 | artifact_type=DataPanelArtifact, 49 | description="A DataPanel of test examples with columns ``id``, " 50 | "``input``, and ``target``.", 51 | ), 52 | } 53 | 54 | task_id: str = "minidata" 55 | 56 | def solve(self, idx_selected: Any, **kwargs: Any) -> Solution: 57 | 58 | # Construct the solution object as a Pandas DataFrame. 59 | idx_selected_dp = None 60 | if isinstance(idx_selected, mk.DataPanel): 61 | idx_selected_dp = mk.DataPanel( 62 | { 63 | "idx_selected": idx_selected[idx_selected.columns[0]].data.astype( 64 | bool 65 | ) 66 | } 67 | ) 68 | elif isinstance(idx_selected, pd.DataFrame): 69 | idx_selected_dp = mk.DataPanel( 70 | {"idx_selected": idx_selected.iloc[:, 0].values.astype(bool)} 71 | ) 72 | elif isinstance(idx_selected, list): 73 | idx_selected_dp = mk.DataPanel({"idx_selected": idx_selected}).astype( 74 | "bool" 75 | ) 76 | else: 77 | raise ValueError( 78 | "The provided idx_selected object must be either a list or a DataFrame." 79 | ) 80 | 81 | # Check if the content of the solution object is valid. 82 | X_train_dirty = self["X_train_dirty"] 83 | if len(X_train_dirty) != len(idx_selected_dp): 84 | raise ValueError( 85 | "The number of elements of the provided solution object must be the " 86 | "same as for the training dataset. (expected: %d, found: %d)" 87 | % (len(X_train_dirty), len(idx_selected_dp)) 88 | ) 89 | 90 | # Construct and return a solution object. 91 | solution = MiniDataSolution.from_artifacts({"idx_selected": idx_selected_dp}) 92 | solution.attributes["problem_id"] = self.container_id 93 | for k, v in self.attributes.items(): 94 | solution.attributes[k] = v 95 | return solution 96 | 97 | def evaluate(self, solution: Solution): 98 | train_dp = self["train_data"] 99 | train_ids = solution["train_ids"] 100 | 101 | train_dp = train_dp.lz[train_dp["id"].isin(train_ids)] 102 | dirpath = tempfile.mkdtemp() 103 | dp_path = os.path.join(dirpath, "dataset.mk") 104 | 105 | train_dp.write(dp_path) 106 | 107 | from unagi.unagi import main 108 | from unagi.utils.config_utils import build_config 109 | 110 | from .unagi_configs import RESNET_CONFIG 111 | 112 | config = RESNET_CONFIG.copy() 113 | config["dataset"]["path_to_dp"] = dp_path 114 | config["dataset"]["index_name"] = "id" 115 | config = build_config(config) 116 | 117 | main(config) 118 | shutil.rmtree(dirpath) 119 | 120 | # TODO: Plug unagi in here 121 | # model = fit(train_dp) 122 | # score = score(self["test_dp"], model) 123 | # returnscore 124 | 125 | 126 | task = Task( 127 | task_id="minidata", 128 | name="Minimal Data Selection", 129 | # flake8: noqa 130 | summary="Given a large training dataset, what is the smallest subset you can sample that still achieves some threshold of performance.", 131 | problem_class=MiniDataProblem, 132 | solution_class=MiniDataSolution, 133 | baselines=None, 134 | ) 135 | -------------------------------------------------------------------------------- /dcbench/tasks/minidata/unagi_configs.py: -------------------------------------------------------------------------------- 1 | MIXER_CONFIG = { 2 | "model": { 3 | "name": "mixer", 4 | "train": True, 5 | "d": 256, 6 | "num_heads": 8, 7 | "head_dropout": 0.1, 8 | "label_smoothing": True, 9 | "mlp_dim": 512, 10 | "num_layers": 7, 11 | "patch_size": 4, 12 | "dropout": 0.05, 13 | "max_sequence_length": 64, 14 | "patch_emb_type": "square", 15 | }, 16 | "augmentations": { 17 | "raw": { 18 | "image_pil": [ 19 | {"type": "RandomResizeCrop", "params": {"prob": 1.0, "size": 32}}, 20 | {"type": "HorizontalFlip", "params": {"prob": 0.5}}, 21 | {"type": "ColorDistortion", "params": {"prob": 1.0}}, 22 | {"type": "GaussianBlur", "params": {"prob": 0.5, "kernel_size": 3}}, 23 | ], 24 | "image_pil_default_transform_transformer": [ 25 | {"type": "ToTensor"}, 26 | { 27 | "type": "Normalize", 28 | "params": { 29 | "mean": [0.49139968, 0.48215841, 0.44653091], 30 | "std": [0.24703223, 0.24348513, 0.26158784], 31 | }, 32 | }, 33 | {"type": "Reshape2D", "params": {"h_dim": 3, "w_dim": 1024}}, 34 | ], 35 | "image_pil_default_transform_resnet": [ 36 | {"type": "ToTensor"}, 37 | { 38 | "type": "Normalize", 39 | "params": { 40 | "mean": [0.49139968, 0.48215841, 0.44653091], 41 | "std": [0.24703223, 0.24348513, 0.26158784], 42 | }, 43 | }, 44 | ], 45 | }, 46 | "patch": {"type": None}, 47 | "feature": {"type": None}, 48 | }, 49 | "tasks": {"supervised": {"loss_fn": "cross_entropy", "label": "target"}}, 50 | "dataset": { 51 | "name": "meerkat_dataset", 52 | "meerkat_dataset_name": None, 53 | "index_name": "id", 54 | "task": "multi_class", 55 | "input_features": [ 56 | { 57 | "name": "image", 58 | "type": "image", 59 | "transformation": "image_pil", 60 | "default_transformation": "image_pil_default_transform_transformer", 61 | } 62 | ], 63 | "output_features": [{"name": "target"}], 64 | "path_to_dp": None, 65 | "batch_size": 128, 66 | "val_batch_size": 128, 67 | "num_workers": 4, 68 | }, 69 | "learner_config": { 70 | "n_epochs": 2, 71 | "train_split": "train", 72 | "valid_split": None, 73 | "test_split": "test", 74 | "optimizer_config": {"optimizer": "adamw", "lr": 0.001}, 75 | "lr_scheduler_config": { 76 | "lr_scheduler": "plateau", 77 | "lr_scheduler_step_unit": "epoch", 78 | "plateau_config": {"factor": 0.2, "patience": 10, "threshold": 0.01}, 79 | }, 80 | }, 81 | } 82 | 83 | TRANSFORMER_CONFIG = { 84 | "model": { 85 | "name": "patchnet", 86 | "train": True, 87 | "d": 256, 88 | "mlp_dim": 512, 89 | "num_layers": 7, 90 | "patch_size": 4, 91 | "grayscale": False, 92 | "num_heads": 8, 93 | "head_dropout": 0.1, 94 | "label_smoothing": True, 95 | "dropout": 0.05, 96 | "tie_weights": False, 97 | "learn_pos": True, 98 | "use_cls_token": True, 99 | "use_all_tokens": False, 100 | "max_sequence_length": 65, 101 | "patch_emb_type": "square", 102 | }, 103 | "augmentations": { 104 | "raw": { 105 | "image_pil": [ 106 | {"type": "RandomResizeCrop", "params": {"prob": 1.0, "size": 32}}, 107 | {"type": "HorizontalFlip", "params": {"prob": 0.5}}, 108 | {"type": "ColorDistortion", "params": {"prob": 1.0}}, 109 | {"type": "GaussianBlur", "params": {"prob": 0.5, "kernel_size": 3}}, 110 | ], 111 | "image_pil_default_transform_transformer": [ 112 | {"type": "ToTensor"}, 113 | { 114 | "type": "Normalize", 115 | "params": { 116 | "mean": [0.49139968, 0.48215841, 0.44653091], 117 | "std": [0.24703223, 0.24348513, 0.26158784], 118 | }, 119 | }, 120 | {"type": "Reshape2D", "params": {"h_dim": 3, "w_dim": 1024}}, 121 | ], 122 | "image_pil_default_transform_resnet": [ 123 | {"type": "ToTensor"}, 124 | { 125 | "type": "Normalize", 126 | "params": { 127 | "mean": [0.49139968, 0.48215841, 0.44653091], 128 | "std": [0.24703223, 0.24348513, 0.26158784], 129 | }, 130 | }, 131 | ], 132 | }, 133 | "patch": {"type": None}, 134 | "feature": {"type": None}, 135 | }, 136 | "tasks": {"supervised": {"loss_fn": "cross_entropy", "label": "target"}}, 137 | "dataset": { 138 | "name": "meerkat_dataset", 139 | "meerkat_dataset_name": None, 140 | "index_name": "id", 141 | "task": "multi_class", 142 | "batch_size": 128, 143 | "val_batch_size": 128, 144 | "num_workers": 4, 145 | "input_features": [ 146 | { 147 | "name": "image", 148 | "type": "image", 149 | "transformation": "image_pil", 150 | "default_transformation": "image_pil_default_transform_transformer", 151 | } 152 | ], 153 | "output_features": [{"name": "target"}], 154 | "path_to_dp": None, 155 | }, 156 | "learner_config": { 157 | "n_epochs": 2, 158 | "train_split": "train", 159 | "valid_split": None, 160 | "test_split": "test", 161 | "optimizer_config": {"optimizer": "adamw", "lr": 0.001}, 162 | "lr_scheduler_config": { 163 | "lr_scheduler": "plateau", 164 | "lr_scheduler_step_unit": "epoch", 165 | "plateau_config": {"factor": 0.2, "patience": 10, "threshold": 0.01}, 166 | }, 167 | }, 168 | } 169 | 170 | RESNET_CONFIG = { 171 | "model": { 172 | "name": "resnet", 173 | "train": True, 174 | "decoder_hidden_dim": 512, 175 | "decoder_projection_dim": 512, 176 | "model": "resnet50", 177 | }, 178 | "augmentations": { 179 | "raw": { 180 | "image_pil": [ 181 | {"type": "RandomResizeCrop", "params": {"prob": 1.0, "size": 32}}, 182 | {"type": "HorizontalFlip", "params": {"prob": 0.5}}, 183 | {"type": "ColorDistortion", "params": {"prob": 1.0}}, 184 | {"type": "GaussianBlur", "params": {"prob": 0.5, "kernel_size": 3}}, 185 | ], 186 | "image_pil_default_transform_transformer": [ 187 | {"type": "ToTensor"}, 188 | { 189 | "type": "Normalize", 190 | "params": { 191 | "mean": [0.49139968, 0.48215841, 0.44653091], 192 | "std": [0.24703223, 0.24348513, 0.26158784], 193 | }, 194 | }, 195 | {"type": "Reshape2D", "params": {"h_dim": 3, "w_dim": 1024}}, 196 | ], 197 | "image_pil_default_transform_resnet": [ 198 | {"type": "ToTensor"}, 199 | { 200 | "type": "Normalize", 201 | "params": { 202 | "mean": [0.49139968, 0.48215841, 0.44653091], 203 | "std": [0.24703223, 0.24348513, 0.26158784], 204 | }, 205 | }, 206 | ], 207 | }, 208 | "patch": {"type": None}, 209 | "feature": {"type": None}, 210 | }, 211 | "tasks": {"supervised": {"loss_fn": "cross_entropy", "label": "target"}}, 212 | "dataset": { 213 | "name": "meerkat_dataset", 214 | "meerkat_dataset_name": None, 215 | "index_name": "id", 216 | "task": "multi_class", 217 | "input_features": [ 218 | { 219 | "name": "image", 220 | "type": "image", 221 | "transformation": "image_pil", 222 | "default_transformation": "image_pil_default_transform_resnet", 223 | } 224 | ], 225 | "output_features": [{"name": "target"}], 226 | "path_to_dp": None, 227 | "batch_size": 128, 228 | "val_batch_size": 128, 229 | "num_workers": 4, 230 | }, 231 | "learner_config": { 232 | "n_epochs": 2, 233 | "train_split": "train", 234 | "valid_split": None, 235 | "test_split": "test", 236 | "optimizer_config": {"optimizer": "adamw", "lr": 0.001}, 237 | "lr_scheduler_config": { 238 | "lr_scheduler": "plateau", 239 | "lr_scheduler_step_unit": "epoch", 240 | "plateau_config": {"factor": 0.2, "patience": 10, "threshold": 0.01}, 241 | }, 242 | }, 243 | } 244 | -------------------------------------------------------------------------------- /dcbench/tasks/slice_discovery/__init__.py: -------------------------------------------------------------------------------- 1 | from dcbench.common import Task 2 | 3 | #from .baselines import confusion_sdm, domino_sdm 4 | from .problem import SliceDiscoveryProblem, SliceDiscoverySolution 5 | 6 | __all__ = [ 7 | # "confusion_sdm", 8 | # "domino_sdm", 9 | "SliceDiscoveryProblem", 10 | "SliceDiscoverySolution", 11 | ] 12 | 13 | task = Task( 14 | task_id="slice_discovery", 15 | name="Slice Discovery", 16 | summary=( 17 | "Machine learnings models that achieve high overall accuracy often make " 18 | " systematic erors on important subgroups (or *slices*) of data. When working " 19 | " with high-dimensional inputs (*e.g.* images, audio) where data slices are " 20 | " often unlabeled, identifying underperforming slices is challenging. In " 21 | " this task, we'll develop automated slice discovery methods that mine " 22 | " unstructured data for underperforming slices." 23 | ), 24 | problem_class=SliceDiscoveryProblem, 25 | solution_class=SliceDiscoverySolution, 26 | baselines=None, 27 | ) 28 | -------------------------------------------------------------------------------- /dcbench/tasks/slice_discovery/baselines.py: -------------------------------------------------------------------------------- 1 | # import meerkat as mk 2 | # import numpy as np 3 | # from sklearn.decomposition import PCA 4 | 5 | # from ...common.solver import solver 6 | # from .problem import SliceDiscoveryProblem, SliceDiscoverySolution 7 | 8 | 9 | # @solver( 10 | # id="confusion_sdm", 11 | # summary=( 12 | # "A simple slice discovery method that returns a slice corresponding to " 13 | # "each cell of the confusion matrix." 14 | # ), 15 | # ) 16 | # def confusion_sdm(problem: SliceDiscoveryProblem) -> SliceDiscoverySolution: 17 | # """A simple slice discovery method that returns a slice corresponding to 18 | # each cell of the confusion matrix. For example, for a binary prediction 19 | # task, this sdm will return 4 slices corresponding to true positives, false 20 | # positives, true negatives and false negatives. 21 | 22 | # Args: 23 | # problem (SliceDiscoveryProblem): The slice discovery problem. 24 | 25 | # Returns: 26 | # SliceDiscoverySolution: The predicted slices. 27 | # """ 28 | 29 | # # the budget of predicted slices allowed by the problem 30 | # n_pred_slices: int = problem.n_pred_slices 31 | 32 | # # the only aritfact used by this simple baseline is the model predictions 33 | # predictions_dp = problem["test_predictions"] 34 | 35 | # pred_slices = np.stack( 36 | # [ 37 | # (predictions_dp["target"] == target_idx) 38 | # * (predictions_dp["probs"][:, pred_idx]).numpy() 39 | # for target_idx in range(predictions_dp["probs"].shape[1]) 40 | # for pred_idx in range(predictions_dp["probs"].shape[1]) 41 | # ], 42 | # axis=-1, 43 | # ) 44 | # if pred_slices.shape[1] > n_pred_slices: 45 | # raise ValueError( 46 | # "ConfusionSDM is not configured to return enough slices to " 47 | # "capture the full confusion matrix." 48 | # ) 49 | 50 | # if pred_slices.shape[1] < n_pred_slices: 51 | # # fill in the other predicted slices with zeros 52 | # pred_slices = np.concatenate( 53 | # [ 54 | # pred_slices, 55 | # np.zeros((pred_slices.shape[0], n_pred_slices - pred_slices.shape[1])), 56 | # ], 57 | # axis=1, 58 | # ) 59 | 60 | # return problem.solve( 61 | # pred_slices_dp=mk.DataPanel( 62 | # {"id": predictions_dp["id"], "pred_slices": pred_slices} 63 | # ) 64 | # ) 65 | 66 | 67 | # @solver(id="domino_sdm", summary=("An error aware mixture model.")) 68 | # def domino_sdm(problem: SliceDiscoveryProblem) -> SliceDiscoverySolution: 69 | # from .domino import DominoMixture 70 | 71 | # # the budget of predicted slices allowed by the problem 72 | # n_pred_slices: int = problem.n_pred_slices 73 | 74 | # mm = DominoMixture( 75 | # n_components=25, 76 | # weight_y_log_likelihood=10, 77 | # init_params="error", 78 | # covariance_type="diag", 79 | # ) 80 | 81 | # dp = mk.merge(problem["val_predictions"], problem["clip"], on="id") 82 | # emb = dp["emb"] 83 | 84 | # pca = PCA(n_components=128) 85 | # pca.fit(X=emb) 86 | # pca.fit(X=emb) 87 | # emb = pca.transform(X=emb) 88 | 89 | # mm.fit(X=emb, y=dp["target"], y_hat=dp["probs"]) 90 | 91 | # slice_cluster_indices = ( 92 | # -np.abs((mm.y_probs[:, 1] - mm.y_hat_probs[:, 1])) 93 | # ).argsort()[:n_pred_slices] 94 | 95 | # dp = mk.merge(problem["test_predictions"], problem["clip"], on="id") 96 | # emb = dp["emb"] 97 | # clusters = mm.predict_proba( 98 | # X=pca.transform(dp["emb"]), y=dp["target"], y_hat=dp["probs"] 99 | # ) 100 | 101 | # pred_slices = clusters[:, slice_cluster_indices] 102 | 103 | # return problem.solve( 104 | # pred_slices_dp=mk.DataPanel({"id": dp["id"], "pred_slices": pred_slices}) 105 | # ) 106 | -------------------------------------------------------------------------------- /dcbench/tasks/slice_discovery/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from typing import List, Tuple 5 | import meerkat as mk 6 | import numpy as np 7 | import sklearn.metrics as skmetrics 8 | from domino.utils import unpack_args 9 | from scipy.stats import rankdata 10 | import pandas as pd 11 | from tqdm import tqdm 12 | 13 | from dcbench import SliceDiscoveryProblem, SliceDiscoverySolution 14 | 15 | 16 | def compute_metrics(solutions: List[SliceDiscoverySolution], run_id: int = None) -> Tuple[mk.DataPanel]: 17 | global_metrics = [] 18 | slice_metrics = [] 19 | for solution in tqdm(solutions): 20 | g, s = compute_solution_metrics(solution) 21 | global_metrics.append(g) 22 | slice_metrics.extend(s) 23 | return mk.DataPanel(global_metrics), mk.DataPanel(slice_metrics) 24 | 25 | 26 | def compute_solution_metrics( 27 | solution: SliceDiscoverySolution, 28 | ): 29 | metrics = _compute_metrics( 30 | data=solution.merge(), 31 | slice_target_column="slices", 32 | slice_pred_column="slice_preds", 33 | slice_prob_column="slice_probs", 34 | slice_names=solution.problem.slice_names, 35 | ) 36 | for row in metrics: 37 | row["solution_id"] = solution.id 38 | row["problem_id"] = solution.problem_id 39 | return metrics 40 | 41 | 42 | def _compute_metrics( 43 | data: mk.DataPanel, 44 | slice_target_column: str, 45 | slice_pred_column: str, 46 | slice_prob_column: str, 47 | slice_names: List[str], 48 | ): 49 | slice_targets, slice_preds, slice_probs = unpack_args( 50 | data, slice_target_column, slice_pred_column, slice_prob_column 51 | ) 52 | 53 | # consider complements of slices 54 | slice_preds = np.concatenate([slice_preds, 1 - slice_preds], axis=1) 55 | slice_probs = np.concatenate([slice_probs, 1 - slice_probs], axis=1) 56 | 57 | def precision_at_k(slc: np.ndarray, pred_slice: np.ndarray, k: int = 25): 58 | # don't need to check for zero division because we're taking the top_k 59 | return skmetrics.precision_score( 60 | slc, rankdata(-pred_slice, method="ordinal") <= k 61 | ) 62 | 63 | # compute mean response conditional on the slice and predicted slice_targets 64 | def zero_fill_nan_and_infs(x: np.ndarray): 65 | return np.nan_to_num(x, nan=0, posinf=0, neginf=0, copy=False) 66 | 67 | metrics = [] 68 | for slice_idx in range(slice_targets.shape[1]): 69 | slc = slice_targets[:, slice_idx] 70 | slice_name = slice_names[slice_idx] 71 | for pred_slice_idx in range(slice_preds.shape[1]): 72 | slice_pred = slice_preds[:, pred_slice_idx] 73 | slice_prob = slice_probs[:, pred_slice_idx] 74 | 75 | metrics.append( 76 | { 77 | "target_slice_idx": slice_idx, 78 | "target_slice_name": slice_name, 79 | "pred_slice_idx": pred_slice_idx, 80 | "average_precision": skmetrics.average_precision_score( 81 | y_true=slc, y_score=slice_prob 82 | ), 83 | "precision-at-10": precision_at_k(slc, slice_prob, k=10), 84 | "precision-at-25": precision_at_k(slc, slice_prob, k=25), 85 | **dict( 86 | zip( 87 | ["precision", "recall", "f1_score", "support"], 88 | skmetrics.precision_recall_fscore_support( 89 | y_true=slc, 90 | y_pred=slice_pred, 91 | average="binary", 92 | # note: if slc is empty, recall will be 0 and if pred 93 | # is empty precision will be 0 94 | zero_division=0, 95 | ), 96 | ) 97 | ), 98 | } 99 | ) 100 | 101 | df = pd.DataFrame(metrics) 102 | primary_metric = "average_precision" 103 | slice_metrics = df.iloc[ 104 | df.groupby("target_slice_name")[primary_metric].idxmax().astype(int) 105 | ] 106 | return slice_metrics.to_dict("records") 107 | -------------------------------------------------------------------------------- /dcbench/tasks/slice_discovery/pipeline.py: -------------------------------------------------------------------------------- 1 | import dcbench 2 | from domino import SpotlightSlicer, embed, DominoSlicer 3 | 4 | from dcbench.tasks.slice_discovery.run import run_sdms 5 | 6 | if __name__ == "__main__": 7 | 8 | task = dcbench.tasks["slice_discovery"] 9 | 10 | for weight in [1,2,3,5,7,10,20,50]: 11 | solutions, metrics = run_sdms( 12 | list(task.problems.values()), 13 | slicer_class=DominoSlicer, 14 | slicer_config=dict( 15 | y_hat_log_likelihood_weight=weight, 16 | y_log_likelihood_weight=weight, 17 | ), 18 | encoder="clip", 19 | num_workers=10 20 | ) 21 | 22 | # solutions, metrics = run_sdms( 23 | # list(task.problems.values()), 24 | # slicer_class=SpotlightSlicer, 25 | # slicer_config=dict( 26 | # # y_hat_log_likelihood_weight=10, 27 | # # y_log_likelihood_weight=10, 28 | # n_steps=100, 29 | # ), 30 | # encoder="clip", 31 | # num_workers=0 32 | # ) 33 | -------------------------------------------------------------------------------- /dcbench/tasks/slice_discovery/problem.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | 3 | import meerkat as mk 4 | 5 | from dcbench.common import Problem, Solution 6 | from dcbench.common.artifact import ( 7 | DataPanelArtifact, 8 | ModelArtifact, 9 | VisionDatasetArtifact, 10 | ) 11 | from dcbench.common.artifact_container import ArtifactSpec 12 | from dcbench.common.table import AttributeSpec 13 | 14 | 15 | 16 | class SliceDiscoverySolution(Solution): 17 | 18 | artifact_specs: Mapping[str, ArtifactSpec] = { 19 | "pred_slices": ArtifactSpec( 20 | artifact_type=DataPanelArtifact, 21 | description="A DataPanel of predicted slice labels with columns `id`" 22 | " and `pred_slices`.", 23 | ), 24 | } 25 | 26 | attribute_specs = { 27 | "problem_id": AttributeSpec( 28 | description="A unique identifier for this problem.", 29 | attribute_type=str, 30 | ), 31 | "slicer_class": AttributeSpec( 32 | description="The ", 33 | attribute_type=type, 34 | ), 35 | "slicer_config": AttributeSpec( 36 | description="The configuration for the slicer.", 37 | attribute_type=dict, 38 | ), 39 | "embedding_column": AttributeSpec( 40 | description="The column name of the embedding.", 41 | attribute_type=str, 42 | ), 43 | } 44 | 45 | task_id: str = "slice_discovery" 46 | 47 | @property 48 | def problem(self): 49 | from dcbench import tasks 50 | return tasks["slice_discovery"].problems[self.problem_id] 51 | 52 | def merge(self) -> mk.DataPanel: 53 | return self["pred_slices"].merge( 54 | self.problem.merge(split="test", slices=True), on="id", how="left" 55 | ) 56 | 57 | class SliceDiscoveryProblem(Problem): 58 | 59 | artifact_specs: Mapping[str, ArtifactSpec] = { 60 | "val_predictions": ArtifactSpec( 61 | artifact_type=DataPanelArtifact, 62 | description=( 63 | "A DataPanel of the model's predictions with columns `id`," 64 | "`target`, and `probs.`" 65 | ), 66 | ), 67 | "test_predictions": ArtifactSpec( 68 | artifact_type=DataPanelArtifact, 69 | description=( 70 | "A DataPanel of the model's predictions with columns `id`," 71 | "`target`, and `probs.`" 72 | ), 73 | ), 74 | "test_slices": ArtifactSpec( 75 | artifact_type=DataPanelArtifact, 76 | description="A DataPanel of the ground truth slice labels with columns " 77 | " `id`, `slices`.", 78 | ), 79 | "activations": ArtifactSpec( 80 | artifact_type=DataPanelArtifact, 81 | description="A DataPanel of the model's activations with columns `id`," 82 | "`act`", 83 | ), 84 | "model": ArtifactSpec( 85 | artifact_type=ModelArtifact, 86 | description="A trained PyTorch model to audit.", 87 | ), 88 | "base_dataset": ArtifactSpec( 89 | artifact_type=VisionDatasetArtifact, 90 | description="A DataPanel representing the base dataset with columns `id` " 91 | "and `image`.", 92 | ), 93 | "clip": ArtifactSpec( 94 | artifact_type=DataPanelArtifact, 95 | description="A DataPanel of the image embeddings from OpenAI's CLIP model", 96 | ), 97 | } 98 | 99 | attribute_specs = { 100 | "n_pred_slices": AttributeSpec( 101 | description="The number of slice predictions that each slice discovery " 102 | "method can return.", 103 | attribute_type=int, 104 | ), 105 | "slice_category": AttributeSpec( 106 | description="The type of slice .", attribute_type=str 107 | ), 108 | "target_name": AttributeSpec( 109 | description="The name of the target column in the dataset.", 110 | attribute_type=str, 111 | ), 112 | "dataset": AttributeSpec( 113 | description="The name of the dataset being audited.", 114 | attribute_type=str, 115 | ), 116 | "alpha": AttributeSpec( 117 | description="The alpha parameter for the AUC metric.", 118 | attribute_type=float, 119 | ), 120 | "slice_names": AttributeSpec( 121 | description="The names of the slices in the dataset.", 122 | attribute_type=list, 123 | ), 124 | } 125 | 126 | task_id: str = "slice_discovery" 127 | 128 | def merge(self, split="val", slices: bool = False): 129 | base_dataset = self["base_dataset"] 130 | base_dataset = base_dataset[[c for c in base_dataset.columns if c != "split"]] 131 | dp = self[f"{split}_predictions"].merge( 132 | base_dataset, on="id", how="left" 133 | ) 134 | if slices: 135 | dp = dp.merge(self[f"{split}_slices"], on="id", how="left") 136 | return dp 137 | 138 | def solve(self, pred_slices_dp: mk.DataPanel) -> SliceDiscoverySolution: 139 | if ("id" not in pred_slices_dp) or ("pred_slices" not in pred_slices_dp): 140 | raise ValueError( 141 | f"DataPanel passed to {self.__class__.__name__} must include columns " 142 | "`id` and `pred_slices`" 143 | ) 144 | 145 | return SliceDiscoverySolution( 146 | artifacts={"pred_slices": pred_slices_dp}, 147 | attributes={"problem_id": self.id}, 148 | ) 149 | 150 | 151 | def evaluate(self): 152 | pass 153 | -------------------------------------------------------------------------------- /dcbench/tasks/slice_discovery/run.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | from contextlib import redirect_stdout 4 | import dataclasses 5 | from gettext import dpgettext 6 | import io 7 | import itertools 8 | 9 | from random import choice, sample 10 | from typing import Collection, Dict, Iterable, List, Mapping, Tuple, Union 11 | from dataclasses import dataclass 12 | from domino import embed 13 | from sklearn.linear_model import LinearRegression 14 | 15 | import pandas as pd 16 | from scipy.stats import rankdata 17 | import numpy as np 18 | import meerkat as mk 19 | from tqdm.auto import tqdm 20 | import os 21 | 22 | from domino.utils import unpack_args 23 | from dcbench import Artifact 24 | from dcbench import SliceDiscoveryProblem, SliceDiscoverySolution 25 | import dcbench 26 | from .metrics import compute_solution_metrics 27 | 28 | task = dcbench.tasks["slice_discovery"] 29 | 30 | def _run_sdms(problems: List[SliceDiscoveryProblem], **kwargs): 31 | result = [] 32 | for problem in problems: 33 | #f = io.StringIO() 34 | #with redirect_stdout(f): 35 | result.append(run_sdm(problem, **kwargs)) 36 | return result 37 | 38 | def run_sdms( 39 | problems: List[SliceDiscoveryProblem], 40 | slicer_class: type, 41 | slicer_config: dict, 42 | encoder: str = "clip", 43 | variant: str = "ViT-B/32", 44 | batch_size: int = 1, 45 | num_workers: int = 0, 46 | ): 47 | 48 | # prepare embeddings 49 | base_datasets = set([p.artifacts["base_dataset"].id for p in problems]) 50 | embs = {} 51 | for base_dataset in base_datasets: 52 | dataset_artifact = dcbench.VisionDatasetArtifact(base_dataset) 53 | emb_artifact_id = f"common/embeddings/{base_dataset}/{encoder}-{variant.replace('/', '-')}" 54 | 55 | emb_artifact = dcbench.DataPanelArtifact(emb_artifact_id) 56 | if os.path.exists(emb_artifact.local_path): 57 | emb_dp = emb_artifact.load() 58 | else: 59 | dataset_artifact.download() 60 | emb_dp = embed( 61 | dataset_artifact.load(), 62 | input_col="image", 63 | encoder=encoder, 64 | variant=variant, 65 | device=0, 66 | num_workers=12 67 | ) 68 | emb_dp["emb"] = emb_dp[f"{encoder}(image)"] 69 | emb_dp.remove_column(f"{encoder}(image)") 70 | emb_artifact = Artifact.from_data(emb_dp, artifact_id=emb_artifact_id) 71 | embs[base_dataset] = emb_dp 72 | 73 | if num_workers > 0: 74 | import ray 75 | 76 | ray.init() 77 | run_fn = ray.remote(_run_sdms).remote 78 | embs = ray.put(embs) 79 | else: 80 | run_fn = _run_sdms 81 | 82 | total_batches = len(problems) 83 | results = [] 84 | t = tqdm(total=total_batches) 85 | 86 | for start_idx in range(0, len(problems), batch_size): 87 | batch = problems[start_idx : start_idx + batch_size] 88 | 89 | result = run_fn( 90 | problems=batch, 91 | embs=embs, 92 | slicer_class=slicer_class, 93 | slicer_config=slicer_config, 94 | ) 95 | 96 | if num_workers == 0: 97 | t.update(n=len(result)) 98 | results.extend(result) 99 | else: 100 | # in the parallel case, this is a single object reference 101 | # moreover, the remote returns immediately so we don't update tqdm 102 | results.append(result) 103 | 104 | if num_workers > 0: 105 | # if we're working in parallel, we need to wait for the results to come back 106 | # and update the tqdm accordingly 107 | result_refs = results 108 | results = [] 109 | while result_refs: 110 | done, result_refs = ray.wait(result_refs) 111 | for result in done: 112 | result = ray.get(result) 113 | results.extend(result) 114 | t.update(n=len(result)) 115 | ray.shutdown() 116 | solutions, metrics = zip(*results) 117 | # flatten the list of lists 118 | metrics = [row for slices in metrics for row in slices] 119 | 120 | path = task.write_solutions(solutions) 121 | metrics_df = pd.DataFrame(metrics) 122 | metrics_df.to_csv(os.path.join(os.path.dirname(path), "metrics.csv"), index=False) 123 | 124 | return solutions, metrics_df 125 | 126 | 127 | def run_sdm( 128 | problem: SliceDiscoveryProblem, 129 | slicer_class: type, 130 | slicer_config: dict, 131 | embs: Mapping[str, mk.DataPanel], 132 | ) -> SliceDiscoverySolution: 133 | emb_dp = embs[problem.artifacts["base_dataset"].id] 134 | val_dp = problem.merge(split="val") 135 | val_dp = val_dp.merge(emb_dp["id", "emb"], on="id", how="left") 136 | 137 | slicer = slicer_class(pbar=False, n_slices=problem.n_pred_slices, **slicer_config) 138 | slicer.fit( 139 | val_dp, embeddings="emb", targets="target", pred_probs="probs" 140 | ) 141 | 142 | test_dp = problem.merge(split="test") 143 | test_dp = test_dp.merge(emb_dp["id", "emb"], on="id", how="left") 144 | result = mk.DataPanel({"id": test_dp["id"]}) 145 | result["slice_preds"] = slicer.predict( 146 | test_dp, embeddings="emb", targets="target", pred_probs="probs" 147 | ) 148 | result["slice_probs"] = slicer.predict_proba( 149 | test_dp, embeddings="emb", targets="target", pred_probs="probs" 150 | ) 151 | 152 | solution = SliceDiscoverySolution( 153 | artifacts={ 154 | "pred_slices": result, 155 | }, 156 | attributes={ 157 | "problem_id": problem.id, 158 | "slicer_class": slicer_class, 159 | "slicer_config": slicer_config, 160 | "embedding_column": "emb", 161 | } 162 | ) 163 | metrics = compute_solution_metrics( 164 | solution, 165 | ) 166 | return solution, metrics 167 | -------------------------------------------------------------------------------- /dcbench/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.4" 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | populate_docs: 5 | python populate_docs.py 6 | 7 | 8 | # You can set these variables from the command line, and also 9 | # from the environment for the first two. 10 | SPHINXOPTS ?= 11 | SPHINXBUILD ?= sphinx-build 12 | SOURCEDIR = source 13 | BUILDDIR = build 14 | 15 | # Put it first so that "make" without argument is like "make help". 16 | help: 17 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 18 | 19 | .PHONY: help Makefile 20 | 21 | # Catch-all target: route all unknown targets to Sphinx using the new 22 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 23 | %: Makefile populate_docs 24 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | -------------------------------------------------------------------------------- /docs/assets/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/docs/assets/banner.png -------------------------------------------------------------------------------- /docs/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/docs/assets/logo.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/populate_docs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from tabulate import tabulate 5 | 6 | import dcbench 7 | 8 | BUCKET_BROWSER_URL = "https://console.cloud.google.com/storage/browser/dcbench" 9 | 10 | 11 | def get_rst_class_ref(klass: type): 12 | return f":class:`dcbench.{klass.__name__}`" 13 | 14 | 15 | def get_link(text: str, url: str): 16 | return f"`{text} <{url}>`_" 17 | 18 | 19 | def get_artifact_table(task: dcbench.Task): 20 | df = pd.DataFrame( 21 | [ 22 | { 23 | "name": f"``{name}``", 24 | "type": get_rst_class_ref(spec.artifact_type), 25 | "description": spec.description, 26 | } 27 | for name, spec in task.artifact_specs.items() 28 | ] 29 | ).set_index(keys="name") 30 | 31 | return tabulate(df, headers="keys", tablefmt="rst") 32 | 33 | 34 | sections = [".. _tasks:\n\n🎯 Tasks\n========="] 35 | template = open("source/task_template.rst").read() 36 | for task in dcbench.tasks.values(): 37 | longer_description = open( 38 | os.path.join("source/task_descriptions", f"{task.task_id}.rst") 39 | ).read() 40 | section = template.format( 41 | name=task.name, 42 | summary=task.summary, 43 | num_problems=len(task.problems), 44 | task_id=task.task_id, 45 | problem_class=get_rst_class_ref(task.problem_class), 46 | problem_artifact_table=get_artifact_table(task.problem_class), 47 | solution_class=get_rst_class_ref(task.solution_class), 48 | solution_artifact_table=get_artifact_table(task.solution_class), 49 | storage_url=os.path.join(BUCKET_BROWSER_URL, task.task_id), 50 | longer_description=longer_description, 51 | ) 52 | 53 | sections.append(section) 54 | 55 | open("source/tasks.rst", "w").write("\n\n".join(sections)) 56 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme 2 | sphinx_autodoc_typehints 3 | nbsphinx 4 | recommonmark 5 | toml  6 | furo 7 | ipython 8 | git+https://github.com/data-centric-ai/dcbench@main#dcbench -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.budgetclean.rst: -------------------------------------------------------------------------------- 1 | dcbench.tasks.budgetclean package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dcbench.tasks.budgetclean.common module 8 | ------------------------------------- 9 | 10 | .. automodule:: dcbench.tasks.budgetclean.common 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | dcbench.tasks.budgetclean.problem module 16 | -------------------------------------- 17 | 18 | .. automodule:: dcbench.tasks.budgetclean.problem 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: dcbench.tasks.budgetclean 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.common.rst: -------------------------------------------------------------------------------- 1 | dcbench.common package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dcbench.common.artifact module 8 | ------------------------------ 9 | 10 | .. automodule:: dcbench.common.artifact 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | dcbench.common.artifact\_container module 16 | ----------------------------------------- 17 | 18 | .. automodule:: dcbench.common.artifact_container 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | dcbench.common.method module 24 | ---------------------------- 25 | 26 | .. automodule:: dcbench.common.method 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | dcbench.common.modeling module 32 | ------------------------------ 33 | 34 | .. automodule:: dcbench.common.modeling 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | dcbench.common.problem module 40 | ----------------------------- 41 | 42 | .. automodule:: dcbench.common.problem 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | dcbench.common.result module 48 | ---------------------------- 49 | 50 | .. automodule:: dcbench.common.result 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | dcbench.common.solution module 56 | ------------------------------ 57 | 58 | .. automodule:: dcbench.common.solution 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | dcbench.common.solve module 64 | --------------------------- 65 | 66 | .. automodule:: dcbench.common.solve 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | dcbench.common.solver module 72 | ---------------------------- 73 | 74 | .. automodule:: dcbench.common.solver 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | dcbench.common.table module 80 | --------------------------- 81 | 82 | .. automodule:: dcbench.common.table 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | dcbench.common.task module 88 | -------------------------- 89 | 90 | .. automodule:: dcbench.common.task 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | dcbench.common.trial module 96 | --------------------------- 97 | 98 | .. automodule:: dcbench.common.trial 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | dcbench.common.utils module 104 | --------------------------- 105 | 106 | .. automodule:: dcbench.common.utils 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | Module contents 112 | --------------- 113 | 114 | .. automodule:: dcbench.common 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.rst: -------------------------------------------------------------------------------- 1 | dcbench package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | dcbench.common 11 | dcbench.tasks 12 | 13 | Submodules 14 | ---------- 15 | 16 | dcbench.config module 17 | --------------------- 18 | 19 | .. automodule:: dcbench.config 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | 24 | dcbench.constants module 25 | ------------------------ 26 | 27 | .. automodule:: dcbench.constants 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | dcbench.version module 33 | ---------------------- 34 | 35 | .. automodule:: dcbench.version 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: dcbench 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.tasks.budgetclean.rst: -------------------------------------------------------------------------------- 1 | dcbench.tasks.budgetclean package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | dcbench.tasks.budgetclean.baselines module 8 | ------------------------------------------ 9 | 10 | .. automodule:: dcbench.tasks.budgetclean.baselines 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | dcbench.tasks.budgetclean.common module 16 | --------------------------------------- 17 | 18 | .. automodule:: dcbench.tasks.budgetclean.common 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | dcbench.tasks.budgetclean.problem module 24 | ---------------------------------------- 25 | 26 | .. automodule:: dcbench.tasks.budgetclean.problem 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: dcbench.tasks.budgetclean 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.tasks.minidata.rst: -------------------------------------------------------------------------------- 1 | dcbench.tasks.minidata package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dcbench.tasks.minidata.unagi\_configs module 8 | -------------------------------------------- 9 | 10 | .. automodule:: dcbench.tasks.minidata.unagi_configs 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: dcbench.tasks.minidata 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.tasks.rst: -------------------------------------------------------------------------------- 1 | dcbench.tasks package 2 | ===================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | dcbench.tasks.budgetclean 11 | dcbench.tasks.minidata 12 | dcbench.tasks.slice_discovery 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: dcbench.tasks 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.tasks.slice.rst: -------------------------------------------------------------------------------- 1 | dcbench.tasks.slice package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dcbench.tasks.slice.build\_problems module 8 | ------------------------------------------ 9 | 10 | .. automodule:: dcbench.tasks.slice.build_problems 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: dcbench.tasks.slice 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/apidocs/dcbench.tasks.slice_discovery.rst: -------------------------------------------------------------------------------- 1 | dcbench.tasks.slice\_discovery package 2 | ====================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dcbench.tasks.slice\_discovery.baselines module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: dcbench.tasks.slice_discovery.baselines 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | dcbench.tasks.slice\_discovery.metrics module 16 | --------------------------------------------- 17 | 18 | .. automodule:: dcbench.tasks.slice_discovery.metrics 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | dcbench.tasks.slice\_discovery.problem module 24 | --------------------------------------------- 25 | 26 | .. automodule:: dcbench.tasks.slice_discovery.problem 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: dcbench.tasks.slice_discovery 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/apidocs/modules.rst: -------------------------------------------------------------------------------- 1 | dcbench 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | dcbench 8 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | from pathlib import Path 16 | 17 | version_path = Path(__file__).parent.parent.parent / "dcbench" / "version.py" 18 | metadata = {} 19 | with open(str(version_path)) as ver_file: 20 | exec(ver_file.read(), metadata) 21 | 22 | sys.path.insert(0, os.path.abspath(".")) 23 | sys.path.insert(0, os.path.abspath("..")) 24 | sys.path.insert(0, os.path.abspath("../..")) 25 | sys.setrecursionlimit(1500) 26 | 27 | 28 | # -- Project information ----------------------------------------------------- 29 | 30 | project = "dcbench" 31 | copyright = "2021, Data Centric AI" 32 | author = "Data Centric AI" 33 | 34 | 35 | # -- General configuration --------------------------------------------------- 36 | 37 | # Add any Sphinx extension module names here, as strings. They can be 38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 39 | # ones. 40 | extensions = [ 41 | "sphinx.ext.autodoc", 42 | "sphinx.ext.coverage", 43 | "sphinx.ext.napoleon", 44 | "sphinx.ext.viewcode", 45 | "sphinx.ext.autodoc.typehints", 46 | "sphinx.ext.autosummary", 47 | "sphinx.ext.autosectionlabel", 48 | "IPython.sphinxext.ipython_directive", 49 | "IPython.sphinxext.ipython_console_highlighting", 50 | "sphinx_rtd_theme", 51 | "nbsphinx", 52 | "recommonmark", 53 | ] 54 | autodoc_typehints = "description" 55 | 56 | # useful for when you have multiple headers with the same name 57 | autosectionlabel_prefix_document = True 58 | 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ["_templates"] 61 | 62 | # List of patterns, relative to source directory, that match files and 63 | # directories to ignore when looking for source files. 64 | # This pattern also affects html_static_path and html_extra_path. 65 | exclude_patterns = [] 66 | 67 | 68 | # -- Options for HTML output ------------------------------------------------- 69 | 70 | # The theme to use for HTML and HTML Help pages. See the documentation for 71 | # a list of builtin themes. 72 | # 73 | html_theme = "furo" 74 | 75 | 76 | # Add any paths that contain custom static files (such as style sheets) here, 77 | # relative to this directory. They are copied after the builtin static files, 78 | # so a file named "default.css" will overwrite the builtin "default.css". 79 | html_static_path = ["_static"] 80 | 81 | # The name of an image file (relative to this directory) to place at the top of 82 | # the title page. 83 | html_logo = "../assets/logo.png" 84 | 85 | html_theme_options = {} 86 | 87 | # Don't show module names in front of class names. 88 | add_module_names = False 89 | 90 | # Don't alphabetize the method names. 91 | autodoc_member_order = "bysource" 92 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. DCBench documentation master file, created by 2 | sphinx-quickstart on Fri Jan 1 16:41:09 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to dcbench 7 | ========================================== 8 | 9 | .. _Issues: https://github.com/data-centric-ai/dcbench/issues/ 10 | .. _installation: getting-started/install.md 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | 15 | intro.md 16 | tasks.rst 17 | install.md 18 | 19 | 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | :caption: API Docs 24 | 25 | apidocs/dcbench.rst 26 | 27 | 28 | .. 29 | Indices and tables 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | 36 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _installing: 3 | 4 | 🚀 Installing dcbench 5 | ============================ 6 | 7 | This section describes how to install the ``dcbench`` Python package. 8 | 9 | .. code-block:: bash 10 | 11 | pip install dcbench 12 | 13 | .. admonition:: Optional 14 | 15 | Some parts of ``dcbench`` rely on optional dependencies. If you know which optional dependencies you'd like to install, you can do so using something like ``pip install dcbench[dev]`` instead. See ``setup.py`` for a full list of optional dependencies. 16 | 17 | Installing from branch 18 | ----------------------- 19 | 20 | To install from a specific branch use the command below, replacing ``main`` with the name of `any branch in the dcbench repository `_. 21 | 22 | .. code-block:: bash 23 | 24 | pip install "dcbench @ git+https://github.com/data-centric-ai/dcbench@main" 25 | 26 | 27 | Installing from clone 28 | ----------------------- 29 | You can install from a clone of the ``dcbench`` `repo `_ with: 30 | 31 | .. code-block:: bash 32 | 33 | git clone https://github.com/data-centric-ai/dcbench.git 34 | cd dcbench 35 | pip install -e . 36 | 37 | .. _configuring: 38 | 39 | ⚙️ Configuring dcbench 40 | ============================ 41 | 42 | Several aspects of ``dcbench`` behavior can be configured by the user. 43 | For example, one may wish to change the directory in which ``dcbench`` downloads artifacts (by default this is ``~/.dcbench``). 44 | 45 | You can see the current state of the ``dcbench`` configuration with: 46 | 47 | .. ipython:: python 48 | 49 | import dcbench 50 | dcbench.config 51 | 52 | Configuring with YAML 53 | ---------------------- 54 | 55 | To change the configuration create a YAML file, like the one below: 56 | 57 | .. code-block:: yaml 58 | local_dir: "/path/to/storage" 59 | public_bucket_name: "dcbench-test" 60 | 61 | Then set the environment variable ``DCBENCH_CONFIG`` to point to the file: 62 | 63 | .. code-block:: bash 64 | 65 | export DCBENCH_CONFIG="/path/to/dcbench-config.yaml" 66 | 67 | If you're using a conda, you can permanently set this variable for your environment: 68 | 69 | .. code-block:: bash 70 | 71 | conda env config vars set DCBENCH_CONFIG="path/to/dcbench-config.yaml" 72 | conda activate env_name # need to reactivate the environment 73 | 74 | 75 | Configuring Programmatically 76 | ------------------------------ 77 | 78 | You can also update the config programmatically, though unlike the YAML method above, these changes will not persist beyond the lifetime of your program. 79 | 80 | .. code-block:: python 81 | 82 | dcbench.config.local_dir = "/path/to/storage" 83 | dcbench.config.public_bucket_name = "dcbench-test" 84 | 85 | -------------------------------------------------------------------------------- /docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | 💡 What is dcbench? 2 | ------------------- 3 | 4 | This benchmark evaluates the steps in your machine learning workflow beyond model training and tuning. This includes feature cleaning, slice discovery, and coreset selection. We call these “data-centric” tasks because they're focused on exploring and manipulating data – not training models. ``dcbench`` supports a growing number of them: 5 | 6 | * :any:`minidata`: Find the smallest subset of training data on which a fixed model architecture achieves accuracy above a threshold. 7 | * :any:`slice_discovery`: Identify subgroups on which a model underperforms. 8 | * :any:`budgetclean`: Given a fixed budget, clean input features of training data to improve model performance. 9 | 10 | 11 | ``dcbench`` includes tasks that look very different from one another: the inputs and 12 | outputs of the slice discovery task are not the same as those of the 13 | minimal data cleaning task. However, we think it important that 14 | researchers and practitioners be able to run evaluations on data-centric 15 | tasks across the ML lifecycle without having to learn a bunch of 16 | different APIs or rewrite evaluation scripts. 17 | 18 | So, ``dcbench`` is designed to be a common home for these diverse, but 19 | related, tasks. In ``dcbench`` all of these tasks are structured in a 20 | similar manner and they are supported by a common Python API that makes 21 | it easy to download data, run evaluations, and compare methods. 22 | 23 | .. py:currentmodule:: dcbench 24 | 25 | 26 | 27 | 28 | 🧭 API Walkthrough 29 | --------------------------------------- 30 | .. 31 | TODO: Add a schematic outlining the clas structure 32 | 33 | .. code-block:: bash 34 | 35 | pip install dcbench 36 | 37 | .. _task-intro: 38 | 39 | 40 | ``Task`` 41 | ~~~~~~~~~~~~ 42 | ``dcbench`` supports a diverse set of data-centric tasks (*e.g.* :any:`slice_discovery`). 43 | You can explore the supported tasks in the documentation (:any:`tasks`) or via the Python API: 44 | 45 | .. ipython:: python 46 | 47 | import dcbench 48 | dcbench.tasks 49 | 50 | 51 | In the ``dcbench`` API, each task is represented by a :class:`dcbench.Task` object that can be accessed by *task_id* (*e.g.* ``dcbench.slice_discovery``). These task objects hold metadata about the task and hold pointers to task-specific :class:`dcbench.Problem` and :class:`dcbench.Solution` subclasses, discussed below. 52 | 53 | .. _problem-intro: 54 | 55 | ``Problem`` 56 | ~~~~~~~~~~~~ 57 | Each task features a collection of *problems* (*i.e.* instances of the task). For example, the :any:`slice_discovery` task includes hundreds of problems across a number of different datasets. We can explore a task's problems in ``dcbench``: 58 | 59 | .. ipython:: python 60 | 61 | dcbench.tasks["slice_discovery"].problems 62 | 63 | All of a task's problems share the same structure and use the same evaluation scripts. 64 | This is specified via task-specific subclasses of :class:`dcbench.Problem` (*e.g.* :class:`~dcbench.SliceDiscoveryProblem`). The problems themselves are instances of these subclasses. We can access a problem using it's id: 65 | 66 | .. ipython:: python 67 | 68 | problem = dcbench.tasks["slice_discovery"].problems["p_118919"] 69 | problem 70 | 71 | 72 | ``Artifact`` 73 | ~~~~~~~~~~~~ 74 | 75 | Each *problem* is made up of a set of artifacts: a dataset with features to clean, a dataset and a model to perform error analysis on. In ``dcbench`` , these artifacts are represented by instances of 76 | :class:`dcbench.Artifact`. We can think of each :class:`Problem` object as a container for :class:`Artifact` objects. 77 | 78 | .. ipython:: python 79 | 80 | problem.artifacts 81 | 82 | Note that :class:`~dcbench.Artifact` objects don't actually hold their underlying data in memory. Instead, they hold pointers to where the :class:`Artifact` lives in ``dcbench`` `cloud storage `_ and, if it's been downloaded, where it lives locally on disk. This makes the :class:`Problem` objects very lightweight. 83 | 84 | ``dcbench`` includes loading functionality for each artifact type. To load an artifact into memory we can use :meth:`~dcbench.Artifact.load()` . Note that this will also download the artifact to disk if it hasn't yet been downloaded. 85 | 86 | .. ipython:: python 87 | 88 | problem.artifacts["model"] 89 | 90 | Easier yet, we can use the index operator directly on :class:`Problem` objects to both fetch the artifact and load it into memory. 91 | 92 | .. ipython:: python 93 | 94 | problem["activations"] # shorthand for problem.artifacts["model"].load() 95 | 96 | 97 | .. admonition:: Downloading to Disk 98 | 99 | By default, ``dcbench`` downloads artifacts to ``~/.dcbench`` but this can be configured by creating a ``dcbench-config.yaml`` as described in :any:`configuring`. To download an :class:`Artifact` via the Python API, use :meth:`Artifact.download()`. You can also download all the artifacts in a problem with :class:`Problem.download()`. 100 | 101 | 102 | ``Solution`` 103 | ~~~~~~~~~~~~ -------------------------------------------------------------------------------- /docs/source/task_descriptions/budgetclean.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/docs/source/task_descriptions/budgetclean.rst -------------------------------------------------------------------------------- /docs/source/task_descriptions/minidata.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/docs/source/task_descriptions/minidata.rst -------------------------------------------------------------------------------- /docs/source/task_descriptions/slice_discovery.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/docs/source/task_descriptions/slice_discovery.rst -------------------------------------------------------------------------------- /docs/source/task_template.rst: -------------------------------------------------------------------------------- 1 | .. _{task_id}: 2 | 3 | {name} 4 | -------------------------------------------- 5 | 6 | .. sidebar:: 7 | Task Details 8 | 9 | :Task ID: ``{task_id}`` 10 | :Problems: {num_problems} 11 | 12 | {summary} 13 | 14 | **Classes**: {problem_class} {solution_class} 15 | 16 | .. admonition:: Cloud Storage 17 | 18 | We recommend downloading Artifacts through the Python API, but you can also explore the Artifacts on the `Google Cloud Console <{storage_url}>`_. 19 | 20 | 21 | Problem Artifacts 22 | __________________ 23 | {problem_artifact_table} 24 | 25 | Solution Artifacts 26 | ____________________ 27 | {solution_artifact_table} 28 | 29 | {longer_description} 30 | -------------------------------------------------------------------------------- /docs/source/tasks.rst: -------------------------------------------------------------------------------- 1 | .. _tasks: 2 | 3 | 🎯 Tasks 4 | ========= 5 | 6 | .. _minidata: 7 | 8 | Minimal Data Selection 9 | -------------------------------------------- 10 | 11 | .. sidebar:: 12 | Task Details 13 | 14 | :Task ID: ``minidata`` 15 | :Problems: 1 16 | 17 | Given a large training dataset, what is the smallest subset you can sample that still achieves some threshold of performance. 18 | 19 | **Classes**: :class:`dcbench.MiniDataProblem` :class:`dcbench.MiniDataSolution` 20 | 21 | .. admonition:: Cloud Storage 22 | 23 | We recommend downloading Artifacts through the Python API, but you can also explore the Artifacts on the `Google Cloud Console `_. 24 | 25 | 26 | Problem Artifacts 27 | __________________ 28 | ============== ================================== ================================================================================== 29 | name type description 30 | ============== ================================== ================================================================================== 31 | ``train_data`` :class:`dcbench.DataPanelArtifact` A DataPanel of train examples with columns ``id``, ``input``, and ``target``. 32 | ``val_data`` :class:`dcbench.DataPanelArtifact` A DataPanel of validation examples with columns ``id``, ``input``, and ``target``. 33 | ``test_data`` :class:`dcbench.DataPanelArtifact` A DataPanel of test examples with columns ``id``, ``input``, and ``target``. 34 | ============== ================================== ================================================================================== 35 | 36 | Solution Artifacts 37 | ____________________ 38 | ============= ============================= ====================================================================== 39 | name type description 40 | ============= ============================= ====================================================================== 41 | ``train_ids`` :class:`dcbench.YAMLArtifact` A list of train example ids from the ``id`` column of ``train_data``. 42 | ============= ============================= ====================================================================== 43 | 44 | 45 | 46 | 47 | .. _slice_discovery: 48 | 49 | Slice Discovery 50 | -------------------------------------------- 51 | 52 | .. sidebar:: 53 | Task Details 54 | 55 | :Task ID: ``slice_discovery`` 56 | :Problems: 20 57 | 58 | Machine learnings models that achieve high overall accuracy often make systematic erors on important subgroups (or *slices*) of data. When working with high-dimensional inputs (*e.g.* images, audio) where data slices are often unlabeled, identifying underperforming slices is challenging. In this task, we'll develop automated slice discovery methods that mine unstructured data for underperforming slices. 59 | 60 | **Classes**: :class:`dcbench.SliceDiscoveryProblem` :class:`dcbench.SliceDiscoverySolution` 61 | 62 | .. admonition:: Cloud Storage 63 | 64 | We recommend downloading Artifacts through the Python API, but you can also explore the Artifacts on the `Google Cloud Console `_. 65 | 66 | 67 | Problem Artifacts 68 | __________________ 69 | ==================== ====================================== =============================================================================== 70 | name type description 71 | ==================== ====================================== =============================================================================== 72 | ``val_predictions`` :class:`dcbench.DataPanelArtifact` A DataPanel of the model's predictions with columns `id`,`target`, and `probs.` 73 | ``test_predictions`` :class:`dcbench.DataPanelArtifact` A DataPanel of the model's predictions with columns `id`,`target`, and `probs.` 74 | ``test_slices`` :class:`dcbench.DataPanelArtifact` A DataPanel of the ground truth slice labels with columns `id`, `slices`. 75 | ``activations`` :class:`dcbench.DataPanelArtifact` A DataPanel of the model's activations with columns `id`,`act` 76 | ``model`` :class:`dcbench.ModelArtifact` A trained PyTorch model to audit. 77 | ``base_dataset`` :class:`dcbench.VisionDatasetArtifact` A DataPanel representing the base dataset with columns `id` and `image`. 78 | ``clip`` :class:`dcbench.DataPanelArtifact` A DataPanel of the image embeddings from OpenAI's CLIP model 79 | ==================== ====================================== =============================================================================== 80 | 81 | Solution Artifacts 82 | ____________________ 83 | =============== ================================== ========================================================================== 84 | name type description 85 | =============== ================================== ========================================================================== 86 | ``pred_slices`` :class:`dcbench.DataPanelArtifact` A DataPanel of predicted slice labels with columns `id` and `pred_slices`. 87 | =============== ================================== ========================================================================== 88 | 89 | 90 | 91 | 92 | .. _budgetclean: 93 | 94 | Data Cleaning on a Budget 95 | -------------------------------------------- 96 | 97 | .. sidebar:: 98 | Task Details 99 | 100 | :Task ID: ``budgetclean`` 101 | :Problems: 144 102 | 103 | When it comes to data preparation, data cleaning is an essential yet quite costly task. If we are given a fixed cleaning budget, the challenge is to find the training data examples that would would bring the biggest positive impact on model performance if we were to clean them. 104 | 105 | **Classes**: :class:`dcbench.BudgetcleanProblem` :class:`dcbench.BudgetcleanSolution` 106 | 107 | .. admonition:: Cloud Storage 108 | 109 | We recommend downloading Artifacts through the Python API, but you can also explore the Artifacts on the `Google Cloud Console `_. 110 | 111 | 112 | Problem Artifacts 113 | __________________ 114 | ================= ============================ ======================================================================================================================================== 115 | name type description 116 | ================= ============================ ======================================================================================================================================== 117 | ``X_train_dirty`` :class:`dcbench.CSVArtifact` ('Features of the dirty training dataset which we need to clean. Each dirty cell contains an embedded list of clean candidate values.',) 118 | ``X_train_clean`` :class:`dcbench.CSVArtifact` Features of the clean training dataset where each dirty value from the dirty dataset is replaced with the correct clean candidate. 119 | ``y_train`` :class:`dcbench.CSVArtifact` Labels of the training dataset. 120 | ``X_val`` :class:`dcbench.CSVArtifact` Feature of the validtion dataset which can be used to guide the cleaning optimization process. 121 | ``y_val`` :class:`dcbench.CSVArtifact` Labels of the validation dataset. 122 | ``X_test`` :class:`dcbench.CSVArtifact` ('Features of the test dataset used to produce the final evaluation score of the model.',) 123 | ``y_test`` :class:`dcbench.CSVArtifact` Labels of the test dataset. 124 | ================= ============================ ======================================================================================================================================== 125 | 126 | Solution Artifacts 127 | ____________________ 128 | ================ ============================ ============= 129 | name type description 130 | ================ ============================ ============= 131 | ``idx_selected`` :class:`dcbench.CSVArtifact` 132 | ================ ============================ ============= 133 | 134 | 135 | -------------------------------------------------------------------------------- /notebooks/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e92c135b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "e5ed4031", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import dcbench" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "id": "01273a96", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "task = dcbench.tasks[\"slice_discovery\"]" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 11, 37 | "id": "3960e86a", 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "dcbench.tasks.slice_discovery.problem.SliceDiscoverySolution" 44 | ] 45 | }, 46 | "execution_count": 11, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "task.solution_class" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "id": "d249ee2b", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "import yaml\n", 63 | "import domino\n", 64 | "out = yaml.load(\n", 65 | " open(\"/oak/stanford/groups/jamesz/eyuboglu/.dcbench/slice_discovery/solution_sets/22-06-07-c5ef5603/solutions.yaml\"),\n", 66 | " Loader=yaml.FullLoader\n", 67 | ")" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "id": "62985a71", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "SliceDiscoverySolution(artifacts={'pred_slices': 'DataPanelArtifact'}, attributes={'embedding_column': 'emb', 'problem_id': 'p_72877', 'slicer_class': , 'slicer_config': {'n_slices': 5, 'y_hat_log_likelihood_weight': 10, 'y_log_likelihood_weight': 10}})" 80 | ] 81 | }, 82 | "execution_count": 6, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "out[0]" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 7, 94 | "id": "f9c43f36", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "from dcbench.common.solution_set import SolutionSet" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 13, 104 | "id": "43287e5d", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "sset = SolutionSet.from_solutions(out, \"test\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 15, 114 | "id": "ff176846", 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "ename": "AttributeError", 119 | "evalue": "'SolutionSet' object has no attribute 'solutions'", 120 | "output_type": "error", 121 | "traceback": [ 122 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 123 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 124 | "Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43msset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msolutions\u001b[49m\n", 125 | "\u001b[0;31mAttributeError\u001b[0m: 'SolutionSet' object has no attribute 'solutions'" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "sset.solutions" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "f7f921d0", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [] 140 | } 141 | ], 142 | "metadata": { 143 | "kernelspec": { 144 | "display_name": "Python 3", 145 | "language": "python", 146 | "name": "python3" 147 | }, 148 | "language_info": { 149 | "codemirror_mode": { 150 | "name": "ipython", 151 | "version": 3 152 | }, 153 | "file_extension": ".py", 154 | "mimetype": "text/x-python", 155 | "name": "python", 156 | "nbconvert_exporter": "python", 157 | "pygments_lexer": "ipython3", 158 | "version": "3.9.12" 159 | } 160 | }, 161 | "nbformat": 4, 162 | "nbformat_minor": 5 163 | } 164 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta:__legacy__" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from distutils.util import convert_path 11 | from shutil import rmtree 12 | 13 | from setuptools import Command, find_packages, setup 14 | 15 | main_ns = {} 16 | ver_path = convert_path("dcbench/version.py") 17 | with open(ver_path) as ver_file: 18 | exec(ver_file.read(), main_ns) 19 | 20 | 21 | # Package meta-data. 22 | NAME = "dcbench" 23 | DESCRIPTION = ( 24 | "This is a benchmark that tests various data-centric aspects of improving the " 25 | "quality of machine learning workflows." 26 | ) 27 | URL = "" 28 | EMAIL = "sabri@eyuboglu.us" 29 | AUTHOR = "https://github.com/data-centric-ai/dcbench" 30 | REQUIRES_PYTHON = ">=3.7.0" 31 | VERSION = main_ns["__version__"] 32 | 33 | REQUIRED = [ 34 | "click>=8.0.0", 35 | "pyyaml>=5.4", 36 | "pre-commit", 37 | # "pytorch-lightning", 38 | "pandas", 39 | "numpy>=1.18.0", 40 | # "cytoolz", 41 | "ujson", 42 | "jsonlines>=1.2.0", 43 | # "torch>=1.8.0", 44 | "tqdm>=4.49.0", 45 | "scikit-learn", 46 | "meerkat-ml[dev,vision,ml]", 47 | # "torchvision>=0.9.0", 48 | # "wandb", 49 | # "ray[default]", 50 | # `"torchxrayvision",` 51 | ] 52 | EXTRAS = { 53 | "dev": [ 54 | "black==21.5b0", 55 | "isort>=5.7.0", 56 | "autoflake", 57 | "flake8>=3.8.4", 58 | "mypy>=0.9", 59 | "docformatter>=1.4", 60 | "pytest-cov>=2.10.1", 61 | "sphinx-rtd-theme>=0.5.1", 62 | "nbsphinx>=0.8.0", 63 | "recommonmark>=0.7.1", 64 | "parameterized", 65 | "pre-commit>=2.9.3", 66 | "sphinx-autobuild", 67 | "google-cloud-storage", 68 | "furo", 69 | ], 70 | } 71 | 72 | # The rest you shouldn't have to touch too much :) 73 | # ------------------------------------------------ 74 | # Except, perhaps the License and Trove Classifiers! 75 | # If you do change the License, remember to change the Trove Classifier for that! 76 | 77 | here = os.path.abspath(os.path.dirname(__file__)) 78 | 79 | # Import the README and use it as the long-description. 80 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 81 | try: 82 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 83 | long_description = "\n" + f.read() 84 | except FileNotFoundError: 85 | long_description = DESCRIPTION 86 | 87 | # Load the package's __version__.py module as a dictionary. 88 | about = {} 89 | if not VERSION: 90 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 91 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 92 | exec(f.read(), about) 93 | else: 94 | about["__version__"] = VERSION 95 | 96 | 97 | class UploadCommand(Command): 98 | """Support setup.py upload.""" 99 | 100 | description = "Build and publish the package." 101 | user_options = [] 102 | 103 | @staticmethod 104 | def status(s): 105 | """Prints things in bold.""" 106 | print("\033[1m{0}\033[0m".format(s)) 107 | 108 | def initialize_options(self): 109 | pass 110 | 111 | def finalize_options(self): 112 | pass 113 | 114 | def run(self): 115 | try: 116 | self.status("Removing previous builds…") 117 | rmtree(os.path.join(here, "dist")) 118 | except OSError: 119 | pass 120 | 121 | self.status("Building Source and Wheel (universal) distribution…") 122 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 123 | 124 | self.status("Uploading the package to PyPI via Twine…") 125 | os.system("twine upload dist/*") 126 | 127 | self.status("Pushing git tags…") 128 | os.system("git tag v{0}".format(about["__version__"])) 129 | os.system("git push --tags") 130 | 131 | sys.exit() 132 | 133 | 134 | # Where the magic happens: 135 | setup( 136 | name=NAME, 137 | version=about["__version__"], 138 | description=DESCRIPTION, 139 | long_description=long_description, 140 | long_description_content_type="text/markdown", 141 | author=AUTHOR, 142 | author_email=EMAIL, 143 | python_requires=REQUIRES_PYTHON, 144 | url=URL, 145 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 146 | # If your package is a single module, use this instead of 'packages': 147 | # py_modules=['mypackage'], 148 | # entry_points={ 149 | # 'console_scripts': ['mycli=mymodule:cli'], 150 | # }, 151 | install_requires=REQUIRED, 152 | extras_require=EXTRAS, 153 | include_package_data=True, 154 | license="Apache 2.0", 155 | classifiers=[ 156 | # Trove classifiers 157 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 158 | "Programming Language :: Python", 159 | "Programming Language :: Python :: 3", 160 | "Programming Language :: Python :: 3.7", 161 | "Programming Language :: Python :: 3.8", 162 | "Programming Language :: Python :: 3.9", 163 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 164 | ], 165 | # $ setup.py publish support. 166 | cmdclass={ 167 | "upload": UploadCommand, 168 | }, 169 | ) 170 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # contents of conftest.py 2 | import os 3 | 4 | import google.cloud.storage as storage 5 | import pytest 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def set_test_bucket(monkeypatch): 10 | test_bucket_name = "dcbench" 11 | monkeypatch.setattr("dcbench.config.public_bucket_name", test_bucket_name) 12 | 13 | 14 | @pytest.fixture() 15 | def set_tmp_bucket(monkeypatch): 16 | test_bucket_name = "dcbench-test" 17 | monkeypatch.setattr("dcbench.config.public_bucket_name", test_bucket_name) 18 | 19 | # code above this yield will be executed before every test 20 | yield 21 | # code below this yield will be executed after every test 22 | 23 | assert test_bucket_name != "dcbench" # ensure we don't empty the production bucket 24 | client = storage.Client() 25 | bucket = client.get_bucket(test_bucket_name) 26 | blobs = list(bucket.list_blobs()) 27 | bucket.delete_blobs(blobs) 28 | return test_bucket_name 29 | 30 | 31 | @pytest.fixture(autouse=True) 32 | def set_test_local(monkeypatch, tmpdir): 33 | monkeypatch.setattr("dcbench.config.local_dir", os.path.join(tmpdir, ".dcbench")) 34 | -------------------------------------------------------------------------------- /tests/dcbench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/tests/dcbench/__init__.py -------------------------------------------------------------------------------- /tests/dcbench/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/tests/dcbench/common/__init__.py -------------------------------------------------------------------------------- /tests/dcbench/common/test_artifact.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from typing import Any 4 | 5 | import meerkat as mk 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | import torch 10 | import torch.nn as nn 11 | import yaml 12 | 13 | from dcbench.common.artifact import ( 14 | Artifact, 15 | CSVArtifact, 16 | DataPanelArtifact, 17 | ModelArtifact, 18 | VisionDatasetArtifact, 19 | YAMLArtifact, 20 | ) 21 | from dcbench.common.modeling import Model 22 | 23 | 24 | class SimpleModel(Model): 25 | def _set_model(self): 26 | torch.manual_seed(0) 27 | self.layer = nn.Linear(in_features=self.config["in_features"], out_features=2) 28 | 29 | 30 | @pytest.fixture(params=["csv", "datapanel", "model", "yaml"]) 31 | def artifact(request): 32 | artifact_type = request.param 33 | 34 | artifact_id = f"test_artifact_{artifact_type}" 35 | if artifact_type == "csv": 36 | return CSVArtifact.from_data( 37 | pd.DataFrame({"a": np.arange(5), "b": np.ones(5)}), artifact_id=artifact_id 38 | ) 39 | elif artifact_type == "datapanel": 40 | return DataPanelArtifact.from_data( 41 | mk.DataPanel({"a": np.arange(5), "b": np.ones(5)}), artifact_id=artifact_id 42 | ) 43 | elif artifact_type == "model": 44 | return ModelArtifact.from_data( 45 | SimpleModel({"in_features": 4}), artifact_id=artifact_id 46 | ) 47 | elif artifact_type == "yaml": 48 | return YAMLArtifact.from_data([1, 2, 3], artifact_id=artifact_id) 49 | else: 50 | raise ValueError(f"Artifact type '{artifact_type}' not supported.") 51 | 52 | 53 | def is_data_equal(data1: Any, data2: Any) -> bool: 54 | if not isinstance(data1, type(data2)): 55 | return False 56 | 57 | if isinstance(data1, pd.DataFrame): 58 | return data1.equals(data2) 59 | elif isinstance(data1, mk.DataPanel): 60 | for col in data1.columns: 61 | if col not in data2.columns: 62 | return False 63 | if not (data1[col] == data2[col]).all(): 64 | return False 65 | elif isinstance(data1, Model): 66 | return ( 67 | data1.layer.weight == data2.layer.weight 68 | ).all() and data1.config == data2.config 69 | elif isinstance(data1, (list, dict)): 70 | return data1 == data2 71 | else: 72 | raise ValueError(f"Data type '{type(data1)}' not supported.") 73 | return True 74 | 75 | 76 | def test_artifact_upload_download(set_tmp_bucket, artifact): 77 | data = artifact.load() 78 | uploaded = artifact.upload(force=True) 79 | assert uploaded 80 | 81 | downloaded = artifact.download(force=True) 82 | assert downloaded 83 | assert is_data_equal(data, artifact.load()) 84 | 85 | # check that upload without force does not upload 86 | uploaded = artifact.upload() 87 | assert not uploaded 88 | 89 | # check that download without force does not download 90 | downloaded = artifact.download() 91 | assert not downloaded 92 | 93 | 94 | def test_to_yaml_from_yaml(artifact): 95 | yaml_str = yaml.dump(artifact) 96 | artifact_from_yaml = yaml.load(yaml_str, Loader=yaml.FullLoader) 97 | assert artifact_from_yaml.id == artifact.id 98 | assert isinstance(artifact.load(), type(artifact_from_yaml.load())) 99 | assert artifact.remote_url == artifact_from_yaml.remote_url 100 | assert artifact.local_path == artifact.local_path 101 | assert is_data_equal(artifact.load(), artifact_from_yaml.load()) 102 | 103 | 104 | def test_load_without_download_errors(artifact): 105 | if os.path.isdir(artifact.local_path): 106 | shutil.rmtree(artifact.local_path) 107 | else: 108 | os.remove(artifact.local_path) 109 | 110 | with pytest.raises(ValueError) as excinfo: 111 | artifact.load() 112 | 113 | assert "`Artifact`" in str(excinfo.value) 114 | 115 | 116 | def test_upload_without_save_errors(artifact): 117 | if os.path.isdir(artifact.local_path): 118 | shutil.rmtree(artifact.local_path) 119 | else: 120 | os.remove(artifact.local_path) 121 | 122 | with pytest.raises(ValueError) as excinfo: 123 | artifact.upload() 124 | 125 | assert "Artifact" in str(excinfo.value) 126 | 127 | 128 | def test_from_data(): 129 | artifact = Artifact.from_data(pd.DataFrame({"a": np.arange(5), "b": np.ones(5)})) 130 | assert isinstance(artifact, CSVArtifact) 131 | 132 | artifact = Artifact.from_data(mk.DataPanel({"a": np.arange(5), "b": np.ones(5)})) 133 | assert isinstance(artifact, DataPanelArtifact) 134 | 135 | artifact = Artifact.from_data(SimpleModel({"in_features": 4})) 136 | assert isinstance(artifact, ModelArtifact) 137 | 138 | artifact = Artifact.from_data([1, 2, 3]) 139 | assert isinstance(artifact, YAMLArtifact) 140 | 141 | with pytest.raises(ValueError) as excinfo: 142 | Artifact.from_data(None) 143 | assert "Artifact" in str(excinfo.value) 144 | 145 | 146 | def test_vision_dataset_artifact(monkeypatch): 147 | downloads = [] 148 | celeba_dp = mk.DataPanel( 149 | { 150 | "image": np.random.rand(10, 3, 4, 4), 151 | "identity": np.random.randint(0, 10, 10), 152 | "split": np.random.randint(0, 10, 10), 153 | "image_id": np.random.randint(0, 10, 10), 154 | } 155 | ) 156 | 157 | imagenet_dp = mk.DataPanel( 158 | { 159 | "image": np.random.rand(10, 3, 4, 4), 160 | "name": np.random.randint(0, 10, 10), 161 | "synset": np.random.randint(0, 10, 10), 162 | "image_id": np.random.randint(0, 10, 10), 163 | } 164 | ) 165 | 166 | def mock_get(name, dataset_dir, **kwargs): 167 | if name == "celeba": 168 | downloads.append("celeba") 169 | return celeba_dp.view() 170 | elif name == "imagenet": 171 | downloads.append("imagenet") 172 | return imagenet_dp.view() 173 | 174 | monkeypatch.setattr(mk.datasets, "get", mock_get) 175 | 176 | artifact = VisionDatasetArtifact.from_name("celeba") 177 | loaded_artifact = artifact.load() 178 | assert isinstance(loaded_artifact, mk.DataPanel) 179 | assert np.allclose(loaded_artifact["image"], celeba_dp["image"]) 180 | assert len(downloads) == 1 181 | 182 | artifact.download() 183 | loaded_artifact = artifact.load() 184 | assert isinstance(loaded_artifact, mk.DataPanel) 185 | assert np.allclose(loaded_artifact["image"], celeba_dp["image"]) 186 | assert len(downloads) == 2 187 | 188 | artifact = VisionDatasetArtifact.from_name("imagenet") 189 | loaded_artifact = artifact.load() 190 | assert isinstance(loaded_artifact, mk.DataPanel) 191 | assert np.allclose(loaded_artifact["image"], imagenet_dp["image"]) 192 | assert len(downloads) == 3 193 | 194 | artifact.download() 195 | loaded_artifact = artifact.load() 196 | assert isinstance(loaded_artifact, mk.DataPanel) 197 | assert np.allclose(loaded_artifact["image"], imagenet_dp["image"]) 198 | assert len(downloads) == 4 199 | 200 | with pytest.raises(ValueError) as excinfo: 201 | artifact = VisionDatasetArtifact.from_name("nonexistent") 202 | 203 | assert "nonexistent" in str(excinfo.value) and "dcbench" in str(excinfo.value) 204 | assert len(downloads) == 4 205 | -------------------------------------------------------------------------------- /tests/dcbench/common/test_artifact_container.py: -------------------------------------------------------------------------------- 1 | import meerkat as mk 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | import yaml 6 | 7 | from dcbench.common.artifact import Artifact, CSVArtifact, DataPanelArtifact 8 | from dcbench.common.artifact_container import ArtifactContainer, ArtifactSpec 9 | from dcbench.common.table import AttributeSpec 10 | 11 | from .test_artifact import is_data_equal 12 | 13 | 14 | class SimpleContainer(ArtifactContainer): 15 | artifact_specs = { 16 | "csv1": ArtifactSpec("Description of csv artifact", CSVArtifact), 17 | "csv2": ArtifactSpec("Description of common csv artifact", CSVArtifact), 18 | "dp1": ArtifactSpec("Description of datapanel artifact", DataPanelArtifact), 19 | "dp2": ArtifactSpec( 20 | "Description of datapanel artifact", DataPanelArtifact, optional=True 21 | ), 22 | } 23 | attribute_specs = { 24 | "test_attr": AttributeSpec( 25 | description="A test attribute", attribute_type=int, optional=True 26 | ) 27 | } 28 | task_id = "test_task" 29 | container_type = "test_container_type" 30 | 31 | 32 | @pytest.fixture 33 | def container(): 34 | return SimpleContainer( 35 | artifacts={ 36 | "csv1": CSVArtifact.from_data( 37 | pd.DataFrame({"a": np.arange(5), "b": np.ones(5)}), 38 | artifact_id="csv1", 39 | ), 40 | "dp1": DataPanelArtifact.from_data( 41 | mk.DataPanel({"a": np.arange(5), "b": np.ones(5)}), 42 | artifact_id="dp1", 43 | ), 44 | "csv2": CSVArtifact.from_data( 45 | pd.DataFrame({"a": np.arange(5), "b": np.ones(5)}), 46 | artifact_id="csv2", 47 | ), 48 | }, 49 | attributes={ 50 | "test_attr": 5, 51 | }, 52 | container_id="test_container", 53 | ) 54 | 55 | 56 | def test_artifact_container_from_raw_data(): 57 | 58 | df = pd.DataFrame({"a": np.arange(5), "b": np.ones(5)}) 59 | dp = mk.DataPanel({"a": np.arange(5), "b": np.ones(5)}) 60 | 61 | artifact = Artifact.from_data(df, artifact_id="common_artifact") 62 | 63 | container = SimpleContainer( 64 | artifacts={"csv1": df, "dp1": dp, "csv2": artifact}, 65 | container_id="test_container", 66 | ) 67 | 68 | assert is_data_equal(container.artifacts["csv1"].load(), df) 69 | assert is_data_equal(container.artifacts["dp1"].load(), dp) 70 | assert is_data_equal(container.artifacts["csv2"].load(), df) 71 | 72 | # if the container is passed raw object, an ID is automatically generated for it 73 | assert ( 74 | container.artifacts["csv1"].id 75 | == "test_task/test_container_type/artifacts/test_container/csv1" 76 | ) 77 | assert ( 78 | container.artifacts["dp1"].id 79 | == "test_task/test_container_type/artifacts/test_container/dp1" 80 | ) 81 | assert container.artifacts["csv2"].id == "common_artifact" 82 | 83 | 84 | def test_artifact_container_invalid_artifact(container): 85 | df = pd.DataFrame({"a": np.arange(5), "b": np.ones(5)}) 86 | dp = mk.DataPanel({"a": np.arange(5), "b": np.ones(5)}) 87 | 88 | with pytest.raises(ValueError) as excinfo: 89 | SimpleContainer( 90 | artifacts={"csv1": df, "dp1": dp, "csv2": df, "nonexistent": df}, 91 | container_id="test_container", 92 | ) 93 | assert "Passed artifact name 'nonexistent'" in str(excinfo.value) 94 | 95 | with pytest.raises(ValueError) as excinfo: 96 | SimpleContainer( 97 | artifacts={"csv1": df, "csv2": Artifact.from_data(dp), "dp1": dp}, 98 | container_id="test_container", 99 | ) 100 | assert "for the artifact named 'csv2'" in str(excinfo.value) 101 | 102 | with pytest.raises(ValueError) as excinfo: 103 | SimpleContainer( 104 | artifacts={"dp1": dp, "csv2": df}, 105 | container_id="test_container", 106 | ) 107 | 108 | assert "Must pass required artifact with key 'csv1'" in str(excinfo.value) 109 | 110 | 111 | @pytest.mark.parametrize("use_force", [True, False]) 112 | def test_artifact_container_download(monkeypatch, container, use_force: bool): 113 | downloads = [] 114 | 115 | # mock the download function 116 | def mock_download(self, force: str = True): 117 | if not use_force: 118 | return False 119 | downloads.append(self.id) 120 | return True 121 | 122 | monkeypatch.setattr(Artifact, "download", mock_download) 123 | 124 | downloaded = container.download(force=use_force) 125 | assert downloaded == use_force 126 | 127 | if use_force: 128 | assert len(downloads) == 3 129 | assert set(downloads) == set(["csv1", "dp1", "csv2"]) 130 | else: 131 | assert len(downloads) == 0 132 | 133 | 134 | @pytest.mark.parametrize("use_force", [True, False]) 135 | def test_artifact_container_upload(monkeypatch, container, use_force: bool): 136 | uploads = [] 137 | 138 | # mock the upload function 139 | def mock_upload(self, force: str = False, bucket: str = None): 140 | if not force: 141 | return False 142 | uploads.append(self.id) 143 | return True 144 | 145 | monkeypatch.setattr(Artifact, "upload", mock_upload) 146 | 147 | uploaded = container.upload(force=use_force) 148 | assert uploaded == use_force 149 | 150 | if use_force: 151 | assert len(uploads) == 3 152 | assert set(uploads) == set(["csv1", "dp1", "csv2"]) 153 | else: 154 | assert len(uploads) == 0 155 | 156 | 157 | def test_artifact_container_is_uploaded(monkeypatch, container): 158 | def mock_is_uploaded(self): 159 | return True 160 | 161 | monkeypatch.setattr(Artifact, "is_uploaded", property(mock_is_uploaded)) 162 | 163 | is_uploaded = container.is_uploaded 164 | assert is_uploaded 165 | 166 | 167 | def test_artifact_container_is_not_uploaded(monkeypatch, container): 168 | def mock_is_not_uploaded(self): 169 | print(self.id) 170 | return self.id != "csv1" 171 | 172 | monkeypatch.setattr(Artifact, "is_uploaded", property(mock_is_not_uploaded)) 173 | 174 | is_uploaded = container.is_uploaded 175 | assert not is_uploaded 176 | 177 | 178 | def test_artifact_container_is_downloaded(monkeypatch, container): 179 | def mock_is_downloaded(self): 180 | return True 181 | 182 | monkeypatch.setattr(Artifact, "is_downloaded", property(mock_is_downloaded)) 183 | 184 | is_downloaded = container.is_downloaded 185 | assert is_downloaded 186 | 187 | 188 | def test_artifact_container_is_not_downloaded(monkeypatch, container): 189 | def mock_is_downloaded(self): 190 | return self.id != "csv1" 191 | 192 | monkeypatch.setattr(Artifact, "is_downloaded", property(mock_is_downloaded)) 193 | 194 | is_downloaded = container.is_downloaded 195 | assert not is_downloaded 196 | 197 | 198 | def test_artifact_container_to_yaml_from_yaml(container): 199 | yaml_str = yaml.dump(container) 200 | container_from_yaml = yaml.load(yaml_str, Loader=yaml.FullLoader) 201 | 202 | assert container.id == container_from_yaml.id 203 | assert isinstance(container, type(container_from_yaml)) 204 | assert is_data_equal( 205 | container.artifacts["csv1"].load(), container_from_yaml.artifacts["csv1"].load() 206 | ) 207 | assert is_data_equal( 208 | container.artifacts["dp1"].load(), container_from_yaml.artifacts["dp1"].load() 209 | ) 210 | assert is_data_equal( 211 | container.artifacts["csv2"].load(), container_from_yaml.artifacts["csv2"].load() 212 | ) 213 | 214 | 215 | def test_artifact_container_repr(container): 216 | assert "SimpleContainer" in str(container) 217 | 218 | 219 | def test_artifact_container_len(container): 220 | assert len(container) == 3 221 | 222 | 223 | def test_artifact_container_getitem(container): 224 | assert is_data_equal(container["csv1"], container.artifacts["csv1"].load()) 225 | assert is_data_equal(container["dp1"], container.artifacts["dp1"].load()) 226 | assert is_data_equal(container["csv2"], container.artifacts["csv2"].load()) 227 | 228 | 229 | def test_artifact_container_iter(container): 230 | 231 | assert [key for key in container] == ["csv1", "dp1", "csv2"] 232 | 233 | 234 | def test_attribute_access(container): 235 | assert container.test_attr == 5 236 | 237 | with pytest.raises(AttributeError): 238 | container.nonexistent_attr 239 | -------------------------------------------------------------------------------- /tests/dcbench/common/test_problem.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import dcbench 4 | 5 | 6 | @pytest.fixture(params=["slice_discovery"]) 7 | def problem_class(request): 8 | task = request.param 9 | if task == "slice_discovery": 10 | return dcbench.slice_discovery 11 | else: 12 | raise ValueError(f"Task '{task}' not supported.") 13 | -------------------------------------------------------------------------------- /tests/dcbench/common/test_task.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/data-centric-ai/dcbench/831ab2359d686739d0b0c7a589974ce08448e58d/tests/dcbench/common/test_task.py -------------------------------------------------------------------------------- /tests/dcbench/tasks/test_slice_discovery.py: -------------------------------------------------------------------------------- 1 | import meerkat as mk 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | import dcbench 6 | from dcbench import DataPanelArtifact, Table, Task 7 | from dcbench.tasks.slice_discovery.metrics import compute_metrics, roc_auc_score 8 | 9 | 10 | def test_solve(): 11 | slice_discovery = dcbench.tasks["slice_discovery"] 12 | 13 | problem = slice_discovery.problems["p_117634"] 14 | 15 | ids = problem["test_predictions"]["id"] 16 | pred_slices = np.zeros((len(ids), 5)) 17 | problem.solve(pred_slices_dp=mk.DataPanel({"id": ids, "pred_slices": pred_slices})) 18 | 19 | 20 | def test_problems(): 21 | slice_discovery = dcbench.tasks["slice_discovery"] 22 | assert isinstance(slice_discovery, Task) 23 | assert isinstance(slice_discovery.problems, Table) 24 | 25 | 26 | def test_problem(): 27 | slice_discovery = dcbench.tasks["slice_discovery"] 28 | problem = slice_discovery.problems["p_117634"] 29 | 30 | for name in ["test_predictions", "val_predictions", "test_slices", "activations"]: 31 | out = problem[name] 32 | assert isinstance(out, mk.DataPanel) 33 | 34 | out = problem["model"] 35 | assert isinstance(out, nn.Module) 36 | 37 | problem.slice_category 38 | 39 | 40 | def test_artifacts(): 41 | slice_discovery = dcbench.tasks["slice_discovery"] 42 | problem = slice_discovery.problems["p_117634"] 43 | artifacts = problem.artifacts 44 | assert isinstance(artifacts, dict) 45 | for name in ["test_predictions", "val_predictions", "test_slices", "activations"]: 46 | artifact = artifacts[name] 47 | assert isinstance(artifact, DataPanelArtifact) 48 | assert not artifact.is_downloaded 49 | 50 | artifact.download() 51 | 52 | assert artifact.is_downloaded 53 | 54 | 55 | def test_metrics(): 56 | slices = np.array([[0, 1], [0, 0], [1, 0], [1, 1], [0, 0], [1, 0], [0, 1]]) 57 | 58 | pred_slices = np.array( 59 | [[0, 10, 3], [0, 0, 4], [1, 0, 0], [0, 10, 0], [0, 0, 0], [1, 0, 0], [0, 10, 0]] 60 | ) 61 | 62 | metrics = compute_metrics(pred_slices=pred_slices, slices=slices) 63 | 64 | assert metrics["auroc"][0] == roc_auc_score(slices[:, 0], pred_slices[:, 0]) 65 | assert metrics["auroc"][1] == roc_auc_score(slices[:, 1], pred_slices[:, 1]) 66 | -------------------------------------------------------------------------------- /tests/dcbench/test_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | from dcbench.config import DCBenchConfig, get_config 6 | 7 | 8 | def test_env(tmpdir): 9 | config_path = os.path.join(tmpdir, "config.yaml") 10 | 11 | new_local_dir = os.path.join(tmpdir, ".dcbench-env") 12 | 13 | yaml.dump({"local_dir": new_local_dir}, open(config_path, "w")) 14 | os.environ["DCBENCH_CONFIG"] = config_path 15 | 16 | config = DCBenchConfig(**get_config()) 17 | assert config.local_dir == new_local_dir 18 | -------------------------------------------------------------------------------- /tests/dcbench/test_dcbench.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import dcbench 4 | from dcbench.common.table import Table 5 | from dcbench.common.task import Task 6 | 7 | 8 | def test_tasks(): 9 | assert isinstance(dcbench.tasks, Table) 10 | assert len(dcbench.tasks) == 3 11 | 12 | 13 | def test_tasks_html(): 14 | dcbench.tasks._repr_html_() 15 | 16 | 17 | def test_tasks_df(): 18 | df = dcbench.tasks.df 19 | assert isinstance(df, pd.DataFrame) 20 | 21 | 22 | def test_get_tasks(): 23 | for task_id in dcbench.tasks: 24 | out = dcbench.tasks[task_id] 25 | assert isinstance(out, Task) 26 | assert task_id == out.id 27 | --------------------------------------------------------------------------------