├── .editorconfig ├── .github └── workflows │ ├── codeclimate.yml │ ├── docs.yml │ ├── pypi.yml │ └── tests.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── conftest.py ├── docs ├── .gitignore ├── Makefile ├── conf.py ├── index.rst ├── make.bat ├── repositories.rst ├── structure.rst └── utils.rst ├── pyproject.toml ├── setup.cfg ├── setup.py └── skdatasets ├── __init__.py ├── repositories ├── __init__.py ├── aneurisk.py ├── base.py ├── cran.py ├── forex.py ├── keel.py ├── keras.py ├── libsvm.py ├── physionet.py ├── raetsch.py ├── sklearn.py ├── uci.py └── ucr.py ├── tests ├── __init__.py ├── repositories │ ├── __init__.py │ ├── test_cran.py │ ├── test_forex.py │ ├── test_keel.py │ ├── test_keras.py │ ├── test_libsvm.py │ ├── test_physionet.py │ ├── test_raetsch.py │ ├── test_sklearn.py │ ├── test_uci.py │ └── test_ucr.py └── utils │ ├── LinearRegression.json │ ├── LinearRegressionCustom.json │ ├── MLPClassifier.json │ ├── MLPRegressor.json │ ├── __init__.py │ ├── linear_model.py │ ├── run.py │ ├── test_estimator.py │ ├── test_experiment.py │ ├── test_run.py │ └── test_scores.py └── utils ├── __init__.py ├── estimator.py ├── experiment.py └── scores.py /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | indent_style = tab 6 | indent_size = 4 7 | insert_final_newline = false 8 | end_of_line = lf 9 | 10 | [*.py] 11 | indent_style = space 12 | -------------------------------------------------------------------------------- /.github/workflows/codeclimate.yml: -------------------------------------------------------------------------------- 1 | name: CodeClimate upload 2 | on: 3 | workflow_run: 4 | workflows: [Tests] 5 | types: 6 | - completed 7 | jobs: 8 | download: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: 'Download artifact' 13 | uses: actions/github-script@v6 14 | with: 15 | script: | 16 | let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({ 17 | owner: context.repo.owner, 18 | repo: context.repo.repo, 19 | run_id: context.payload.workflow_run.id, 20 | }); 21 | let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => { 22 | return artifact.name == "code-coverage-report" 23 | })[0]; 24 | let download = await github.rest.actions.downloadArtifact({ 25 | owner: context.repo.owner, 26 | repo: context.repo.repo, 27 | artifact_id: matchArtifact.id, 28 | archive_format: 'zip', 29 | }); 30 | let fs = require('fs'); 31 | fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/code-coverage-report.zip`, Buffer.from(download.data)); 32 | - name: 'Unzip artifact' 33 | run: unzip code-coverage-report.zip 34 | - name: Install dependencies 35 | run: | 36 | pip3 install codecov pytest-cov || pip3 install --user codecov pytest-cov; 37 | - name: Upload coverage to CodeClimate 38 | uses: paambaati/codeclimate-action@v3.2.0 39 | env: 40 | CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Pages 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - develop 7 | pull_request: 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/setup-python@v4 13 | - uses: actions/checkout@v3 14 | with: 15 | fetch-depth: 0 # otherwise, you will failed to push refs to dest repo 16 | - name: Install package 17 | run: pip3 install --upgrade-strategy eager -v ".[all]" 18 | - name: Build and Commit 19 | uses: sphinx-notes/pages@v2 20 | - name: Push changes 21 | uses: ad-m/github-push-action@master 22 | with: 23 | github_token: ${{ secrets.GITHUB_TOKEN }} 24 | branch: gh-pages -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | # This workflow uses actions that are not certified by GitHub. 4 | # They are provided by a third-party and are governed by 5 | # separate terms of service, privacy policy, and support 6 | # documentation. 7 | name: Upload Python Package 8 | on: 9 | release: 10 | types: [published] 11 | permissions: 12 | contents: read 13 | jobs: 14 | deploy: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.x' 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install build 26 | - name: Build package 27 | run: python -m build 28 | - name: Publish package 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: 3 | push: 4 | pull_request: 5 | jobs: 6 | build: 7 | runs-on: ${{ matrix.os }} 8 | name: Python ${{ matrix.python-version }} on ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest, macos-latest, windows-latest] 12 | python-version: ['3.10', '3.11', '3.12'] 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | pip3 install codecov pytest-cov || pip3 install --user codecov pytest-cov; 22 | - name: Run tests 23 | run: | 24 | pip3 install --upgrade-strategy eager -v ".[test]" 25 | coverage run --source=skdatasets/ -m pytest; 26 | coverage xml -o coverage.xml # explicitely exporting coverage file to be read by coverage report command. 27 | - name: Archive code coverage results 28 | uses: actions/upload-artifact@v3 29 | with: 30 | name: code-coverage-report 31 | path: coverage.xml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | /.pytest_cache/ 103 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Díaz-Vico" 5 | given-names: "David" 6 | orcid: "https://orcid.org/0000-0002-4002-5312" 7 | affiliation: "Universidad Autónoma de Madrid" 8 | - family-names: "Ramos-Carreño" 9 | given-names: "Carlos" 10 | orcid: "https://orcid.org/0000-0003-2566-7058" 11 | affiliation: "Universidad Autónoma de Madrid" 12 | email: vnmabus@gmail.com 13 | title: "scikit-datasets: Scikit-learn-compatible datasets" 14 | date-released: 2022-03-24 15 | doi: 10.5281/zenodo.6383047 16 | url: "https://github.com/daviddiazvico/scikit-datasets" 17 | license: MIT 18 | keywords: 19 | - datasets 20 | - repository 21 | - benchmark 22 | - Python 23 | identifiers: 24 | - description: "This is the collection of archived snapshots of all versions of scikit-datasets" 25 | type: doi 26 | value: 10.5281/zenodo.6383047 27 | - description: "This is the archived snapshot of version 0.2 of scikit-datasets" 28 | type: doi 29 | value: 10.5281/zenodo.6383048 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 David Díaz Vico 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # scikit-datasets 2 | Scikit-learn-compatible datasets 3 | 4 | ## Status 5 | [![Tests](https://github.com/daviddiazvico/scikit-datasets/actions/workflows/tests.yml/badge.svg)](https://github.com/daviddiazvico/scikit-datasets/actions/workflows/tests.yml) 6 | [![Maintainability](https://api.codeclimate.com/v1/badges/a37c9ee152b41a0cb577/maintainability)](https://codeclimate.com/github/daviddiazvico/scikit-datasets/maintainability) 7 | [![Test Coverage](https://api.codeclimate.com/v1/badges/a37c9ee152b41a0cb577/test_coverage)](https://codeclimate.com/github/daviddiazvico/scikit-datasets/test_coverage) 8 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6383047.svg)](https://doi.org/10.5281/zenodo.6383047) 9 | 10 | ## Installation 11 | Available in [PyPI](https://pypi.python.org/pypi?:action=display&name=scikit-datasets) 12 | ``` 13 | pip install scikit-datasets 14 | ``` 15 | 16 | ## Documentation 17 | Autogenerated and hosted in [GitHub Pages](https://daviddiazvico.github.io/scikit-datasets/) 18 | 19 | ## Distribution 20 | Run the following command from the project home to create the distribution 21 | ``` 22 | python setup.py sdist bdist_wheel 23 | ``` 24 | and upload the package to [testPyPI](https://testpypi.python.org/) 25 | ``` 26 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 27 | ``` 28 | or [PyPI](https://pypi.python.org/) 29 | ``` 30 | twine upload dist/* 31 | ``` 32 | 33 | ## Citation 34 | If you find scikit-datasets useful, please cite it in your publications. 35 | You can find the appropriate citation format in the sidebar, in both APA and 36 | Bibtex. 37 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | collect_ignore = ["setup.py"] 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | /autosummary/ 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | import sys 21 | 22 | import pkg_resources 23 | 24 | try: 25 | release = pkg_resources.get_distribution("scikit-datasets").version 26 | except pkg_resources.DistributionNotFound: 27 | print( 28 | "To build the documentation, The distribution information of\n" 29 | "scikit-datasets has to be available. Either install the package\n" 30 | 'into your development environment or run "setup.py develop"\n' 31 | "to setup the metadata. A virtualenv is recommended!\n" 32 | ) 33 | sys.exit(1) 34 | del pkg_resources 35 | 36 | version = ".".join(release.split(".")[:2]) 37 | 38 | project = "scikit-datasets" 39 | copyright = "2020, David Diaz Vico" 40 | author = "David Diaz Vico" 41 | 42 | 43 | # -- General configuration --------------------------------------------------- 44 | 45 | # Add any Sphinx extension module names here, as strings. They can be 46 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 47 | # ones. 48 | extensions = [ 49 | "sphinx.ext.autodoc", 50 | "sphinx.ext.autosummary", 51 | "sphinx.ext.viewcode", 52 | "sphinx.ext.githubpages", 53 | "sphinx.ext.napoleon", 54 | "sphinx.ext.intersphinx", 55 | ] 56 | 57 | # Add any paths that contain templates here, relative to this directory. 58 | templates_path = ["_templates"] 59 | 60 | # The language for content autogenerated by Sphinx. Refer to documentation 61 | # for a list of supported languages. 62 | # 63 | # This is also used if you do content translation via gettext catalogs. 64 | # Usually you set "language" from the command line for these cases. 65 | language = "en" 66 | 67 | # List of patterns, relative to source directory, that match files and 68 | # directories to ignore when looking for source files. 69 | # This pattern also affects html_static_path and html_extra_path. 70 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 71 | 72 | 73 | # -- Options for HTML output ------------------------------------------------- 74 | 75 | # The theme to use for HTML and HTML Help pages. See the documentation for 76 | # a list of builtin themes. 77 | # 78 | html_theme = "alabaster" 79 | 80 | # Add any paths that contain custom static files (such as style sheets) here, 81 | # relative to this directory. They are copied after the builtin static files, 82 | # so a file named "default.css" will overwrite the builtin "default.css". 83 | html_static_path = ["_static"] 84 | 85 | 86 | # -- Extension configuration ------------------------------------------------- 87 | intersphinx_mapping = { 88 | "python": ( 89 | "https://docs.python.org/{.major}".format(sys.version_info), 90 | None, 91 | ), 92 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 93 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 94 | "sklearn": ("https://scikit-learn.org/stable", None), 95 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 96 | "sacred": ("https://sacred.readthedocs.io/en/stable/", None), 97 | } 98 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to scikit-datasets's documentation! 2 | =========================================== 3 | 4 | This package groups functions to fetch datasets from several sources, 5 | converting them to scikit-learn-compatible datasets. 6 | 7 | In the `project page `_ hosted by 8 | Github you can find more information related to the development of the package. 9 | 10 | Installation 11 | ------------ 12 | 13 | Currently, scikit-datasets is available in Python 3, regardless of the 14 | platform. The stable version can be installed via 15 | `PyPI `_: 16 | 17 | .. code-block:: bash 18 | 19 | pip install scikit-datasets 20 | 21 | Content tree 22 | ------------ 23 | 24 | .. toctree:: 25 | :maxdepth: 2 26 | :caption: Contents: 27 | 28 | structure 29 | repositories 30 | utils 31 | 32 | 33 | Contributions 34 | ------------- 35 | 36 | All contributions are welcome. You can help this project grow in multiple ways, 37 | from creating an issue, reporting an improvement or a bug, to doing a 38 | repository fork and creating a pull request to the development branch. 39 | 40 | If you want to contribute routines for fetching data from a new repository, 41 | please first open a new issue or draft PR to discuss it. 42 | 43 | License 44 | ------- 45 | 46 | The package is licensed under the MIT License. A copy of the 47 | `license `_ 48 | can be found along with the code or in the project page. 49 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/repositories.rst: -------------------------------------------------------------------------------- 1 | Repositories 2 | ============ 3 | 4 | The core of the scikit-datasets package consist in fetching functions to 5 | obtain data from several repositories, containing both multivariate and 6 | functional data. 7 | 8 | The subpackage :mod:`~skdatasets.repositories` contains a module per available 9 | repository. For repositories that contain data in a regular format, that module 10 | has a ``fetch`` function that returns data in a 11 | :doc:`standardized format `. 12 | For modules such as :mod:`~skdatasets.repositories.cran`, where data is in 13 | a non-regular format, specific functions are provided to return the data. 14 | 15 | The available repositories are described next. 16 | 17 | Aneurisk 18 | -------- 19 | 20 | The Aneurisk dataset repository 21 | 22 | URL: http://ecm2.mathcs.emory.edu/aneuriskweb/index 23 | 24 | .. autosummary:: 25 | :toctree: autosummary 26 | 27 | ~skdatasets.repositories.aneurisk.fetch 28 | 29 | CRAN 30 | ---- 31 | 32 | The main repository of R packages. 33 | 34 | URL: https://cran.r-project.org/ 35 | 36 | .. autosummary:: 37 | :toctree: autosummary 38 | 39 | ~skdatasets.repositories.cran.fetch_package 40 | ~skdatasets.repositories.cran.fetch_dataset 41 | 42 | Forex 43 | ----- 44 | 45 | The foreign exchange market (Forex). 46 | 47 | URL: https://theforexapi.com/ 48 | 49 | .. autosummary:: 50 | :toctree: autosummary 51 | 52 | ~skdatasets.repositories.forex.fetch 53 | 54 | Keel 55 | ---- 56 | 57 | The KEEL-dataset repository. 58 | 59 | URL: https://sci2s.ugr.es/keel/datasets.php 60 | 61 | .. autosummary:: 62 | :toctree: autosummary 63 | 64 | ~skdatasets.repositories.keel.fetch 65 | 66 | Keras 67 | ----- 68 | 69 | The Keras example datasets. 70 | 71 | URL: https://keras.io/api/datasets 72 | 73 | .. autosummary:: 74 | :toctree: autosummary 75 | 76 | ~skdatasets.repositories.keras.fetch 77 | 78 | LIBSVM 79 | ------ 80 | 81 | The LIBSVM data repository. 82 | 83 | URL: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ 84 | 85 | .. autosummary:: 86 | :toctree: autosummary 87 | 88 | ~skdatasets.repositories.libsvm.fetch 89 | 90 | Rätsch 91 | ------- 92 | 93 | The Gunnar Rätsch benchmark datasets. 94 | 95 | URL: https://github.com/tdiethe/gunnar_raetsch_benchmark_datasets/ 96 | 97 | .. autosummary:: 98 | :toctree: autosummary 99 | 100 | ~skdatasets.repositories.raetsch.fetch 101 | 102 | scikit-learn 103 | ------------ 104 | 105 | The scikit-learn example datasets. 106 | 107 | URL: https://scikit-learn.org/stable/datasets.html 108 | 109 | .. autosummary:: 110 | :toctree: autosummary 111 | 112 | ~skdatasets.repositories.sklearn.fetch 113 | 114 | UCI 115 | --- 116 | 117 | The University of California Irvine (CRAN) repository. 118 | 119 | URL: https://archive.ics.uci.edu 120 | 121 | .. autosummary:: 122 | :toctree: autosummary 123 | 124 | ~skdatasets.repositories.uci.fetch 125 | 126 | UCR 127 | --- 128 | 129 | The UCR/UEA time series classification archive. 130 | 131 | URL: https://www.timeseriesclassification.com 132 | 133 | .. autosummary:: 134 | :toctree: autosummary 135 | 136 | ~skdatasets.repositories.ucr.fetch 137 | -------------------------------------------------------------------------------- /docs/structure.rst: -------------------------------------------------------------------------------- 1 | Dataset structure 2 | ================= 3 | 4 | Most of the repositories available in scikit-datasets have datasets in some 5 | regular format. 6 | In that case, its corresponding ``fetch`` function in scikit-datasets converts 7 | the data to a standardized format, similar to the one used in scikit-learn, but 8 | with new optional fields for additional features that some repositories 9 | include, such as indices for train, validation and test partitions. 10 | 11 | .. note:: 12 | Data in the CRAN repository is unstructured, and thus there is no ``fetch`` 13 | function for it. The data is returned in the original format. 14 | 15 | The structure is a :external:class:`~sklearn.utils.Bunch` object with the 16 | following fields: 17 | 18 | - ``data``: The matrix of observed data. A 2d NumPy array, ready to be used 19 | with scikit-learn tools. 20 | Each row correspond to a different observation while each column is a 21 | particular feature. 22 | For datasets with train, validation and test partitions, the whole data 23 | is included here. 24 | Use ``train_indices``, ``validation_indices`` and ``test_indices`` to 25 | select each partition. 26 | - ``target``: The target of the classification or regression problem. This 27 | is a 1d NumPy array except for multioutput problems, in with it is a 2d 28 | array, where each column correspond to a different output. 29 | - ``DESCR``: A human readable description of the dataset. 30 | - ``feature_names``: The list of feature names, if the repository has that 31 | information available. 32 | - ``target_names``: For classification problems, this correspond to the names 33 | of the different classes, if available. 34 | Note that this field in scikit-learn is used in some cases for naming the 35 | outputs in multioutput problems. 36 | As we will try to maintain compatibility with scikit-learn, the meaning of 37 | this field could change in future versions. 38 | - ``train_indices``: Indexes of the elements of the train partition, if 39 | available in the repository. 40 | - ``validation_indices``: Indexes of the elements of the validation partition, 41 | if available in the repository. 42 | - ``test_indices``: Indexes of the elements of the test partition, if 43 | available in the repository. 44 | - ``inner_cv``: A :external:term:`CV splitter` object, usable for cross 45 | validation and hyperparameter selection, if the repository provides a 46 | cross validation strategy (such as using a particular validation 47 | partition). 48 | - ``outer_cv``: A Python iterable over different train and test partitions, 49 | when they are provided in the repository. 50 | -------------------------------------------------------------------------------- /docs/utils.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | 4 | In addition to dataset fetching, scikit-datasets provide some utility functions 5 | that make easier dataset-related tasks, such as launching experiments and 6 | formatting their scores. 7 | 8 | Estimator 9 | --------- 10 | 11 | The following functions are related :external:term:`estimators` that follow the 12 | scikit-learn API. 13 | 14 | .. autosummary:: 15 | :toctree: autosummary 16 | 17 | ~skdatasets.utils.estimator.json2estimator 18 | 19 | Experiment 20 | ---------- 21 | 22 | The following functions can be used to execute several experiments, 23 | such as classification or regression tasks, with different datasets 24 | for a posterior comparison. 25 | These experiments are created using the Sacred library, storing the 26 | most common parameters of interest, such as time required for training or 27 | final scores. 28 | After the experiments have finished, the final scores can be easily 29 | retrieved in order to plot a table or perform hypothesis testing. 30 | 31 | .. autosummary:: 32 | :toctree: autosummary 33 | 34 | ~skdatasets.utils.experiment.create_experiments 35 | ~skdatasets.utils.experiment.run_experiments 36 | ~skdatasets.utils.experiment.fetch_scores 37 | ~skdatasets.utils.experiment.ScoresInfo 38 | 39 | Scores 40 | ------ 41 | 42 | The following functions can be used to format and display the scores of machine 43 | learning or hypothesis testing experiments. 44 | 45 | .. autosummary:: 46 | :toctree: autosummary 47 | 48 | ~skdatasets.utils.scores.scores_table 49 | ~skdatasets.utils.scores.hypotheses_table 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "scikit-datasets" 3 | version = "0.2.4" 4 | description = "Scikit-learn-compatible datasets" 5 | readme = "README.md" 6 | requires-python = ">=3.8" 7 | license = {file = "LICENSE"} 8 | keywords = ["scikit-learn", "datasets", "repository", "benchmark", "Python"] 9 | authors = [ 10 | {name = "David Diaz Vico", email = "david.diaz.vico@outlook.com"}, 11 | {name = "Carlos Ramos Carreño", email = "vnmabus@gmail.com"}, 12 | ] 13 | maintainers = [ 14 | {name = "David Diaz Vico", email = "david.diaz.vico@outlook.com"}, 15 | {name = "Carlos Ramos Carreño", email = "vnmabus@gmail.com"}, 16 | ] 17 | classifiers = [ 18 | "Intended Audience :: Science/Research", 19 | "Topic :: Scientific/Engineering", 20 | "Programming Language :: Python", 21 | "Programming Language :: Python :: 3", 22 | "Programming Language :: Python :: 3.8", 23 | "Programming Language :: Python :: 3.9", 24 | "Programming Language :: Python :: 3.10", 25 | ] 26 | dependencies = ["numpy", "scipy", "scikit-learn"] 27 | [project.optional-dependencies] 28 | cran = ["packaging", "rdata"] 29 | forex = ["forex_python>=1.6"] 30 | keel = ["pandas"] 31 | keras = ["tensorflow"] 32 | physionet = ["pandas", "wfdb"] 33 | utils-estimator = ["jsonpickle"] 34 | utils-experiments = ["sacred", "incense"] 35 | utils-scores = ["statsmodels", "jinja2"] 36 | all = ["scikit-datasets[cran, forex, keel, keras, physionet, utils-estimator, utils-experiments, utils-scores]"] 37 | test = ["pytest", "pytest-cov[all]", "coverage", "scikit-datasets[all]"] 38 | [project.urls] 39 | homepage = "https://github.com/daviddiazvico/scikit-datasets" 40 | documentation = "https://daviddiazvico.github.io/scikit-datasets/" 41 | repository = "https://github.com/daviddiazvico/scikit-datasets" 42 | download = "https://github.com/daviddiazvico/scikit-datasets/archive/v0.2.2.tar.gz" 43 | [build-system] 44 | # Minimum requirements for the build system to execute. 45 | requires = ["setuptools", "wheel"] # PEP 508 specifications. 46 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --doctest-modules --cov=skdatasets 6 | 7 | [metadata] 8 | description_file = README.md 9 | 10 | [darglint] 11 | docstring_style = numpy 12 | 13 | [isort] 14 | multi_line_output = 3 15 | include_trailing_comma = true 16 | use_parentheses = true 17 | combine_as_imports = 1 18 | 19 | [mypy] 20 | strict = True 21 | strict_equality = True 22 | implicit_reexport = True -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """ 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from setuptools import setup 9 | 10 | setup(name="scikit-datasets") 11 | -------------------------------------------------------------------------------- /skdatasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scikit-learn-compatible datasets. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from .repositories import fetch 9 | -------------------------------------------------------------------------------- /skdatasets/repositories/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | 6 | from . import aneurisk, libsvm, raetsch, sklearn, uci, ucr 7 | 8 | repos = { 9 | "libsvm": libsvm, 10 | "raetsch": raetsch, 11 | "sklearn": sklearn, 12 | "uci": uci, 13 | "ucr": ucr, 14 | "aneurisk": aneurisk, 15 | } 16 | try: 17 | from . import cran 18 | 19 | repos["cran"] = cran 20 | except ImportError: 21 | pass 22 | try: 23 | from . import forex 24 | 25 | repos["forex"] = forex 26 | except ImportError: 27 | pass 28 | try: 29 | from . import keel 30 | 31 | repos["keel"] = keel 32 | except ImportError: 33 | pass 34 | try: 35 | from . import keras 36 | 37 | repos["keras"] = keras 38 | except ImportError: 39 | pass 40 | try: 41 | from . import physionet 42 | 43 | repos["physionet"] = physionet 44 | except ImportError: 45 | pass 46 | 47 | 48 | def fetch(repository, dataset, collection=None, **kwargs): 49 | if collection: 50 | data = repos[repository].fetch(collection, dataset, **kwargs) 51 | else: 52 | data = repos[repository].fetch(dataset, **kwargs) 53 | return data 54 | -------------------------------------------------------------------------------- /skdatasets/repositories/aneurisk.py: -------------------------------------------------------------------------------- 1 | """Data from the AneuRisk project.""" 2 | 3 | import numpy as np 4 | from sklearn.utils import Bunch 5 | 6 | from .base import fetch_zip 7 | 8 | DESCR = """ 9 | The AneuRisk data set is based on a set of three-dimensional angiographic 10 | images taken from 65 subjects, hospitalized at Niguarda Ca’ Granda 11 | Hospital (Milan), who were suspected of being affected by cerebral aneurysms. 12 | Out of these 65 subjects, 33 subjects have an aneurysm at or after the 13 | terminal bifurcation of the ICA (“Upper” group), 25 subjects have an aneurysm 14 | along the ICA (“Lower” group), and 7 subjects were not found any visible 15 | aneurysm during the angiography (“No-aneurysm” group). 16 | 17 | For more information see: 18 | http://ecm2.mathcs.emory.edu/aneuriskdata/files/ReadMe_AneuRisk-website_2012-05.pdf 19 | """ 20 | 21 | 22 | def fetch(name="Aneurisk65", *, data_home=None, return_X_y=False): 23 | 24 | if name != "Aneurisk65": 25 | raise ValueError(f"Unknown dataset {name}") 26 | 27 | n_samples = 65 28 | 29 | url = ( 30 | "http://ecm2.mathcs.emory.edu/aneuriskdata/files/Carotid-data_MBI_workshop.zip" 31 | ) 32 | 33 | dataset_path = fetch_zip( 34 | dataname=name, 35 | urlname=url, 36 | subfolder="aneurisk", 37 | data_home=data_home, 38 | ) 39 | 40 | patient_dtype = [ 41 | ("patient", np.int_), 42 | ("code", "U8"), 43 | ("type", "U1"), 44 | ("aneurysm location", np.float_), 45 | ("left_right", "U2"), 46 | ] 47 | 48 | functions_dtype = [ 49 | ("curvilinear abscissa", np.object_), 50 | ("MISR", np.object_), 51 | ("X0 observed", np.object_), 52 | ("Y0 observed", np.object_), 53 | ("Z0 observed", np.object_), 54 | ("X0 observed FKS", np.object_), 55 | ("Y0 observed FKS", np.object_), 56 | ("Z0 observed FKS", np.object_), 57 | ("X0 observed FKS reflected", np.object_), 58 | ("X1 observed FKS", np.object_), 59 | ("Y1 observed FKS", np.object_), 60 | ("Z1 observed FKS", np.object_), 61 | ("X1 observed FKS reflected", np.object_), 62 | ("X2 observed FKS", np.object_), 63 | ("Y2 observed FKS", np.object_), 64 | ("Z2 observed FKS", np.object_), 65 | ("X2 observed FKS reflected", np.object_), 66 | ("Curvature FKS", np.object_), 67 | ] 68 | 69 | complete_dtype = patient_dtype + functions_dtype 70 | 71 | X = np.zeros(shape=n_samples, dtype=complete_dtype) 72 | 73 | X[[p[0] for p in patient_dtype]] = np.genfromtxt( 74 | dataset_path / "Patients.txt", 75 | dtype=patient_dtype, 76 | skip_header=1, 77 | missing_values=("NA",), 78 | ) 79 | 80 | for i in range(n_samples): 81 | file = f"Rawdata_FKS_{i + 1}.txt" 82 | 83 | functions = np.genfromtxt( 84 | dataset_path / file, 85 | skip_header=1, 86 | ) 87 | 88 | for j, (f_name, _) in enumerate(functions_dtype): 89 | X[i][f_name] = functions[:, j] 90 | 91 | X = np.array(X.tolist(), dtype=np.object_) 92 | 93 | if return_X_y: 94 | return X, None 95 | 96 | return Bunch( 97 | data=X, 98 | target=None, 99 | train_indices=[], 100 | validation_indices=[], 101 | test_indices=[], 102 | name=name, 103 | DESCR=DESCR, 104 | feature_names=[t[0] for t in complete_dtype], 105 | ) 106 | -------------------------------------------------------------------------------- /skdatasets/repositories/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities. 3 | """ 4 | from __future__ import annotations 5 | 6 | import pathlib 7 | import tarfile 8 | import zipfile 9 | from os.path import basename, normpath 10 | from shutil import copyfileobj 11 | from typing import ( 12 | TYPE_CHECKING, 13 | Any, 14 | Callable, 15 | Literal, 16 | Optional, 17 | Sequence, 18 | Tuple, 19 | Union, 20 | overload, 21 | ) 22 | from urllib.error import HTTPError 23 | from urllib.request import urlopen 24 | 25 | import numpy as np 26 | from sklearn.datasets import get_data_home 27 | from sklearn.utils import Bunch 28 | 29 | if TYPE_CHECKING: 30 | from pandas import DataFrame, Series 31 | 32 | CompressedFile = Union[zipfile.ZipFile, tarfile.TarFile] 33 | 34 | OpenMethod = Callable[ 35 | [pathlib.Path, str], 36 | CompressedFile, 37 | ] 38 | 39 | 40 | class DatasetNotFoundError(ValueError): 41 | """Exception raised for dataset not found.""" 42 | 43 | def __init__(self, dataset_name: str) -> None: 44 | self.dataset_name = dataset_name 45 | super().__init__(f"Dataset '{dataset_name}' not found.") 46 | 47 | 48 | def fetch_file( 49 | dataname: str, 50 | urlname: str, 51 | subfolder: Optional[str] = None, 52 | data_home: Optional[str] = None, 53 | ) -> pathlib.Path: 54 | """Fetch dataset. 55 | 56 | Fetch a file from a given url and stores it in a given directory. 57 | 58 | Parameters 59 | ---------- 60 | dataname: string 61 | Dataset name. 62 | urlname: string 63 | Dataset url. 64 | subfolder: string, default=None 65 | The subfolder where to put the data, if any. 66 | data_home: string, default=None 67 | Dataset directory. If None, use the default of scikit-learn. 68 | 69 | Returns 70 | ------- 71 | filename: Path 72 | Name of the file. 73 | 74 | """ 75 | # check if this data set has been already downloaded 76 | data_home_path = pathlib.Path(get_data_home(data_home=data_home)) 77 | 78 | if subfolder: 79 | data_home_path /= subfolder 80 | 81 | data_home_path /= dataname 82 | if not data_home_path.exists(): 83 | data_home_path.mkdir(parents=True) 84 | filename = data_home_path / basename(normpath(urlname)) 85 | # if the file does not exist, download it 86 | if not filename.exists(): 87 | try: 88 | data_url = urlopen(urlname) 89 | except HTTPError as e: 90 | if e.code == 404: 91 | raise DatasetNotFoundError(dataname) from e 92 | raise 93 | if data_url.length == 0: 94 | raise DatasetNotFoundError(dataname) 95 | 96 | # store file 97 | try: 98 | with open(filename, "w+b") as data_file: 99 | copyfileobj(data_url, data_file) 100 | except Exception: 101 | filename.unlink() 102 | raise 103 | data_url.close() 104 | return filename 105 | 106 | 107 | @overload 108 | def _missing_files( 109 | compressed_file: zipfile.ZipFile, 110 | data_home_path: pathlib.Path, 111 | ) -> Sequence[zipfile.ZipInfo]: 112 | pass 113 | 114 | 115 | @overload 116 | def _missing_files( 117 | compressed_file: tarfile.TarFile, 118 | data_home_path: pathlib.Path, 119 | ) -> Sequence[tarfile.TarInfo]: 120 | pass 121 | 122 | 123 | def _missing_files( 124 | compressed_file: CompressedFile, 125 | data_home_path: pathlib.Path, 126 | ) -> Sequence[Union[zipfile.ZipInfo, tarfile.TarInfo]]: 127 | 128 | if isinstance(compressed_file, zipfile.ZipFile): 129 | 130 | members_zip = compressed_file.infolist() 131 | 132 | return [ 133 | info 134 | for info in members_zip 135 | if not (data_home_path / info.filename).exists() 136 | ] 137 | 138 | members_tar = compressed_file.getmembers() 139 | 140 | return [info for info in members_tar if not (data_home_path / info.name).exists()] 141 | 142 | 143 | def fetch_compressed( 144 | dataname: str, 145 | urlname: str, 146 | compression_open: OpenMethod, 147 | subfolder: Optional[str] = None, 148 | data_home: Optional[str] = None, 149 | open_format: str = "r", 150 | ) -> pathlib.Path: 151 | """Fetch compressed dataset. 152 | 153 | Fetch a compressed file from a given url, unzips and stores it in a given 154 | directory. 155 | 156 | Parameters 157 | ---------- 158 | dataname: string 159 | Dataset name. 160 | urlname: string 161 | Dataset url. 162 | compression_open: callable 163 | Module/class used to decompress the data. 164 | subfolder: string, default=None 165 | The subfolder where to put the data, if any. 166 | data_home: string, default=None 167 | Dataset directory. If None, use the default of scikit-learn. 168 | open_format: string 169 | Format for opening the compressed file. 170 | 171 | Returns 172 | ------- 173 | data_home: Path 174 | Directory. 175 | 176 | """ 177 | # fetch file 178 | filename = fetch_file( 179 | dataname, 180 | urlname, 181 | subfolder=subfolder, 182 | data_home=data_home, 183 | ) 184 | data_home_path = filename.parent 185 | # unzip file 186 | try: 187 | with compression_open(filename, open_format) as compressed_file: 188 | compressed_file.extractall( 189 | data_home_path, 190 | members=_missing_files(compressed_file, data_home_path), 191 | ) 192 | except Exception: 193 | filename.unlink() 194 | raise 195 | return data_home_path 196 | 197 | 198 | def fetch_zip( 199 | dataname: str, 200 | urlname: str, 201 | subfolder: Optional[str] = None, 202 | data_home: Optional[str] = None, 203 | ) -> pathlib.Path: 204 | """Fetch zipped dataset. 205 | 206 | Fetch a zip file from a given url, unzips and stores it in a given 207 | directory. 208 | 209 | Parameters 210 | ---------- 211 | dataname: string 212 | Dataset name. 213 | urlname: string 214 | Dataset url. 215 | subfolder: string, default=None 216 | The subfolder where to put the data, if any. 217 | data_home: string, default=None 218 | Dataset directory. If None, use the default of scikit-learn. 219 | 220 | Returns 221 | ------- 222 | data_home: Path 223 | Directory. 224 | 225 | """ 226 | return fetch_compressed( 227 | dataname=dataname, 228 | urlname=urlname, 229 | compression_open=zipfile.ZipFile, 230 | subfolder=subfolder, 231 | data_home=data_home, 232 | ) 233 | 234 | 235 | def fetch_tgz( 236 | dataname: str, 237 | urlname: str, 238 | subfolder: Optional[str] = None, 239 | data_home: Optional[str] = None, 240 | ) -> pathlib.Path: 241 | """Fetch tgz dataset. 242 | 243 | Fetch a tgz file from a given url, unzips and stores it in a given 244 | directory. 245 | 246 | Parameters 247 | ---------- 248 | dataname: string 249 | Dataset name. 250 | urlname: string 251 | Dataset url. 252 | subfolder: string, default=None 253 | The subfolder where to put the data, if any. 254 | data_home: string, default=None 255 | Dataset directory. If None, use the default of scikit-learn. 256 | 257 | Returns 258 | ------- 259 | data_home: Path 260 | Directory. 261 | 262 | """ 263 | return fetch_compressed( 264 | dataname=dataname, 265 | urlname=urlname, 266 | compression_open=tarfile.open, 267 | subfolder=subfolder, 268 | data_home=data_home, 269 | open_format="r:gz", 270 | ) 271 | 272 | 273 | @overload 274 | def dataset_from_dataframe( 275 | frame: DataFrame, 276 | *, 277 | DESCR: str = "", 278 | return_X_y: Literal[False], 279 | as_frame: bool, 280 | target_column: str | Sequence[str] | None, 281 | ) -> Bunch: 282 | pass 283 | 284 | 285 | @overload 286 | def dataset_from_dataframe( 287 | frame: DataFrame, 288 | *, 289 | DESCR: str = "", 290 | return_X_y: Literal[True], 291 | as_frame: Literal[False], 292 | target_column: None, 293 | ) -> Tuple[np.typing.NDArray[Any], None]: 294 | pass 295 | 296 | 297 | @overload 298 | def dataset_from_dataframe( 299 | frame: DataFrame, 300 | *, 301 | DESCR: str = "", 302 | return_X_y: Literal[True], 303 | as_frame: Literal[False], 304 | target_column: str | Sequence[str], 305 | ) -> Tuple[np.typing.NDArray[Any], np.typing.NDArray[Any]]: 306 | pass 307 | 308 | 309 | @overload 310 | def dataset_from_dataframe( 311 | frame: DataFrame, 312 | *, 313 | DESCR: str = "", 314 | return_X_y: Literal[True], 315 | as_frame: Literal[True], 316 | target_column: None, 317 | ) -> Tuple[DataFrame, None]: 318 | pass 319 | 320 | 321 | @overload 322 | def dataset_from_dataframe( 323 | frame: DataFrame, 324 | *, 325 | DESCR: str = "", 326 | return_X_y: Literal[True], 327 | as_frame: Literal[True], 328 | target_column: str, 329 | ) -> Tuple[DataFrame, Series]: 330 | pass 331 | 332 | 333 | @overload 334 | def dataset_from_dataframe( 335 | frame: DataFrame, 336 | *, 337 | DESCR: str = "", 338 | return_X_y: Literal[True], 339 | as_frame: Literal[True], 340 | target_column: Sequence[str], 341 | ) -> Tuple[DataFrame, DataFrame]: 342 | pass 343 | 344 | 345 | def dataset_from_dataframe( 346 | frame: DataFrame, 347 | *, 348 | DESCR: str = "", 349 | return_X_y: bool, 350 | as_frame: bool, 351 | target_column: str | Sequence[str] | None, 352 | ) -> ( 353 | Bunch 354 | | Tuple[np.typing.NDArray[float], np.typing.NDArray[int] | None] 355 | | Tuple[DataFrame, Series | DataFrame | None] 356 | ): 357 | 358 | data_dataframe = ( 359 | frame if target_column is None else frame.drop(target_column, axis=1) 360 | ) 361 | target_dataframe = None if target_column is None else frame.loc[:, target_column] 362 | 363 | data = data_dataframe if as_frame is True else data_dataframe.to_numpy() 364 | 365 | target = ( 366 | None 367 | if target_dataframe is None 368 | else target_dataframe 369 | if as_frame is True 370 | else target_dataframe.to_numpy() 371 | ) 372 | 373 | if return_X_y: 374 | return data, target 375 | 376 | feature_names = list(data_dataframe.keys()) 377 | target_names = None if target_dataframe is None else list(target_dataframe.keys()) 378 | 379 | bunch = Bunch( 380 | data=data, 381 | target=target, 382 | DESCR=DESCR, 383 | feature_names=feature_names, 384 | target_names=target_names, 385 | ) 386 | 387 | if as_frame: 388 | bunch["frame"] = frame 389 | 390 | return bunch 391 | -------------------------------------------------------------------------------- /skdatasets/repositories/cran.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datasets extracted from R packages in CRAN (https://cran.r-project.org/). 3 | 4 | @author: Carlos Ramos Carreño 5 | @license: MIT 6 | """ 7 | from __future__ import annotations 8 | 9 | import os 10 | import pathlib 11 | import re 12 | import urllib 13 | import warnings 14 | from html.parser import HTMLParser 15 | from pathlib import Path 16 | from typing import ( 17 | Any, 18 | Final, 19 | List, 20 | Literal, 21 | Mapping, 22 | Sequence, 23 | Tuple, 24 | TypedDict, 25 | overload, 26 | ) 27 | 28 | import numpy as np 29 | import pandas as pd 30 | import rdata 31 | from packaging.version import Version 32 | from sklearn.datasets import get_data_home 33 | from sklearn.utils import Bunch 34 | 35 | from .base import DatasetNotFoundError, fetch_tgz as _fetch_tgz 36 | 37 | CRAN_URL: Final = "https://CRAN.R-project.org" 38 | 39 | 40 | class _LatestVersionHTMLParser(HTMLParser): 41 | """Class for parsing the version in the CRAN package information page.""" 42 | 43 | def __init__(self, *, convert_charrefs: bool = True) -> None: 44 | super().__init__(convert_charrefs=convert_charrefs) 45 | 46 | self.last_is_version = False 47 | self.version: str | None = None 48 | self.version_regex = re.compile("(?i).*version.*") 49 | self.handling_td = False 50 | 51 | def handle_starttag( 52 | self, 53 | tag: str, 54 | attrs: List[Tuple[str, str | None]], 55 | ) -> None: 56 | if tag == "td": 57 | self.handling_td = True 58 | 59 | def handle_endtag(self, tag: str) -> None: 60 | self.handling_td = False 61 | 62 | def handle_data(self, data: str) -> None: 63 | if self.handling_td: 64 | if self.last_is_version: 65 | self.version = data 66 | self.last_is_version = False 67 | elif self.version_regex.match(data): 68 | self.last_is_version = True 69 | 70 | 71 | def _get_latest_version_online(package_name: str, dataset_name: str) -> str: 72 | """Get the latest version of the package from CRAN.""" 73 | parser = _LatestVersionHTMLParser() 74 | 75 | url_request = urllib.request.Request( 76 | url=f"{CRAN_URL}/package={package_name}", 77 | ) 78 | try: 79 | with urllib.request.urlopen(url_request) as url_file: 80 | url_content = url_file.read().decode("utf-8") 81 | except urllib.error.HTTPError as e: 82 | if e.code == 404: 83 | raise DatasetNotFoundError(f"{package_name}/{dataset_name}") from e 84 | raise 85 | 86 | parser.feed(url_content) 87 | 88 | if parser.version is None: 89 | raise ValueError(f"Version of package {package_name} not found") 90 | 91 | return parser.version 92 | 93 | 94 | def _get_latest_version_offline(package_name: str) -> str | None: 95 | """ 96 | Get the latest downloaded version of the package. 97 | 98 | Returns None if not found. 99 | 100 | """ 101 | home = pathlib.Path(get_data_home()) # Should allow providing data home? 102 | 103 | downloaded_packages = tuple(home.glob(package_name + "_*.tar.gz")) 104 | 105 | if downloaded_packages: 106 | versions = [ 107 | Version(p.name[(len(package_name) + 1):-len(".tar.gz")]) 108 | for p in downloaded_packages 109 | ] 110 | 111 | versions.sort() 112 | latest_version = versions[-1] 113 | 114 | return str(latest_version) 115 | 116 | return None 117 | 118 | 119 | def _get_version( 120 | package_name: str, 121 | *, 122 | dataset_name: str, 123 | version: str | None = None, 124 | ) -> str: 125 | """ 126 | Get the version of the package. 127 | 128 | If the version is specified, return it. 129 | Otherwise, try to find the last version online. 130 | If offline, try to find the downloaded version, if any. 131 | 132 | """ 133 | if version is None: 134 | try: 135 | version = _get_latest_version_online( 136 | package_name, 137 | dataset_name=dataset_name, 138 | ) 139 | except (urllib.error.URLError, DatasetNotFoundError): 140 | version = _get_latest_version_offline(package_name) 141 | 142 | if version is None: 143 | raise 144 | 145 | return version 146 | 147 | 148 | def _get_urls( 149 | package_name: str, 150 | *, 151 | dataset_name: str, 152 | version: str | None = None, 153 | ) -> Sequence[str]: 154 | 155 | version = _get_version(package_name, dataset_name=dataset_name, version=version) 156 | 157 | filename = f"{package_name}_{version}.tar.gz" 158 | 159 | latest_url = f"{CRAN_URL}/src/contrib/{filename}" 160 | archive_url = f"{CRAN_URL}/src/contrib/Archive/{package_name}/{filename}" 161 | return (latest_url, archive_url) 162 | 163 | 164 | def _download_package_data( 165 | package_name: str, 166 | *, 167 | dataset_name: str = "*", 168 | package_url: str | None = None, 169 | version: str | None = None, 170 | folder_name: str | None = None, 171 | subdir: str | None = None, 172 | ) -> Path: 173 | if package_url is None: 174 | url_list = _get_urls( 175 | package_name, 176 | dataset_name=dataset_name, 177 | version=version, 178 | ) 179 | else: 180 | url_list = (package_url,) 181 | 182 | if folder_name is None: 183 | folder_name = os.path.basename(url_list[0]) 184 | 185 | if subdir is None: 186 | subdir = "data" 187 | 188 | for i, url in enumerate(url_list): 189 | try: 190 | directory = _fetch_tgz(folder_name, url, subfolder="cran") 191 | break 192 | except Exception: 193 | # If it is the last url, reraise 194 | if i >= len(url_list) - 1: 195 | raise 196 | 197 | data_path = directory / package_name / subdir 198 | 199 | return data_path 200 | 201 | 202 | def fetch_dataset( 203 | dataset_name: str, 204 | package_name: str, 205 | *, 206 | package_url: str | None = None, 207 | version: str | None = None, 208 | folder_name: str | None = None, 209 | subdir: str | None = None, 210 | converter: rdata.conversion.Converter | None = None, 211 | ) -> Mapping[str, Any]: 212 | """ 213 | Fetch an R dataset. 214 | 215 | Only .rda datasets in community packages can be downloaded for now. 216 | 217 | R datasets do not have a fixed structure, so this function does not 218 | attempt to force one. 219 | 220 | Parameters 221 | ---------- 222 | dataset_name: string 223 | Name of the dataset, including extension if any. 224 | package_name: string 225 | Name of the R package where this dataset resides. 226 | package_url: string 227 | Package url. If `None` it tries to obtain it from the package name. 228 | version: string 229 | If `package_url` is not specified, the version of the package to 230 | download. By default is the latest one. 231 | folder_name: string 232 | Name of the folder where the downloaded package is stored. By default, 233 | is the last component of `package_url`. 234 | subdir: string 235 | Subdirectory of the package containing the datasets. By default is 236 | 'data'. 237 | converter: rdata.conversion.Converter 238 | Object used to translate R objects into Python objects. 239 | 240 | Returns 241 | ------- 242 | data: dict 243 | Dictionary-like object with all the data and metadata. 244 | 245 | """ 246 | 247 | if converter is None: 248 | converter = rdata.conversion.SimpleConverter() 249 | 250 | data_path = _download_package_data( 251 | package_name, 252 | dataset_name=dataset_name, 253 | package_url=package_url, 254 | version=version, 255 | folder_name=folder_name, 256 | subdir=subdir, 257 | ) 258 | 259 | file_path = data_path / dataset_name 260 | 261 | if not file_path.suffix: 262 | possible_names = list(data_path.glob(dataset_name + ".*")) 263 | if len(possible_names) != 1: 264 | raise FileNotFoundError( 265 | f"Dataset {dataset_name} not found in " f"package {package_name}", 266 | ) 267 | 268 | file_path = data_path / possible_names[0] 269 | 270 | parsed = rdata.parser.parse_file(file_path) 271 | 272 | return converter.convert(parsed) 273 | 274 | 275 | def fetch_package( 276 | package_name: str, 277 | *, 278 | package_url: str | None = None, 279 | version: str | None = None, 280 | folder_name: str | None = None, 281 | subdir: str | None = None, 282 | converter: rdata.conversion.Converter | None = None, 283 | ignore_errors: bool = False, 284 | ) -> Mapping[str, Any]: 285 | """ 286 | Fetch all datasets from a R package. 287 | 288 | Only .rda datasets in community packages can be downloaded for now. 289 | 290 | R datasets do not have a fixed structure, so this function does not 291 | attempt to force one. 292 | 293 | Parameters 294 | ---------- 295 | package_name: string 296 | Name of the R package. 297 | package_url: string 298 | Package url. If `None` it tries to obtain it from the package name. 299 | version: string 300 | If `package_url` is not specified, the version of the package to 301 | download. By default is the latest one. 302 | folder_name: string 303 | Name of the folder where the downloaded package is stored. By default, 304 | is the last component of `package_url`. 305 | subdir: string 306 | Subdirectory of the package containing the datasets. By default is 307 | 'data'. 308 | converter: rdata.conversion.Converter 309 | Object used to translate R objects into Python objects. 310 | ignore_errors: boolean 311 | If True, ignore the datasets producing errors and return the 312 | remaining ones. 313 | 314 | Returns 315 | ------- 316 | data: dict 317 | Dictionary-like object with all the data and metadata. 318 | 319 | """ 320 | 321 | if converter is None: 322 | converter = rdata.conversion.SimpleConverter() 323 | 324 | data_path = _download_package_data( 325 | package_name, 326 | package_url=package_url, 327 | version=version, 328 | folder_name=folder_name, 329 | subdir=subdir, 330 | ) 331 | 332 | if not data_path.exists(): 333 | return {} 334 | 335 | all_datasets = {} 336 | 337 | for dataset in data_path.iterdir(): 338 | 339 | if dataset.suffix.lower() in [".rda", ".rdata"]: 340 | try: 341 | parsed = rdata.parser.parse_file(dataset) 342 | 343 | converted = converter.convert(parsed) 344 | 345 | all_datasets.update(converted) 346 | except Exception: 347 | if not ignore_errors: 348 | raise 349 | else: 350 | warnings.warn( 351 | f"Error loading dataset {dataset.name}", 352 | stacklevel=2, 353 | ) 354 | 355 | return all_datasets 356 | 357 | 358 | class _DatasetArguments(TypedDict): 359 | load_args: Tuple[Sequence[Any], Mapping[str, Any]] 360 | sklearn_args: Tuple[Sequence[Any], Mapping[str, Any]] 361 | 362 | 363 | datasets: Mapping[str, _DatasetArguments] = { 364 | "geyser": { 365 | "load_args": (["geyser.rda", "MASS"], {}), 366 | "sklearn_args": ([], {"target_name": "waiting"}), 367 | }, 368 | } 369 | 370 | 371 | def _to_sklearn( 372 | dataset: Mapping[str, Any], 373 | *, 374 | target_name: str, 375 | ) -> Bunch: 376 | """Transform R datasets to Sklearn format, if possible""" 377 | assert len(dataset.keys()) == 1 378 | name = tuple(dataset.keys())[0] 379 | obj = dataset[name] 380 | 381 | if isinstance(obj, pd.DataFrame): 382 | feature_names = list(obj.keys()) 383 | feature_names.remove(target_name) 384 | X = pd.get_dummies(obj[feature_names]).values 385 | y = obj[target_name].values 386 | else: 387 | raise ValueError( 388 | "Dataset not automatically convertible to Sklearn format", 389 | ) 390 | 391 | return Bunch( 392 | data=X, 393 | target=y, 394 | train_indices=[], 395 | validation_indices=[], 396 | test_indices=[], 397 | inner_cv=None, 398 | outer_cv=None, 399 | target_names=target_name, 400 | feature_names=feature_names, 401 | ) 402 | 403 | 404 | @overload 405 | def fetch( 406 | name: str, 407 | *, 408 | return_X_y: Literal[False] = False, 409 | ) -> Bunch: 410 | pass 411 | 412 | 413 | @overload 414 | def fetch( 415 | name: str, 416 | *, 417 | return_X_y: Literal[True], 418 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[Any]]: 419 | pass 420 | 421 | 422 | def fetch( 423 | name: str, 424 | *, 425 | return_X_y: bool = False, 426 | ) -> Bunch | Tuple[np.typing.NDArray[float], np.typing.NDArray[Any]]: 427 | """ 428 | Load a dataset. 429 | 430 | Parameters 431 | ---------- 432 | name : string 433 | Dataset name. 434 | return_X_y : bool, default=False 435 | If True, returns ``(data, target)`` instead of a Bunch object. 436 | 437 | Returns 438 | ------- 439 | data : Bunch 440 | Dictionary-like object with all the data and metadata. 441 | 442 | (data, target) : tuple if ``return_X_y`` is True 443 | 444 | """ 445 | load_args = datasets[name]["load_args"] 446 | dataset = fetch_dataset(*load_args[0], **load_args[1]) 447 | 448 | sklearn_args = datasets[name]["sklearn_args"] 449 | sklearn_dataset = _to_sklearn(dataset, *sklearn_args[0], **sklearn_args[1]) 450 | 451 | if return_X_y: 452 | return sklearn_dataset.data, sklearn_dataset.target 453 | 454 | return sklearn_dataset 455 | -------------------------------------------------------------------------------- /skdatasets/repositories/forex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Forex datasets (http://forex-python.readthedocs.io). 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | import time 9 | from datetime import date, timedelta 10 | 11 | import numpy as np 12 | from sklearn.utils import Bunch 13 | 14 | from forex_python.bitcoin import BtcConverter 15 | from forex_python.converter import CurrencyRates 16 | 17 | 18 | def _fetch(get_rate, start=date(2015, 1, 1), end=date.today()): 19 | """Fetch dataset.""" 20 | data = [] 21 | delta = end - start 22 | for d in range(delta.days + 1): 23 | day = start + timedelta(days=d) 24 | rate = get_rate(day) 25 | data.append(rate) 26 | return np.asarray(data).reshape((-1, 1)) 27 | 28 | 29 | def _load_bitcoin(start=date(2015, 1, 1), end=date.today(), currency="EUR"): 30 | """Load bitcoin dataset""" 31 | btcc = BtcConverter() 32 | 33 | def get_rate(day): 34 | return btcc.get_previous_price(currency, day) 35 | 36 | return _fetch(get_rate, start=start, end=end) 37 | 38 | 39 | def _load_forex( 40 | start=date(2015, 1, 1), end=date.today(), currency_1="USD", currency_2="EUR" 41 | ): 42 | """Load forex dataset.""" 43 | cr = CurrencyRates() 44 | 45 | def get_rate(day): 46 | time.sleep(0.1) 47 | return cr.get_rate(currency_1, currency_2, day) 48 | 49 | return _fetch(get_rate, start=start, end=end) 50 | 51 | 52 | def fetch( 53 | start=date(2015, 1, 1), 54 | end=date.today(), 55 | currency_1="USD", 56 | currency_2="EUR", 57 | return_X_y=False, 58 | ): 59 | """Fetch Forex datasets. 60 | 61 | Fetches the ECB Forex and Coindesk Bitcoin datasets. More info at 62 | http://forex-python.readthedocs.io. 63 | 64 | Parameters 65 | ---------- 66 | start : date, default=2015-01-01 67 | Initial date. 68 | end : date, default=today 69 | Final date. 70 | currency_1 : str, default='USD' 71 | Currency 1. 72 | currency_2 : str, default='EUR' 73 | Currency 2. 74 | return_X_y : bool, default=False 75 | If True, returns ``(data, target)`` instead of a Bunch object. 76 | 77 | Returns 78 | ------- 79 | data : Bunch 80 | Dictionary-like object with all the data and metadata. 81 | 82 | (data, target) : tuple if ``return_X_y`` is True 83 | 84 | """ 85 | if currency_1 == "BTC": 86 | X = _load_bitcoin(start=start, end=end, currency=currency_2) 87 | descr = "BTC-" + str(currency_2) 88 | elif currency_2 == "BTC": 89 | X = _load_bitcoin(start=start, end=end, currency=currency_1) 90 | descr = "BTC-" + str(currency_1) 91 | else: 92 | X = _load_forex( 93 | start=start, end=end, currency_1=currency_1, currency_2=currency_2 94 | ) 95 | descr = str(currency_1) + "-" + str(currency_2) 96 | descr = descr + start.strftime("%Y-%m-%d") + "-" + end.strftime("%Y-%m-%d") 97 | 98 | if return_X_y: 99 | return X, None 100 | 101 | return Bunch( 102 | data=X, 103 | target=None, 104 | train_indices=[], 105 | validation_indices=[], 106 | test_indices=[], 107 | inner_cv=None, 108 | outer_cv=None, 109 | DESCR=descr, 110 | ) 111 | -------------------------------------------------------------------------------- /skdatasets/repositories/keel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Keel datasets (http://sci2s.ugr.es/keel). 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | from __future__ import annotations 8 | 9 | import io 10 | import os 11 | from pathlib import Path 12 | from types import MappingProxyType 13 | from typing import ( 14 | AbstractSet, 15 | Any, 16 | Final, 17 | Iterator, 18 | Literal, 19 | Optional, 20 | Sequence, 21 | Tuple, 22 | Union, 23 | overload, 24 | ) 25 | from zipfile import ZipFile 26 | 27 | import numpy as np 28 | import pandas as pd 29 | from sklearn.utils import Bunch 30 | 31 | from .base import fetch_file 32 | 33 | BASE_URL = "http://sci2s.ugr.es/keel" 34 | COLLECTIONS: Final = frozenset( 35 | ( 36 | "classification", 37 | "missing", 38 | "imbalanced", 39 | "multiInstance", 40 | "multilabel", 41 | "textClassification", 42 | "classNoise", 43 | "attributeNoise", 44 | "semisupervised", 45 | "regression", 46 | "timeseries", 47 | "unsupervised", 48 | "lowQuality", 49 | ) 50 | ) 51 | 52 | 53 | # WTFs 54 | IMBALANCED_URLS: Final = ( 55 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9", 56 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9p1", 57 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9p2", 58 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9p3", 59 | "dataset/data/imbalanced", 60 | "keel-dataset/datasets/imbalanced/imb_noisyBordExamples", 61 | "keel-dataset/datasets/imbalanced/preprocessed", 62 | ) 63 | 64 | IRREGULAR_DESCR_IMBALANCED_URLS: Final = ( 65 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9", 66 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9p1", 67 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9p2", 68 | "keel-dataset/datasets/imbalanced/imb_IRhigherThan9p3", 69 | ) 70 | 71 | INCORRECT_DESCR_IMBALANCED_URLS: Final = MappingProxyType( 72 | {"semisupervised": "classification"}, 73 | ) 74 | 75 | 76 | class KeelOuterCV(object): 77 | """Iterable over already separated CV partitions of the dataset.""" 78 | 79 | def __init__( 80 | self, 81 | Xs: Sequence[np.typing.NDArray[float]], 82 | ys: Sequence[np.typing.NDArray[Union[int, float]]], 83 | Xs_test: Sequence[np.typing.NDArray[float]], 84 | ys_test: Sequence[np.typing.NDArray[Union[int, float]]], 85 | ) -> None: 86 | self.Xs = Xs 87 | self.ys = ys 88 | self.Xs_test = Xs_test 89 | self.ys_test = ys_test 90 | 91 | def __iter__( 92 | self, 93 | ) -> Iterator[ 94 | Tuple[ 95 | np.typing.NDArray[float], 96 | np.typing.NDArray[Union[int, float]], 97 | np.typing.NDArray[float], 98 | np.typing.NDArray[Union[int, float]], 99 | ] 100 | ]: 101 | return zip(self.Xs, self.ys, self.Xs_test, self.ys_test) 102 | 103 | 104 | def _load_Xy( 105 | zipfile: Path, 106 | csvfile: str, 107 | sep: str = ",", 108 | header: Optional[int] = None, 109 | engine: str = "python", 110 | na_values: AbstractSet[str] = frozenset(("?")), 111 | **kwargs: Any, 112 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[Union[int, float]]]: 113 | """Load a zipped csv file with target in the last column.""" 114 | with ZipFile(zipfile) as z: 115 | with z.open(csvfile) as c: 116 | s = io.StringIO(c.read().decode(encoding="utf8")) 117 | data = pd.read_csv( 118 | s, 119 | sep=sep, 120 | header=header, 121 | engine=engine, 122 | na_values=na_values, 123 | **kwargs, 124 | ) 125 | data.columns = data.columns.astype(str) 126 | X = pd.get_dummies(data.iloc[:, :-1]) 127 | y = pd.factorize(data.iloc[:, -1].tolist(), sort=True)[0] 128 | return X, y 129 | 130 | 131 | def _load_descr( 132 | collection: str, 133 | name: str, 134 | data_home: Optional[str] = None, 135 | ) -> Tuple[int, str]: 136 | """Load a dataset description.""" 137 | subfolder = os.path.join("keel", collection) 138 | filename = name + "-names.txt" 139 | if collection == "imbalanced": 140 | for url in IMBALANCED_URLS: 141 | if url in IRREGULAR_DESCR_IMBALANCED_URLS: 142 | url = BASE_URL + "/" + url + "/" + "names" + "/" + filename 143 | else: 144 | url = BASE_URL + "/" + url + "/" + filename 145 | try: 146 | f = fetch_file( 147 | dataname=name, 148 | urlname=url, 149 | subfolder=subfolder, 150 | data_home=data_home, 151 | ) 152 | break 153 | except Exception: 154 | pass 155 | else: 156 | collection = ( 157 | INCORRECT_DESCR_IMBALANCED_URLS[collection] 158 | if collection in INCORRECT_DESCR_IMBALANCED_URLS 159 | else collection 160 | ) 161 | url = f"{BASE_URL}/dataset/data/{collection}/{filename}" 162 | f = fetch_file( 163 | dataname=name, 164 | urlname=url, 165 | subfolder=subfolder, 166 | data_home=data_home, 167 | ) 168 | with open(f) as rst_file: 169 | fdescr = rst_file.read() 170 | nattrs = fdescr.count("@attribute") 171 | return nattrs, fdescr 172 | 173 | 174 | def _fetch_keel_zip( 175 | collection: str, 176 | name: str, 177 | filename: str, 178 | data_home: Optional[str] = None, 179 | ) -> Path: 180 | """Fetch Keel dataset zip file.""" 181 | subfolder = os.path.join("keel", collection) 182 | if collection == "imbalanced": 183 | for url in IMBALANCED_URLS: 184 | url = BASE_URL + "/" + url + "/" + filename 185 | try: 186 | return fetch_file( 187 | dataname=name, 188 | urlname=url, 189 | subfolder=subfolder, 190 | data_home=data_home, 191 | ) 192 | except Exception: 193 | pass 194 | else: 195 | url = f"{BASE_URL}/dataset/data/{collection}/{filename}" 196 | return fetch_file( 197 | dataname=name, 198 | urlname=url, 199 | subfolder=subfolder, 200 | data_home=data_home, 201 | ) 202 | raise ValueError("Dataset not found") 203 | 204 | 205 | def _load_folds( 206 | collection: str, 207 | name: str, 208 | nfolds: Literal[None, 1, 5, 10], 209 | dobscv: bool, 210 | nattrs: int, 211 | data_home: Optional[str] = None, 212 | ) -> Tuple[ 213 | np.typing.NDArray[float], 214 | np.typing.NDArray[Union[int, float]], 215 | Optional[KeelOuterCV], 216 | ]: 217 | """Load a dataset folds.""" 218 | filename = name + ".zip" 219 | f = _fetch_keel_zip(collection, name, filename, data_home=data_home) 220 | X, y = _load_Xy(f, name + ".dat", skiprows=nattrs + 4) 221 | cv = None 222 | if nfolds in (5, 10): 223 | fold = "dobscv" if dobscv else "fold" 224 | filename = name + "-" + str(nfolds) + "-" + fold + ".zip" 225 | f = _fetch_keel_zip(collection, name, filename, data_home=data_home) 226 | Xs = [] 227 | ys = [] 228 | Xs_test = [] 229 | ys_test = [] 230 | for i in range(nfolds): 231 | if dobscv: 232 | # Zipfiles always use fordward slashes, even in Windows. 233 | _name = f"{name}/{name}-{nfolds}dobscv-{i + 1}" 234 | else: 235 | _name = f"{name}-{nfolds}-{i + 1}" 236 | X_fold, y_fold = _load_Xy(f, _name + "tra.dat", skiprows=nattrs + 4) 237 | X_test_fold, y_test_fold = _load_Xy( 238 | f, 239 | _name + "tst.dat", 240 | skiprows=nattrs + 4, 241 | ) 242 | Xs.append(X_fold) 243 | ys.append(y_fold) 244 | Xs_test.append(X_test_fold) 245 | ys_test.append(y_test_fold) 246 | 247 | cv = KeelOuterCV(Xs, ys, Xs_test, ys_test) 248 | return X, y, cv 249 | 250 | 251 | @overload 252 | def fetch( 253 | collection: str, 254 | name: str, 255 | data_home: Optional[str] = None, 256 | nfolds: Literal[None, 1, 5, 10] = None, 257 | dobscv: bool = False, 258 | *, 259 | return_X_y: Literal[False] = False, 260 | ) -> Bunch: 261 | pass 262 | 263 | 264 | @overload 265 | def fetch( 266 | collection: str, 267 | name: str, 268 | data_home: Optional[str] = None, 269 | nfolds: Literal[None, 1, 5, 10] = None, 270 | dobscv: bool = False, 271 | *, 272 | return_X_y: Literal[True], 273 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[Union[int, float]]]: 274 | pass 275 | 276 | 277 | def fetch( 278 | collection: str, 279 | name: str, 280 | data_home: Optional[str] = None, 281 | nfolds: Literal[None, 1, 5, 10] = None, 282 | dobscv: bool = False, 283 | *, 284 | return_X_y: bool = False, 285 | ) -> Union[ 286 | Bunch, 287 | Tuple[np.typing.NDArray[float], np.typing.NDArray[Union[int, float]]], 288 | ]: 289 | """ 290 | Fetch Keel dataset. 291 | 292 | Fetch a Keel dataset by collection and name. More info at 293 | http://sci2s.ugr.es/keel. 294 | 295 | Parameters 296 | ---------- 297 | collection : string 298 | Collection name. 299 | name : string 300 | Dataset name. 301 | data_home : string or None, default None 302 | Specify another download and cache folder for the data sets. By default 303 | all scikit-learn data is stored in ‘~/scikit_learn_data’ subfolders. 304 | nfolds : int, default=None 305 | Number of folds. Depending on the dataset, valid values are 306 | {None, 1, 5, 10}. 307 | dobscv : bool, default=False 308 | If folds are in {5, 10}, indicates that the cv folds are distribution 309 | optimally balanced stratified. Only available for some datasets. 310 | return_X_y : bool, default=False 311 | If True, returns ``(data, target)`` instead of a Bunch object. 312 | kwargs : dict 313 | Optional key-value arguments 314 | 315 | Returns 316 | ------- 317 | data : Bunch 318 | Dictionary-like object with all the data and metadata. 319 | 320 | (data, target) : tuple if ``return_X_y`` is True 321 | 322 | """ 323 | if collection not in COLLECTIONS: 324 | raise ValueError("Avaliable collections are " + str(list(COLLECTIONS))) 325 | nattrs, DESCR = _load_descr(collection, name, data_home=data_home) 326 | X, y, cv = _load_folds( 327 | collection, 328 | name, 329 | nfolds, 330 | dobscv, 331 | nattrs, 332 | data_home=data_home, 333 | ) 334 | 335 | if return_X_y: 336 | return X, y 337 | 338 | return Bunch( 339 | data=X, 340 | target=y, 341 | train_indices=[], 342 | validation_indices=[], 343 | test_indices=[], 344 | inner_cv=None, 345 | outer_cv=cv, 346 | DESCR=DESCR, 347 | ) 348 | -------------------------------------------------------------------------------- /skdatasets/repositories/keras.py: -------------------------------------------------------------------------------- 1 | """ 2 | Keras datasets (https://keras.io/datasets). 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | from typing import Any, Literal, Tuple, overload 11 | 12 | import numpy as np 13 | from sklearn.utils import Bunch 14 | 15 | from tensorflow.keras.datasets import ( 16 | boston_housing, 17 | cifar10, 18 | cifar100, 19 | fashion_mnist, 20 | imdb, 21 | mnist, 22 | reuters, 23 | ) 24 | 25 | DATASETS = { 26 | "boston_housing": boston_housing.load_data, 27 | "cifar10": cifar10.load_data, 28 | "cifar100": cifar100.load_data, 29 | "fashion_mnist": fashion_mnist.load_data, 30 | "imdb": imdb.load_data, 31 | "mnist": mnist.load_data, 32 | "reuters": reuters.load_data, 33 | } 34 | 35 | 36 | @overload 37 | def fetch( 38 | name: str, 39 | *, 40 | return_X_y: Literal[False] = False, 41 | **kwargs: Any, 42 | ) -> Bunch: 43 | pass 44 | 45 | 46 | @overload 47 | def fetch( 48 | name: str, 49 | *, 50 | return_X_y: Literal[True], 51 | **kwargs: Any, 52 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[int]]: 53 | pass 54 | 55 | 56 | def fetch( 57 | name: str, 58 | *, 59 | return_X_y: bool = False, 60 | **kwargs: Any, 61 | ) -> Bunch | Tuple[np.typing.NDArray[float], np.typing.NDArray[int]]: 62 | """ 63 | Fetch Keras dataset. 64 | 65 | Fetch a Keras dataset by name. More info at https://keras.io/datasets. 66 | 67 | Parameters 68 | ---------- 69 | name : string 70 | Dataset name. 71 | return_X_y : bool, default=False 72 | If True, returns ``(data, target)`` instead of a Bunch object. 73 | **kwargs : dict 74 | Optional key-value arguments. See https://keras.io/datasets. 75 | 76 | Returns 77 | ------- 78 | data : Bunch 79 | Dictionary-like object with all the data and metadata. 80 | 81 | (data, target) : tuple if ``return_X_y`` is True 82 | 83 | """ 84 | (X_train, y_train), (X_test, y_test) = DATASETS[name](**kwargs) 85 | if len(X_train.shape) > 2: 86 | name = name + " " + str(X_train.shape[1:]) + " shaped" 87 | X_max = np.iinfo(X_train[0][0].dtype).max 88 | n_features = np.prod(X_train.shape[1:]) 89 | X_train = X_train.reshape([X_train.shape[0], n_features]) / X_max 90 | X_test = X_test.reshape([X_test.shape[0], n_features]) / X_max 91 | 92 | X = np.concatenate((X_train, X_test)) 93 | y = np.concatenate((y_train, y_test)) 94 | 95 | if return_X_y: 96 | return X, y 97 | 98 | return Bunch( 99 | data=X, 100 | target=y, 101 | train_indices=list(range(len(X_train))), 102 | validation_indices=[], 103 | test_indices=list(range(len(X_train), len(X))), 104 | inner_cv=None, 105 | outer_cv=None, 106 | DESCR=name, 107 | ) 108 | -------------------------------------------------------------------------------- /skdatasets/repositories/libsvm.py: -------------------------------------------------------------------------------- 1 | """ 2 | LIBSVM datasets (https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets). 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | from __future__ import annotations 8 | 9 | import os 10 | from typing import Final, Literal, Sequence, Tuple, overload 11 | 12 | import numpy as np 13 | import scipy as sp 14 | from sklearn.datasets import load_svmlight_file, load_svmlight_files 15 | from sklearn.model_selection import PredefinedSplit 16 | from sklearn.utils import Bunch 17 | 18 | from .base import DatasetNotFoundError, fetch_file 19 | 20 | BASE_URL: Final = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" 21 | COLLECTIONS: Final = frozenset( 22 | ( 23 | "binary", 24 | "multiclass", 25 | "regression", 26 | "string", 27 | ) 28 | ) 29 | 30 | 31 | def _fetch_partition( 32 | collection: str, 33 | name: str, 34 | partition: str, 35 | data_home: str | None = None, 36 | ) -> str | None: 37 | """Fetch dataset partition.""" 38 | subfolder = os.path.join("libsvm", collection) 39 | dataname = name.replace("/", "-") 40 | 41 | url = f"{BASE_URL}/{collection}/{name}{partition}" 42 | 43 | for data_url in (f"{url}.bz2", url): 44 | try: 45 | return os.fspath( 46 | fetch_file( 47 | dataname, 48 | urlname=data_url, 49 | subfolder=subfolder, 50 | data_home=data_home, 51 | ), 52 | ) 53 | except DatasetNotFoundError: 54 | pass 55 | 56 | return None 57 | 58 | 59 | def _load( 60 | collection: str, 61 | name: str, 62 | data_home: str | None = None, 63 | ) -> Tuple[ 64 | np.typing.NDArray[float], 65 | np.typing.NDArray[int | float], 66 | Sequence[int], 67 | Sequence[int], 68 | Sequence[int], 69 | PredefinedSplit, 70 | ]: 71 | """Load dataset.""" 72 | filename = _fetch_partition(collection, name, "", data_home) 73 | filename_tr = _fetch_partition(collection, name, ".tr", data_home) 74 | filename_val = _fetch_partition(collection, name, ".val", data_home) 75 | filename_t = _fetch_partition(collection, name, ".t", data_home) 76 | filename_r = _fetch_partition(collection, name, ".r", data_home) 77 | 78 | if ( 79 | (filename_tr is not None) 80 | and (filename_val is not None) 81 | and (filename_t is not None) 82 | ): 83 | 84 | _, _, X_tr, y_tr, X_val, y_val, X_test, y_test = load_svmlight_files( 85 | [ 86 | filename, 87 | filename_tr, 88 | filename_val, 89 | filename_t, 90 | ] 91 | ) 92 | 93 | cv = PredefinedSplit([-1] * X_tr.shape[0] + [0] * X_val.shape[0]) 94 | 95 | X = sp.sparse.vstack((X_tr, X_val, X_test)) 96 | y = np.hstack((y_tr, y_val, y_test)) 97 | 98 | # Compute indices 99 | train_indices = list(range(X_tr.shape[0])) 100 | validation_indices = list( 101 | range( 102 | X_tr.shape[0], 103 | X_tr.shape[0] + X_val.shape[0], 104 | ) 105 | ) 106 | test_indices = list(range(X_tr.shape[0] + X_val.shape[0], X.shape[0])) 107 | 108 | elif (filename_tr is not None) and (filename_val is not None): 109 | 110 | _, _, X_tr, y_tr, X_val, y_val = load_svmlight_files( 111 | [ 112 | filename, 113 | filename_tr, 114 | filename_val, 115 | ] 116 | ) 117 | 118 | cv = PredefinedSplit([-1] * X_tr.shape[0] + [0] * X_val.shape[0]) 119 | 120 | X = sp.sparse.vstack((X_tr, X_val)) 121 | y = np.hstack((y_tr, y_val)) 122 | 123 | # Compute indices 124 | train_indices = list(range(X_tr.shape[0])) 125 | validation_indices = list(range(X_tr.shape[0], X.shape[0])) 126 | test_indices = [] 127 | 128 | elif (filename_t is not None) and (filename_r is not None): 129 | 130 | X_tr, y_tr, X_test, y_test, X_remaining, y_remaining = load_svmlight_files( 131 | [ 132 | filename, 133 | filename_t, 134 | filename_r, 135 | ] 136 | ) 137 | 138 | X = sp.sparse.vstack((X_tr, X_test, X_remaining)) 139 | y = np.hstack((y_tr, y_test, y_remaining)) 140 | 141 | # Compute indices 142 | train_indices = list(range(X_tr.shape[0])) 143 | validation_indices = [] 144 | test_indices = list( 145 | range( 146 | X_tr.shape[0], 147 | X_tr.shape[0] + X_test.shape[0], 148 | ), 149 | ) 150 | 151 | cv = None 152 | 153 | elif filename_t is not None: 154 | 155 | X_tr, y_tr, X_test, y_test = load_svmlight_files( 156 | [ 157 | filename, 158 | filename_t, 159 | ] 160 | ) 161 | 162 | X = sp.sparse.vstack((X_tr, X_test)) 163 | y = np.hstack((y_tr, y_test)) 164 | 165 | # Compute indices 166 | train_indices = list(range(X_tr.shape[0])) 167 | validation_indices = [] 168 | test_indices = list(range(X_tr.shape[0], X.shape[0])) 169 | 170 | cv = None 171 | 172 | else: 173 | 174 | X, y = load_svmlight_file(filename) 175 | 176 | # Compute indices 177 | train_indices = [] 178 | validation_indices = [] 179 | test_indices = [] 180 | 181 | cv = None 182 | 183 | return X, y, train_indices, validation_indices, test_indices, cv 184 | 185 | 186 | @overload 187 | def fetch( 188 | collection: str, 189 | name: str, 190 | *, 191 | data_home: str | None = None, 192 | return_X_y: Literal[False] = False, 193 | ) -> Bunch: 194 | pass 195 | 196 | 197 | @overload 198 | def fetch( 199 | collection: str, 200 | name: str, 201 | *, 202 | data_home: str | None = None, 203 | return_X_y: Literal[True], 204 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[int | float]]: 205 | pass 206 | 207 | 208 | def fetch( 209 | collection: str, 210 | name: str, 211 | *, 212 | data_home: str | None = None, 213 | return_X_y: bool = False, 214 | ) -> Bunch | Tuple[np.typing.NDArray[float], np.typing.NDArray[int | float]]: 215 | """ 216 | Fetch LIBSVM dataset. 217 | 218 | Fetch a LIBSVM dataset by collection and name. More info at 219 | https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets. 220 | 221 | Parameters 222 | ---------- 223 | collection : string 224 | Collection name. 225 | name : string 226 | Dataset name. 227 | data_home : string or None, default None 228 | Specify another download and cache folder for the data sets. By default 229 | all scikit-learn data is stored in ‘~/scikit_learn_data’ subfolders. 230 | return_X_y : bool, default=False 231 | If True, returns ``(data, target)`` instead of a Bunch object. 232 | 233 | Returns 234 | ------- 235 | data : Bunch 236 | Dictionary-like object with all the data and metadata. 237 | 238 | (data, target) : tuple if ``return_X_y`` is True 239 | 240 | """ 241 | if collection not in COLLECTIONS: 242 | raise Exception("Avaliable collections are " + str(list(COLLECTIONS))) 243 | 244 | X, y, train_indices, validation_indices, test_indices, cv = _load( 245 | collection, 246 | name, 247 | data_home=data_home, 248 | ) 249 | 250 | if return_X_y: 251 | return X, y 252 | 253 | return Bunch( 254 | data=X, 255 | target=y, 256 | train_indices=train_indices, 257 | validation_indices=validation_indices, 258 | test_indices=test_indices, 259 | inner_cv=cv, 260 | outer_cv=None, 261 | DESCR=name, 262 | ) 263 | -------------------------------------------------------------------------------- /skdatasets/repositories/physionet.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ast 4 | import math 5 | import re 6 | import urllib 7 | from html.parser import HTMLParser 8 | from pathlib import Path 9 | from typing import ( 10 | Any, 11 | Final, 12 | List, 13 | Literal, 14 | Mapping, 15 | Sequence, 16 | Tuple, 17 | overload, 18 | ) 19 | 20 | import numpy as np 21 | import pandas as pd 22 | import wfdb.io 23 | from sklearn.utils import Bunch 24 | 25 | from skdatasets.repositories.base import dataset_from_dataframe 26 | 27 | from .base import DatasetNotFoundError, fetch_zip 28 | 29 | BASE_URL: Final = "https://physionet.org/static/published-projects" 30 | INFO_STRING_SEMICOLONS_ONE_STR: Final = r"(\S*): (\S*)\s*" 31 | INFO_STRING_SEMICOLONS_SEVERAL_REGEX: Final = re.compile( 32 | rf"(?:{INFO_STRING_SEMICOLONS_ONE_STR})+", 33 | ) 34 | INFO_STRING_SEMICOLONS_ONE_REGEX: Final = re.compile( 35 | INFO_STRING_SEMICOLONS_ONE_STR, 36 | ) 37 | 38 | 39 | class _ZipNameHTMLParser(HTMLParser): 40 | """Class for parsing the zip name in PhysioNet directory listing.""" 41 | 42 | def __init__(self, *, convert_charrefs: bool = True) -> None: 43 | super().__init__(convert_charrefs=convert_charrefs) 44 | 45 | self.zip_name: str | None = None 46 | 47 | def handle_starttag( 48 | self, 49 | tag: str, 50 | attrs: List[Tuple[str, str | None]], 51 | ) -> None: 52 | if tag == "a": 53 | for attr in attrs: 54 | if attr[0] == "href" and attr[1] and attr[1].endswith(".zip"): 55 | self.zip_name = attr[1] 56 | 57 | 58 | def _get_zip_name_online(dataset_name: str) -> str: 59 | """Get the zip name of the dataset.""" 60 | parser = _ZipNameHTMLParser() 61 | 62 | url_request = urllib.request.Request(url=f"{BASE_URL}/{dataset_name}") 63 | try: 64 | with urllib.request.urlopen(url_request) as url_file: 65 | url_content = url_file.read().decode("utf-8") 66 | except urllib.error.HTTPError as e: 67 | if e.code == 404: 68 | raise DatasetNotFoundError(dataset_name) from e 69 | raise 70 | 71 | parser.feed(url_content) 72 | 73 | if parser.zip_name is None: 74 | raise ValueError(f"No zip file found for dataset '{dataset_name}'") 75 | 76 | return parser.zip_name 77 | 78 | 79 | def _parse_info_string_value(value: str) -> Any: 80 | if value.lower() == "nan": 81 | return math.nan 82 | try: 83 | value = ast.literal_eval(value) 84 | except Exception: 85 | pass 86 | 87 | return value 88 | 89 | 90 | def _get_info_strings(comments: Sequence[str]) -> Mapping[str, Any]: 91 | 92 | info_strings_semicolons = {} 93 | info_strings_spaces = {} 94 | 95 | for comment in comments: 96 | if comment[0] not in {"-", "#"}: 97 | if re.fullmatch(INFO_STRING_SEMICOLONS_SEVERAL_REGEX, comment): 98 | for result in re.finditer( 99 | INFO_STRING_SEMICOLONS_ONE_REGEX, 100 | comment, 101 | ): 102 | key = result.group(1) 103 | if key[0] == "<" and key[-1] == ">": 104 | key = key[1:-1] 105 | info_strings_semicolons[key] = _parse_info_string_value( 106 | result.group(2) 107 | ) 108 | else: 109 | split = comment.rsplit(maxsplit=1) 110 | if len(split) == 2: 111 | key, value = split 112 | info_strings_spaces[key] = _parse_info_string_value(value) 113 | 114 | if info_strings_semicolons: 115 | return info_strings_semicolons 116 | 117 | # Check for absurd things in spaces 118 | if len(info_strings_spaces) == 1 or any( 119 | key.count(" ") > 3 for key in info_strings_spaces 120 | ): 121 | return {} 122 | 123 | return info_strings_spaces 124 | 125 | 126 | def _join_info_dicts( 127 | dicts: Sequence[Mapping[str, Any]], 128 | ) -> Mapping[str, np.typing.NDArray[Any]]: 129 | 130 | joined = {} 131 | 132 | n_keys = len(dicts[0]) 133 | if not all(len(d) == n_keys for d in dicts): 134 | return {} 135 | 136 | for key in dicts[0]: 137 | joined[key] = np.array([d[key] for d in dicts]) 138 | 139 | return joined 140 | 141 | 142 | def _constant_attrs(register: wfdb.Record) -> Sequence[Any]: 143 | return (register.n_sig, register.sig_name, register.units, register.fs) 144 | 145 | 146 | @overload 147 | def fetch( 148 | name: str, 149 | *, 150 | data_home: str | None = None, 151 | return_X_y: Literal[False] = False, 152 | as_frame: bool = False, 153 | target_column: str | Sequence[str] | None = None, 154 | ) -> Bunch: 155 | pass 156 | 157 | 158 | @overload 159 | def fetch( 160 | name: str, 161 | *, 162 | data_home: str | None = None, 163 | return_X_y: Literal[True], 164 | as_frame: Literal[False] = False, 165 | target_column: None = None, 166 | ) -> Tuple[np.typing.NDArray[Any], None]: 167 | pass 168 | 169 | 170 | @overload 171 | def fetch( 172 | name: str, 173 | *, 174 | data_home: str | None = None, 175 | return_X_y: Literal[True], 176 | as_frame: Literal[False] = False, 177 | target_column: str | Sequence[str], 178 | ) -> Tuple[np.typing.NDArray[Any], np.typing.NDArray[Any]]: 179 | pass 180 | 181 | 182 | @overload 183 | def fetch( 184 | name: str, 185 | *, 186 | data_home: str | None = None, 187 | return_X_y: Literal[True], 188 | as_frame: Literal[True], 189 | target_column: None = None, 190 | ) -> Tuple[pd.DataFrame, None]: 191 | pass 192 | 193 | 194 | @overload 195 | def fetch( 196 | name: str, 197 | *, 198 | data_home: str | None = None, 199 | return_X_y: Literal[True], 200 | as_frame: Literal[True], 201 | target_column: str, 202 | ) -> Tuple[pd.DataFrame, pd.Series]: 203 | pass 204 | 205 | 206 | @overload 207 | def fetch( 208 | name: str, 209 | *, 210 | data_home: str | None = None, 211 | return_X_y: Literal[True], 212 | as_frame: Literal[True], 213 | target_column: Sequence[str], 214 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: 215 | pass 216 | 217 | 218 | def fetch( 219 | name: str, 220 | *, 221 | data_home: str | None = None, 222 | return_X_y: bool = False, 223 | as_frame: bool = False, 224 | target_column: str | Sequence[str] | None = None, 225 | ) -> ( 226 | Bunch 227 | | Tuple[np.typing.NDArray[Any], np.typing.NDArray[Any] | None] 228 | | Tuple[pd.DataFrame, pd.Series | pd.DataFrame | None] 229 | ): 230 | 231 | zip_name = _get_zip_name_online(name) 232 | 233 | path = fetch_zip( 234 | dataname=name, 235 | urlname=f"{BASE_URL}/{name}/{zip_name}", 236 | subfolder="physionet", 237 | data_home=data_home, 238 | ) 239 | 240 | subpath = path / Path(zip_name).stem 241 | if subpath.exists(): 242 | path = subpath 243 | 244 | with open(path / "RECORDS") as records_file: 245 | records = [ 246 | wfdb.io.rdrecord(str(path / record_name.rstrip("\n"))) 247 | for record_name in records_file 248 | ] 249 | 250 | info_strings = [_get_info_strings(r.comments) for r in records] 251 | info = _join_info_dicts(info_strings) 252 | 253 | assert all(_constant_attrs(r) == _constant_attrs(records[0]) for r in records) 254 | data = { 255 | "signal": [r.p_signal for r in records], 256 | } 257 | 258 | dataframe = pd.DataFrame( 259 | {**info, **data}, 260 | index=[r.record_name for r in records], 261 | ) 262 | dataframe["signal"].attrs.update( 263 | sig_name=records[0].sig_name, 264 | units=records[0].units, 265 | fs=records[0].fs, 266 | ) 267 | 268 | return dataset_from_dataframe( 269 | dataframe, 270 | return_X_y=return_X_y, 271 | as_frame=as_frame, 272 | target_column=target_column, 273 | ) 274 | -------------------------------------------------------------------------------- /skdatasets/repositories/raetsch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gunnar Raetsch benchmark datasets 3 | (https://github.com/tdiethe/gunnar_raetsch_benchmark_datasets). 4 | 5 | @author: David Diaz Vico 6 | @license: MIT 7 | """ 8 | from __future__ import annotations 9 | 10 | import hashlib 11 | from pathlib import Path 12 | from typing import ( 13 | Final, 14 | Iterator, 15 | Literal, 16 | Optional, 17 | Sequence, 18 | Tuple, 19 | Union, 20 | overload, 21 | ) 22 | 23 | import numpy as np 24 | from scipy.io import loadmat 25 | from sklearn.utils import Bunch 26 | 27 | from .base import fetch_file 28 | 29 | DATASETS: Final = frozenset( 30 | ( 31 | "banana", 32 | "breast_cancer", 33 | "diabetis", 34 | "flare_solar", 35 | "german", 36 | "heart", 37 | "image", 38 | "ringnorm", 39 | "splice", 40 | "thyroid", 41 | "titanic", 42 | "twonorm", 43 | "waveform", 44 | ) 45 | ) 46 | 47 | 48 | class RaetschOuterCV(object): 49 | """Iterable over already separated CV partitions of the dataset.""" 50 | 51 | def __init__( 52 | self, 53 | X: np.typing.NDArray[float], 54 | y: np.typing.NDArray[Union[int, float]], 55 | train_splits: Sequence[np.typing.NDArray[int]], 56 | test_splits: Sequence[np.typing.NDArray[int]], 57 | ) -> None: 58 | self.X = X 59 | self.y = y 60 | self.train_splits = train_splits 61 | self.test_splits = test_splits 62 | 63 | def __iter__( 64 | self, 65 | ) -> Iterator[ 66 | Tuple[ 67 | np.typing.NDArray[float], 68 | np.typing.NDArray[Union[int, float]], 69 | np.typing.NDArray[float], 70 | np.typing.NDArray[Union[int, float]], 71 | ] 72 | ]: 73 | return ( 74 | (self.X[tr - 1], self.y[tr - 1], self.X[ts - 1], self.y[ts - 1]) 75 | for tr, ts in zip(self.train_splits, self.test_splits) 76 | ) 77 | 78 | 79 | def _fetch_remote(data_home: Optional[str] = None) -> Path: 80 | """ 81 | Helper function to download the remote dataset into path. 82 | 83 | Fetch the remote dataset, save into path using remote's filename and ensure 84 | its integrity based on the SHA256 Checksum of the downloaded file. 85 | 86 | Parameters 87 | ---------- 88 | dirname : string 89 | Directory to save the file to. 90 | 91 | Returns 92 | ------- 93 | file_path: string 94 | Full path of the created file. 95 | """ 96 | file_path = fetch_file( 97 | "raetsch", 98 | "https://github.com/tdiethe/gunnar_raetsch_benchmark_datasets" 99 | "/raw/master/benchmarks.mat", 100 | data_home=data_home, 101 | ) 102 | sha256hash = hashlib.sha256() 103 | with open(file_path, "rb") as f: 104 | while True: 105 | buffer = f.read(8192) 106 | if not buffer: 107 | break 108 | sha256hash.update(buffer) 109 | checksum = sha256hash.hexdigest() 110 | remote_checksum = "47c19e4bc4716edc4077cfa5ea61edf4d02af4ec51a0ecfe035626ae8b561c75" 111 | if remote_checksum != checksum: 112 | raise IOError( 113 | f"{file_path} has an SHA256 checksum ({checksum}) differing " 114 | f"from expected ({remote_checksum}), file may be corrupted.", 115 | ) 116 | return file_path 117 | 118 | 119 | @overload 120 | def fetch( 121 | name: str, 122 | data_home: Optional[str] = None, 123 | *, 124 | return_X_y: Literal[False] = False, 125 | ) -> Bunch: 126 | pass 127 | 128 | 129 | @overload 130 | def fetch( 131 | name: str, 132 | data_home: Optional[str] = None, 133 | *, 134 | return_X_y: Literal[True], 135 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[Union[int, float]]]: 136 | pass 137 | 138 | 139 | def fetch( 140 | name: str, 141 | data_home: Optional[str] = None, 142 | *, 143 | return_X_y: bool = False, 144 | ) -> Union[ 145 | Bunch, 146 | Tuple[np.typing.NDArray[float], np.typing.NDArray[Union[int, float]]], 147 | ]: 148 | """Fetch Gunnar Raetsch's dataset. 149 | 150 | Fetch a Gunnar Raetsch's benchmark dataset by name. Availabe datasets are 151 | 'banana', 'breast_cancer', 'diabetis', 'flare_solar', 'german', 'heart', 152 | 'image', 'ringnorm', 'splice', 'thyroid', 'titanic', 'twonorm' and 153 | 'waveform'. More info at 154 | https://github.com/tdiethe/gunnar_raetsch_benchmark_datasets. 155 | 156 | Parameters 157 | ---------- 158 | name : string 159 | Dataset name. 160 | data_home : string or None, default None 161 | Specify another download and cache folder for the data sets. By default 162 | all scikit-learn data is stored in ‘~/scikit_learn_data’ subfolders. 163 | return_X_y : bool, default=False 164 | If True, returns ``(data, target)`` instead of a Bunch object. 165 | 166 | Returns 167 | ------- 168 | data : Bunch 169 | Dictionary-like object with all the data and metadata. 170 | 171 | (data, target) : tuple if ``return_X_y`` is True 172 | 173 | """ 174 | if name not in DATASETS: 175 | raise Exception("Avaliable datasets are " + str(list(DATASETS))) 176 | filename = _fetch_remote(data_home=data_home) 177 | X, y, train_splits, test_splits = loadmat(filename)[name][0][0] 178 | if len(y.shape) == 2 and y.shape[1] == 1: 179 | y = y.ravel() 180 | 181 | cv = RaetschOuterCV(X, y, train_splits, test_splits) 182 | 183 | if return_X_y: 184 | return X, y 185 | 186 | return Bunch( 187 | data=X, 188 | target=y, 189 | train_indices=[], 190 | validation_indices=[], 191 | test_indices=[], 192 | inner_cv=None, 193 | outer_cv=cv, 194 | DESCR=name, 195 | ) 196 | -------------------------------------------------------------------------------- /skdatasets/repositories/sklearn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scikit-learn datasets (http://scikit-learn.org/stable/datasets/index.html). 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from sklearn.datasets import ( 9 | fetch_20newsgroups, 10 | fetch_20newsgroups_vectorized, 11 | fetch_california_housing, 12 | fetch_covtype, 13 | fetch_kddcup99, 14 | fetch_lfw_pairs, 15 | fetch_lfw_people, 16 | fetch_olivetti_faces, 17 | fetch_rcv1, 18 | load_breast_cancer, 19 | load_diabetes, 20 | load_digits, 21 | load_iris, 22 | load_linnerud, 23 | load_wine, 24 | make_biclusters, 25 | make_blobs, 26 | make_checkerboard, 27 | make_circles, 28 | make_classification, 29 | make_friedman1, 30 | make_friedman2, 31 | make_friedman3, 32 | make_gaussian_quantiles, 33 | make_hastie_10_2, 34 | make_low_rank_matrix, 35 | make_moons, 36 | make_multilabel_classification, 37 | make_regression, 38 | make_s_curve, 39 | make_sparse_coded_signal, 40 | make_sparse_spd_matrix, 41 | make_sparse_uncorrelated, 42 | make_spd_matrix, 43 | make_swiss_roll, 44 | ) 45 | 46 | DATASETS = { 47 | "20newsgroups": fetch_20newsgroups, 48 | "20newsgroups_vectorized": fetch_20newsgroups_vectorized, 49 | "biclusters": make_biclusters, 50 | "blobs": make_blobs, 51 | "breast_cancer": load_breast_cancer, 52 | "california_housing": fetch_california_housing, 53 | "checkerboard": make_checkerboard, 54 | "circles": make_circles, 55 | "classification": make_classification, 56 | "covtype": fetch_covtype, 57 | "diabetes": load_diabetes, 58 | "digits": load_digits, 59 | "friedman1": make_friedman1, 60 | "friedman2": make_friedman2, 61 | "friedman3": make_friedman3, 62 | "gaussian_quantiles": make_gaussian_quantiles, 63 | "hastie_10_2": make_hastie_10_2, 64 | "iris": load_iris, 65 | "kddcup99": fetch_kddcup99, 66 | "lfw_people": fetch_lfw_people, 67 | "lfw_pairs": fetch_lfw_pairs, 68 | "linnerud": load_linnerud, 69 | "low_rank_matrix": make_low_rank_matrix, 70 | "moons": make_moons, 71 | "multilabel_classification": make_multilabel_classification, 72 | "olivetti_faces": fetch_olivetti_faces, 73 | "rcv1": fetch_rcv1, 74 | "regression": make_regression, 75 | "s_curve": make_s_curve, 76 | "sparse_coded_signal": make_sparse_coded_signal, 77 | "sparse_spd_matrix": make_sparse_spd_matrix, 78 | "sparse_uncorrelated": make_sparse_uncorrelated, 79 | "spd_matrix": make_spd_matrix, 80 | "swiss_roll": make_swiss_roll, 81 | "wine": load_wine, 82 | } 83 | 84 | 85 | def fetch(name, *, return_X_y=False, **kwargs): 86 | """Fetch Scikit-learn dataset. 87 | 88 | Fetch a Scikit-learn dataset by name. More info at 89 | http://scikit-learn.org/stable/datasets/index.html. 90 | 91 | Parameters 92 | ---------- 93 | name : string 94 | Dataset name. 95 | return_X_y : bool, default=False 96 | If True, returns ``(data, target)`` instead of a Bunch object. 97 | **kwargs : dict 98 | Optional key-value arguments. See 99 | scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets. 100 | 101 | Returns 102 | ------- 103 | data : Bunch 104 | Dictionary-like object with all the data and metadata. 105 | 106 | (data, target) : tuple if ``return_X_y`` is True 107 | 108 | """ 109 | if return_X_y: 110 | kwargs["return_X_y"] = True 111 | 112 | data = DATASETS[name](**kwargs) 113 | 114 | if not return_X_y: 115 | data.train_indices = [] 116 | data.validation_indices = [] 117 | data.test_indices = [] 118 | data.inner_cv = None 119 | data.outer_cv = None 120 | 121 | return data 122 | -------------------------------------------------------------------------------- /skdatasets/repositories/uci.py: -------------------------------------------------------------------------------- 1 | """ 2 | UCI datasets (https://archive.ics.uci.edu/ml/datasets.html). 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | from __future__ import annotations 8 | 9 | from pathlib import Path 10 | from typing import Any, Literal, Optional, Tuple, Union, overload 11 | 12 | import numpy as np 13 | from sklearn.preprocessing import OrdinalEncoder 14 | from sklearn.utils import Bunch 15 | 16 | from .base import fetch_file 17 | 18 | BASE_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases" 19 | 20 | 21 | def _load_csv( 22 | fname: Path, 23 | **kwargs: Any, 24 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[Union[float, int, str]],]: 25 | """Load a csv with targets in the last column and features in the rest.""" 26 | data = np.genfromtxt( 27 | fname, 28 | dtype=str, 29 | delimiter=",", 30 | encoding=None, 31 | **kwargs, 32 | ) 33 | X = data[:, :-1] 34 | try: 35 | X = X.astype(float) 36 | except ValueError: 37 | pass 38 | 39 | y = data[:, -1] 40 | 41 | return X, y 42 | 43 | 44 | def _fetch( 45 | name: str, 46 | data_home: Optional[str] = None, 47 | ) -> Tuple[ 48 | np.typing.NDArray[float], 49 | np.typing.NDArray[Union[float, int]], 50 | Optional[np.typing.NDArray[float]], 51 | Optional[np.typing.NDArray[Union[float, int]]], 52 | str, 53 | np.typing.NDArray[str], 54 | ]: 55 | """Fetch dataset.""" 56 | subfolder = "uci" 57 | filename_str = name + ".data" 58 | url = BASE_URL + "/" + name + "/" + filename_str 59 | 60 | filename = fetch_file( 61 | dataname=name, 62 | urlname=url, 63 | subfolder=subfolder, 64 | data_home=data_home, 65 | ) 66 | X, y = _load_csv(filename) 67 | target_names = None 68 | ordinal_encoder = OrdinalEncoder(dtype=np.int64) 69 | if y.dtype.type is np.str_: 70 | y = ordinal_encoder.fit_transform(y.reshape(-1, 1))[:, 0] 71 | target_names = ordinal_encoder.categories_[0] 72 | try: 73 | filename_str = name + ".test" 74 | url = BASE_URL + "/" + name + "/" + filename_str 75 | filename = fetch_file( 76 | dataname=name, 77 | urlname=url, 78 | subfolder=subfolder, 79 | data_home=data_home, 80 | ) 81 | X_test: Optional[np.typing.NDArray[float]] 82 | y_test: Optional[np.typing.NDArray[Union[float, int, str]]] 83 | X_test, y_test = _load_csv(filename) 84 | 85 | if y.dtype.type is np.str_: 86 | y_test = ordinal_encoder.transform(y_test.reshape(-1, 1))[:, 0] 87 | 88 | except Exception: 89 | X_test = None 90 | y_test = None 91 | try: 92 | filename_str = name + ".names" 93 | url = BASE_URL + "/" + name + "/" + filename_str 94 | filename = fetch_file( 95 | dataname=name, 96 | urlname=url, 97 | subfolder=subfolder, 98 | data_home=data_home, 99 | ) 100 | except Exception: 101 | filename_str = name + ".info" 102 | url = BASE_URL + "/" + name + "/" + filename_str 103 | filename = fetch_file( 104 | dataname=name, 105 | urlname=url, 106 | subfolder=subfolder, 107 | data_home=data_home, 108 | ) 109 | with open(filename) as rst_file: 110 | fdescr = rst_file.read() 111 | return X, y, X_test, y_test, fdescr, target_names 112 | 113 | 114 | @overload 115 | def fetch( 116 | name: str, 117 | data_home: Optional[str] = None, 118 | *, 119 | return_X_y: Literal[False] = False, 120 | ) -> Bunch: 121 | pass 122 | 123 | 124 | @overload 125 | def fetch( 126 | name: str, 127 | data_home: Optional[str] = None, 128 | *, 129 | return_X_y: Literal[True], 130 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[float]]: 131 | pass 132 | 133 | 134 | def fetch( 135 | name: str, 136 | data_home: Optional[str] = None, 137 | *, 138 | return_X_y: bool = False, 139 | ) -> Union[Bunch, Tuple[np.typing.NDArray[float], np.typing.NDArray[float]],]: 140 | """ 141 | Fetch UCI dataset. 142 | 143 | Fetch a UCI dataset by name. More info at 144 | https://archive.ics.uci.edu/ml/datasets.html. 145 | 146 | Parameters 147 | ---------- 148 | name : string 149 | Dataset name. 150 | data_home : string or None, default None 151 | Specify another download and cache folder for the data sets. By default 152 | all scikit-learn data is stored in ‘~/scikit_learn_data’ subfolders. 153 | return_X_y : bool, default=False 154 | If True, returns ``(data, target)`` instead of a Bunch object. 155 | 156 | Returns 157 | ------- 158 | data : Bunch 159 | Dictionary-like object with all the data and metadata. 160 | 161 | (data, target) : tuple if ``return_X_y`` is True 162 | 163 | """ 164 | X_train, y_train, X_test, y_test, DESCR, target_names = _fetch( 165 | name, 166 | data_home=data_home, 167 | ) 168 | 169 | if X_test is None or y_test is None: 170 | X = X_train 171 | y = y_train 172 | 173 | train_indices = None 174 | test_indices = None 175 | else: 176 | X = np.concatenate((X_train, X_test)) 177 | y = np.concatenate((y_train, y_test)) 178 | 179 | train_indices = list(range(len(X_train))) 180 | test_indices = list(range(len(X_train), len(X))) 181 | 182 | if return_X_y: 183 | return X, y 184 | 185 | return Bunch( 186 | data=X, 187 | target=y, 188 | train_indices=train_indices, 189 | validation_indices=[], 190 | test_indices=test_indices, 191 | inner_cv=None, 192 | outer_cv=None, 193 | DESCR=DESCR, 194 | target_names=target_names, 195 | ) 196 | -------------------------------------------------------------------------------- /skdatasets/repositories/ucr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datasets from the UCR time series database. 3 | 4 | @author: Carlos Ramos Carreño 5 | @license: MIT 6 | """ 7 | from __future__ import annotations 8 | 9 | from pathlib import Path 10 | from typing import Final, Literal, Optional, Sequence, Tuple, Union, overload 11 | 12 | import numpy as np 13 | import scipy.io.arff 14 | from sklearn.utils import Bunch 15 | 16 | from .base import fetch_zip as _fetch_zip 17 | 18 | BASE_URL: Final = "https://www.timeseriesclassification.com/aeon-toolkit/" 19 | 20 | 21 | def _target_conversion( 22 | target: np.typing.NDArray[Union[int, str]], 23 | ) -> Tuple[np.typing.NDArray[int], Sequence[str]]: 24 | try: 25 | target_data = target.astype(int) 26 | target_names = np.unique(target_data).astype(str).tolist() 27 | except ValueError: 28 | target_names = np.unique(target).tolist() 29 | target_data = np.searchsorted(target_names, target) 30 | 31 | return target_data, target_names 32 | 33 | 34 | def data_to_matrix( 35 | struct_array: np.typing.NDArray[object], 36 | ) -> np.typing.NDArray[float]: 37 | fields = struct_array.dtype.fields 38 | assert fields 39 | if len(fields.items()) == 1 and list(fields.items())[0][1][0] == np.dtype( 40 | np.object_ 41 | ): 42 | attribute = struct_array[list(fields.items())[0][0]] 43 | 44 | n_instances = len(attribute) 45 | n_curves = len(attribute[0]) 46 | n_points = len(attribute[0][0]) 47 | 48 | attribute_new = np.zeros(n_instances, dtype=np.object_) 49 | 50 | for i in range(n_instances): 51 | 52 | transformed_matrix = np.zeros((n_curves, n_points)) 53 | 54 | for j in range(n_curves): 55 | for k in range(n_points): 56 | transformed_matrix[j][k] = attribute[i][j][k] 57 | attribute_new[i] = transformed_matrix 58 | 59 | return attribute_new 60 | 61 | else: 62 | return np.array(struct_array.tolist()) 63 | 64 | 65 | @overload 66 | def fetch( 67 | name: str, 68 | *, 69 | data_home: Optional[str] = None, 70 | return_X_y: Literal[False] = False, 71 | ) -> Bunch: 72 | pass 73 | 74 | 75 | @overload 76 | def fetch( 77 | name: str, 78 | *, 79 | data_home: Optional[str] = None, 80 | return_X_y: Literal[True], 81 | ) -> Tuple[np.typing.NDArray[float], np.typing.NDArray[int]]: 82 | pass 83 | 84 | 85 | def fetch( 86 | name: str, 87 | *, 88 | data_home: Optional[str] = None, 89 | return_X_y: bool = False, 90 | ) -> Union[Bunch, Tuple[np.typing.NDArray[float], np.typing.NDArray[int]], ]: 91 | """ 92 | Fetch UCR dataset. 93 | 94 | Fetch a UCR dataset by name. More info at 95 | http://www.timeseriesclassification.com/. 96 | 97 | Parameters 98 | ---------- 99 | name : string 100 | Dataset name. 101 | data_home : string or None, default None 102 | Specify another download and cache folder for the data sets. By default 103 | all scikit-learn data is stored in ‘~/scikit_learn_data’ subfolders. 104 | return_X_y : bool, default=False 105 | If True, returns ``(data, target)`` instead of a Bunch object. 106 | 107 | Returns 108 | ------- 109 | data : Bunch 110 | Dictionary-like object with all the data and metadata. 111 | 112 | (data, target) : tuple if ``return_X_y`` is True 113 | 114 | """ 115 | url = BASE_URL + name 116 | 117 | data_path = _fetch_zip( 118 | name, 119 | urlname=url + ".zip", 120 | subfolder="ucr", 121 | data_home=data_home, 122 | ) 123 | 124 | description_filenames = [name, name + "Description", name + "_Info"] 125 | 126 | path_file_descr: Optional[Path] 127 | for f in description_filenames: 128 | path_file_descr = (data_path / f).with_suffix(".txt") 129 | if path_file_descr.exists(): 130 | break 131 | else: 132 | # No description is found 133 | path_file_descr = None 134 | 135 | path_file_train = (data_path / (name + "_TRAIN")).with_suffix(".arff") 136 | path_file_test = (data_path / (name + "_TEST")).with_suffix(".arff") 137 | 138 | DESCR = ( 139 | path_file_descr.read_text( 140 | errors="surrogateescape") if path_file_descr else "" 141 | ) 142 | train = scipy.io.arff.loadarff(path_file_train) 143 | test = scipy.io.arff.loadarff(path_file_test) 144 | dataset_name = train[1].name 145 | column_names = np.array(train[1].names()) 146 | target_column_name = column_names[-1] 147 | feature_names = column_names[column_names != target_column_name].tolist() 148 | target_column = train[0][target_column_name].astype(str) 149 | test_target_column = test[0][target_column_name].astype(str) 150 | y_train, target_names = _target_conversion(target_column) 151 | y_test, target_names_test = _target_conversion(test_target_column) 152 | assert target_names == target_names_test 153 | X_train = data_to_matrix(train[0][feature_names]) 154 | X_test = data_to_matrix(test[0][feature_names]) 155 | 156 | X = np.concatenate((X_train, X_test)) 157 | y = np.concatenate((y_train, y_test)) 158 | 159 | if return_X_y: 160 | return X, y 161 | 162 | return Bunch( 163 | data=X, 164 | target=y, 165 | train_indices=list(range(len(X_train))), 166 | validation_indices=[], 167 | test_indices=list(range(len(X_train), len(X))), 168 | name=dataset_name, 169 | DESCR=DESCR, 170 | feature_names=feature_names, 171 | target_names=target_names, 172 | ) 173 | -------------------------------------------------------------------------------- /skdatasets/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddiazvico/scikit-datasets/d3b25a490aa6f9bb99e72d2b0cf17cbb73702b1f/skdatasets/tests/__init__.py -------------------------------------------------------------------------------- /skdatasets/tests/repositories/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | import numpy as np 9 | from sklearn.linear_model import Ridge 10 | from sklearn.model_selection import GridSearchCV, cross_validate 11 | from sklearn.pipeline import Pipeline 12 | from sklearn.preprocessing import StandardScaler 13 | 14 | 15 | def check_estimator(data): 16 | """Check that the dataset can be used to cross-validate an estimator.""" 17 | estimator = GridSearchCV( 18 | Pipeline( 19 | [("tr", StandardScaler(with_mean=False)), ("pred", Ridge(max_iter=4))] 20 | ), 21 | {"pred__alpha": [0.33, 0.66]}, 22 | cv=data.inner_cv, 23 | error_score=np.nan, 24 | ) 25 | if data.train_indices and data.test_indices: 26 | 27 | train_indices = data.train_indices 28 | 29 | train_indices += data.validation_indices 30 | 31 | estimator.fit( 32 | data.data[train_indices], 33 | y=data.target[train_indices], 34 | ) 35 | estimator.score(data.data[data.test_indices], y=data.target[data.test_indices]) 36 | else: 37 | if hasattr(data.outer_cv, "__iter__"): 38 | for X, y, X_test, y_test in data.outer_cv: 39 | estimator.fit(X, y=y) 40 | estimator.score(X_test, y=y_test) 41 | else: 42 | cross_validate(estimator, data.data, y=data.target, cv=data.outer_cv) 43 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_cran.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests. 3 | 4 | @author: Carlos Ramos Carreño 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.cran import fetch 9 | 10 | 11 | def test_cran_geyser(): 12 | """Tests CRAN geyser dataset.""" 13 | fetch("geyser") 14 | 15 | 16 | def test_cran_geyser_return_X_y(): 17 | """Tests CRAN geyser dataset.""" 18 | X, y = fetch("geyser", return_X_y=True) 19 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_forex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Forex loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from datetime import date 9 | 10 | from skdatasets.repositories.forex import fetch 11 | 12 | 13 | def test_forex_usd_eur(): 14 | """Tests forex USD-EUR dataset.""" 15 | data = fetch( 16 | start=date(2015, 1, 1), 17 | end=date(2015, 1, 31), 18 | currency_1="USD", 19 | currency_2="EUR", 20 | ) 21 | assert data.data.shape == (31, 1) 22 | 23 | 24 | def test_forex_usd_eur_return_X_y(): 25 | """Tests forex USD-EUR dataset.""" 26 | X, y = fetch( 27 | start=date(2015, 1, 1), 28 | end=date(2015, 1, 31), 29 | currency_1="USD", 30 | currency_2="EUR", 31 | return_X_y=True, 32 | ) 33 | assert X.shape == (31, 1) 34 | assert y is None 35 | 36 | 37 | def test_forex_btc_eur(): 38 | """Tests forex BTC-EUR dataset.""" 39 | data = fetch( 40 | start=date(2015, 1, 1), 41 | end=date(2015, 1, 31), 42 | currency_1="BTC", 43 | currency_2="EUR", 44 | ) 45 | assert data.data.shape == (31, 1) 46 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_keel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Keel loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.keel import fetch 9 | 10 | from . import check_estimator 11 | 12 | 13 | def check(data, shape, splits=1): 14 | """Check dataset properties.""" 15 | assert data.data.shape == shape 16 | assert data.target.shape[0] == shape[0] 17 | if splits > 1: 18 | assert len(list(data.outer_cv)) == splits 19 | else: 20 | assert data.outer_cv is None 21 | assert not data.train_indices 22 | assert not data.validation_indices 23 | assert not data.test_indices 24 | assert data.inner_cv is None 25 | check_estimator(data) 26 | 27 | 28 | def test_fetch_keel_abalone9_18(): 29 | """Tests Keel abalone9-18 dataset.""" 30 | data = fetch("imbalanced", "abalone9-18") 31 | check(data, (731, 10)) 32 | 33 | 34 | def test_fetch_keel_abalone9_18_return_X_y(): 35 | """Tests Keel abalone9-18 dataset.""" 36 | X, y = fetch("imbalanced", "abalone9-18", return_X_y=True) 37 | assert X.shape == (731, 10) 38 | assert y.shape == (731,) 39 | 40 | 41 | def test_fetch_keel_abalone9_18_folds(): 42 | """Tests Keel abalone9-18 dataset with folds.""" 43 | data = fetch("imbalanced", "abalone9-18", nfolds=5) 44 | check(data, (731, 10), splits=5) 45 | 46 | 47 | def test_fetch_keel_banana(): 48 | """Tests Keel banana dataset.""" 49 | data = fetch("classification", "banana") 50 | check(data, (5300, 2)) 51 | 52 | 53 | def test_fetch_keel_banana_folds(): 54 | """Tests Keel banana dataset with folds.""" 55 | data = fetch("classification", "banana", nfolds=5) 56 | check(data, (5300, 2), splits=5) 57 | 58 | 59 | def test_fetch_keel_banana_dobscv(): 60 | """Tests Keel banana dataset with dobscv folds.""" 61 | data = fetch("classification", "banana", nfolds=5, dobscv=True) 62 | check(data, (5300, 2), splits=5) 63 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_keras.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Keras loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.keras import fetch 9 | 10 | 11 | def check(data, n_samples_train, n_samples_test, n_features): 12 | """Check dataset properties.""" 13 | assert data.data.shape == (n_samples_train + n_samples_test, n_features) 14 | assert data.target.shape[0] == n_samples_train + n_samples_test 15 | assert len(data.train_indices) == n_samples_train 16 | assert len(data.test_indices) == n_samples_test 17 | assert not data.validation_indices 18 | 19 | 20 | def test_keras_mnist(): 21 | """Tests keras MNIST dataset.""" 22 | data = fetch("mnist") 23 | check(data, n_samples_train=60000, n_samples_test=10000, n_features=28 * 28) 24 | 25 | 26 | def test_keras_mnist_return_X_y(): 27 | """Tests keras MNIST dataset.""" 28 | X, y = fetch("mnist", return_X_y=True) 29 | assert X.shape == (70000, 28 * 28) 30 | assert y.shape == (70000,) 31 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_libsvm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the LIBSVM loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.libsvm import fetch 9 | 10 | from . import check_estimator 11 | 12 | 13 | def check( 14 | data, 15 | n_features, 16 | n_samples=None, 17 | n_samples_train=None, 18 | n_samples_validation=None, 19 | n_samples_test=None, 20 | n_samples_remaining=None, 21 | estimator=True, 22 | ): 23 | """Check dataset properties.""" 24 | if n_samples is None: 25 | n_samples = sum( 26 | n 27 | for n in [ 28 | n_samples_train, 29 | n_samples_validation, 30 | n_samples_test, 31 | n_samples_remaining, 32 | ] 33 | if n is not None 34 | ) 35 | 36 | assert data.data.shape == (n_samples, n_features) 37 | assert data.target.shape[0] == n_samples 38 | 39 | if n_samples_train is None: 40 | assert not data.train_indices 41 | else: 42 | assert len(data.train_indices) == n_samples_train 43 | 44 | if n_samples_validation is None: 45 | assert not data.validation_indices 46 | else: 47 | assert len(data.validation_indices) == n_samples_validation 48 | 49 | if n_samples_test is None: 50 | assert not data.test_indices 51 | else: 52 | assert len(data.test_indices) == n_samples_test 53 | 54 | if n_samples_validation is None: 55 | assert data.inner_cv is None 56 | else: 57 | assert data.inner_cv is not None 58 | 59 | assert data.outer_cv is None 60 | 61 | if estimator: 62 | check_estimator(data) 63 | 64 | 65 | def test_fetch_libsvm_australian(): 66 | """Tests LIBSVM australian dataset.""" 67 | data = fetch("binary", "australian") 68 | check(data, n_samples=690, n_features=14) 69 | 70 | 71 | def test_fetch_libsvm_australian_return_X_y(): 72 | """Tests LIBSVM australian dataset.""" 73 | X, y = fetch("binary", "australian", return_X_y=True) 74 | assert X.shape == (690, 14) 75 | assert y.shape == (690,) 76 | 77 | 78 | def test_fetch_libsvm_liver_disorders(): 79 | """Tests LIBSVM liver-disorders dataset.""" 80 | data = fetch("binary", "liver-disorders") 81 | check(data, n_samples_train=145, n_samples_test=200, n_features=5) 82 | 83 | 84 | def test_fetch_libsvm_duke(): 85 | """Tests LIBSVM duke dataset.""" 86 | data = fetch("binary", "duke") 87 | check( 88 | data, 89 | n_samples_train=38, 90 | n_samples_validation=4, 91 | n_features=7129, 92 | estimator=False, 93 | ) 94 | 95 | 96 | def test_fetch_libsvm_cod_rna(): 97 | """Tests LIBSVM cod-rna dataset.""" 98 | data = fetch("binary", "cod-rna") 99 | check( 100 | data, 101 | n_samples_train=59535, 102 | n_samples_test=271617, 103 | n_samples_remaining=157413, 104 | n_features=8, 105 | ) 106 | 107 | 108 | def test_fetch_libsvm_satimage(): 109 | """Tests LIBSVM satimage dataset.""" 110 | data = fetch("multiclass", "satimage.scale") 111 | check( 112 | data, 113 | n_samples_train=3104, 114 | n_samples_test=2000, 115 | n_samples_validation=1331, 116 | n_features=36, 117 | ) 118 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_physionet.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from skdatasets.repositories.physionet import fetch 4 | 5 | 6 | def test_fetch_ctu_uhb_ctgdb() -> None: 7 | """Tests ctu_uhb dataset.""" 8 | X, y = fetch( 9 | "ctu-uhb-ctgdb", 10 | return_X_y=True, 11 | target_column=["pH", "BDecf", "pCO2", "BE", "Apgar1", "Apgar5"], 12 | ) 13 | assert X.shape == (552, 30) 14 | assert y.shape == (552, 6) 15 | 16 | 17 | def test_fetch_ctu_uhb_ctgdb_single_target() -> None: 18 | """Tests ctu_uhb dataset with one target.""" 19 | X, y = fetch( 20 | "ctu-uhb-ctgdb", 21 | return_X_y=True, 22 | target_column="pH", 23 | ) 24 | assert X.shape == (552, 35) 25 | assert y.shape == (552,) 26 | 27 | 28 | def test_fetch_ctu_uhb_ctgdb_bunch() -> None: 29 | """Tests ctu_uhb dataset returning Bunch.""" 30 | bunch = fetch( 31 | "ctu-uhb-ctgdb", 32 | as_frame=True, 33 | target_column=["pH", "BDecf", "pCO2", "BE", "Apgar1", "Apgar5"], 34 | ) 35 | assert bunch.data.shape == (552, 30) 36 | assert bunch.target.shape == (552, 6) 37 | assert bunch.frame.shape == (552, 36) 38 | 39 | 40 | def test_fetch_macecgdb() -> None: 41 | """Tests macecgdb dataset.""" 42 | bunch = fetch( 43 | "macecgdb", 44 | as_frame=True, 45 | ) 46 | assert bunch.data.shape == (27, 5) 47 | assert bunch.target == None 48 | assert bunch.frame.shape == (27, 5) 49 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_raetsch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Raetsch loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.raetsch import fetch 9 | 10 | from . import check_estimator 11 | 12 | 13 | def check(data, shape, splits=100): 14 | """Check dataset properties.""" 15 | assert data.data.shape == shape 16 | assert data.target.shape[0] == shape[0] 17 | assert len(list(data.outer_cv)) == splits 18 | check_estimator(data) 19 | 20 | 21 | def test_fetch_raetsch_banana(): 22 | """Tests Gunnar Raetsch banana dataset.""" 23 | data = fetch("banana") 24 | check(data, (5300, 2), splits=100) 25 | 26 | 27 | def test_fetch_raetsch_banana_return_X_y(): 28 | """Tests Gunnar Raetsch banana dataset.""" 29 | X, y = fetch("banana", return_X_y=True) 30 | assert X.shape == (5300, 2) 31 | assert y.shape == (5300,) 32 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_sklearn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Scikit-learn loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.sklearn import fetch 9 | 10 | from . import check_estimator 11 | 12 | 13 | def test_sklearn_iris(): 14 | """Tests Scikit-learn iris dataset.""" 15 | data = fetch("iris") 16 | assert data.data.shape == (150, 4) 17 | check_estimator(data) 18 | 19 | 20 | def test_sklearn_iris_return_X_y(): 21 | """Tests Scikit-learn iris dataset.""" 22 | X, y = fetch("iris", return_X_y=True) 23 | assert X.shape == (150, 4) 24 | assert y.shape == (150,) 25 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_uci.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the UCI loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.uci import fetch 9 | 10 | 11 | def test_fetch_uci_wine(): 12 | """Tests UCI wine dataset.""" 13 | data = fetch("wine") 14 | assert data.data.shape == (178, 13) 15 | assert data.target.shape[0] == data.data.shape[0] 16 | assert not data.train_indices 17 | assert not data.validation_indices 18 | assert not data.test_indices 19 | assert data.inner_cv is None 20 | assert data.outer_cv is None 21 | 22 | 23 | def test_fetch_uci_wine_return_X_y(): 24 | """Tests UCI wine dataset.""" 25 | X, y = fetch("wine", return_X_y=True) 26 | assert X.shape == (178, 13) 27 | assert y.shape == (178,) 28 | -------------------------------------------------------------------------------- /skdatasets/tests/repositories/test_ucr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the UCR loader. 3 | 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | from skdatasets.repositories.ucr import fetch 9 | 10 | 11 | def test_fetch_ucr_gunpoint(): 12 | """Tests UCR GunPoint dataset.""" 13 | data = fetch("GunPoint") 14 | assert data.data.shape == (200, 150) 15 | assert len(data.train_indices) == 50 16 | assert len(data.test_indices) == 150 17 | 18 | 19 | def test_fetch_ucr_gunpoint_return_X_y(): 20 | """Tests UCR GunPoint dataset.""" 21 | X, y = fetch("GunPoint", return_X_y=True) 22 | assert X.shape == (200, 150) 23 | assert y.shape == (200,) 24 | 25 | 26 | def test_fetch_ucr_basicmotions(): 27 | """Tests UCR GunPoint dataset.""" 28 | data = fetch("BasicMotions") 29 | assert data.data.shape == (80,) 30 | assert len(data.train_indices) == 40 31 | assert len(data.test_indices) == 40 32 | -------------------------------------------------------------------------------- /skdatasets/tests/utils/LinearRegression.json: -------------------------------------------------------------------------------- 1 | { 2 | "py/object": "sklearn.model_selection.GridSearchCV", 3 | "py/state": { 4 | "estimator": { 5 | "py/object": "sklearn.linear_model.LinearRegression", 6 | "py/state": { 7 | "fit_intercept": true, 8 | "normalize": false, 9 | "copy_X": true, 10 | "n_jobs": null 11 | } 12 | }, 13 | "param_grid": { 14 | "fit_intercept": [true, false], 15 | "normalize": [true, false] 16 | }, 17 | "scoring": null, 18 | "n_jobs": 1, 19 | "pre_dispatch": "2*n_jobs", 20 | "iid": true, 21 | "cv": null, 22 | "refit": true, 23 | "verbose": 0, 24 | "error_score": 0.0, 25 | "return_train_score": false 26 | } 27 | } -------------------------------------------------------------------------------- /skdatasets/tests/utils/LinearRegressionCustom.json: -------------------------------------------------------------------------------- 1 | { 2 | "py/object": "sklearn.model_selection.GridSearchCV", 3 | "py/state": { 4 | "estimator": { 5 | "py/object": "skdatasets.tests.utils.linear_model.LinearRegressionCustom", 6 | "py/state": { 7 | "fit_intercept": true, 8 | "normalize": false, 9 | "copy_X": true, 10 | "n_jobs": null 11 | } 12 | }, 13 | "param_grid": { 14 | "fit_intercept": [true, false], 15 | "normalize": [true, false] 16 | }, 17 | "scoring": null, 18 | "n_jobs": 1, 19 | "pre_dispatch": "2*n_jobs", 20 | "iid": true, 21 | "cv": null, 22 | "refit": true, 23 | "verbose": 0, 24 | "error_score": 0.0, 25 | "return_train_score": false 26 | } 27 | } -------------------------------------------------------------------------------- /skdatasets/tests/utils/MLPClassifier.json: -------------------------------------------------------------------------------- 1 | { 2 | "py/object": "sklearn.model_selection.GridSearchCV", 3 | "py/state": { 4 | "estimator": { 5 | "py/object": "sklearn.pipeline.Pipeline", 6 | "py/state": { 7 | "steps": [ 8 | [ 9 | "std", 10 | { 11 | "py/object": "sklearn.preprocessing.StandardScaler", 12 | "py/state": { 13 | "copy": true, 14 | "with_mean": false, 15 | "with_std": true 16 | } 17 | } 18 | ], 19 | [ 20 | "mlp", 21 | { 22 | "py/object": "sklearn.neural_network.MLPClassifier", 23 | "py/state": { 24 | "hidden_layer_sizes": [10], 25 | "activation": "relu", 26 | "solver": "adam", 27 | "alpha": 0.0001, 28 | "batch_size": "auto", 29 | "learning_rate": "constant", 30 | "learning_rate_init": 0.001, 31 | "power_t": 0.5, 32 | "max_iter": 2, 33 | "shuffle": true, 34 | "random_state": null, 35 | "tol": 0.0001, 36 | "verbose": false, 37 | "warm_start": false, 38 | "momentum": 0.9, 39 | "nesterovs_momentum": true, 40 | "early_stopping": false, 41 | "validation_fraction": 0.1, 42 | "beta_1": 0.9, 43 | "beta_2": 0.999, 44 | "epsilon": 1e-08, 45 | "n_iter_no_change": 10 46 | } 47 | } 48 | ] 49 | ], 50 | "memory": null, 51 | "verbose": false 52 | } 53 | }, 54 | "param_grid": { 55 | "mlp__alpha": [1e-4, 1e-3, 1e-2, 1e-1, 1e0], 56 | "mlp__learning_rate_init": [0.001] 57 | }, 58 | "scoring": null, 59 | "n_jobs": 1, 60 | "pre_dispatch": "2*n_jobs", 61 | "iid": true, 62 | "cv": null, 63 | "refit": true, 64 | "verbose": 0, 65 | "error_score": 0.0, 66 | "return_train_score": false 67 | } 68 | } -------------------------------------------------------------------------------- /skdatasets/tests/utils/MLPRegressor.json: -------------------------------------------------------------------------------- 1 | { 2 | "py/object": "sklearn.model_selection.GridSearchCV", 3 | "py/state": { 4 | "estimator": { 5 | "py/object": "sklearn.pipeline.Pipeline", 6 | "py/state": { 7 | "steps": [ 8 | [ 9 | "std", 10 | { 11 | "py/object": "sklearn.preprocessing.StandardScaler", 12 | "py/state": { 13 | "copy": true, 14 | "with_mean": false, 15 | "with_std": true 16 | } 17 | } 18 | ], 19 | [ 20 | "mlp", 21 | { 22 | "py/object": "sklearn.neural_network.MLPRegressor", 23 | "py/state": { 24 | "hidden_layer_sizes": [10], 25 | "activation": "relu", 26 | "solver": "adam", 27 | "alpha": 0.0001, 28 | "batch_size": "auto", 29 | "learning_rate": "constant", 30 | "learning_rate_init": 0.001, 31 | "power_t": 0.5, 32 | "max_iter": 2, 33 | "shuffle": true, 34 | "random_state": null, 35 | "tol": 0.0001, 36 | "verbose": false, 37 | "warm_start": false, 38 | "momentum": 0.9, 39 | "nesterovs_momentum": true, 40 | "early_stopping": false, 41 | "validation_fraction": 0.1, 42 | "beta_1": 0.9, 43 | "beta_2": 0.999, 44 | "epsilon": 1e-08, 45 | "n_iter_no_change": 10 46 | } 47 | } 48 | ] 49 | ], 50 | "memory": null, 51 | "verbose": false 52 | } 53 | }, 54 | "param_grid": { 55 | "mlp__alpha": [1e-4, 1e-3, 1e-2, 1e-1, 1e0], 56 | "mlp__learning_rate_init": [0.001] 57 | }, 58 | "scoring": null, 59 | "n_jobs": 1, 60 | "pre_dispatch": "2*n_jobs", 61 | "iid": true, 62 | "cv": null, 63 | "refit": true, 64 | "verbose": 0, 65 | "error_score": 0.0, 66 | "return_train_score": false 67 | } 68 | } -------------------------------------------------------------------------------- /skdatasets/tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddiazvico/scikit-datasets/d3b25a490aa6f9bb99e72d2b0cf17cbb73702b1f/skdatasets/tests/utils/__init__.py -------------------------------------------------------------------------------- /skdatasets/tests/utils/linear_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | 6 | from sklearn.linear_model import LinearRegression 7 | 8 | 9 | LinearRegressionCustom = LinearRegression 10 | -------------------------------------------------------------------------------- /skdatasets/tests/utils/run.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | """ 4 | @author: David Diaz Vico 5 | @license: MIT 6 | """ 7 | 8 | import argparse 9 | from sacred.observers import FileStorageObserver 10 | 11 | from skdatasets import fetch 12 | from skdatasets.utils.estimator import json2estimator 13 | from skdatasets.utils.experiment import experiment 14 | 15 | 16 | def main( 17 | dataset=fetch, estimator=json2estimator, observers=[FileStorageObserver(".results")] 18 | ): 19 | parser = argparse.ArgumentParser(description="Run an experiment.") 20 | parser.add_argument("-r", "--repository", type=str, help="repository") 21 | parser.add_argument("-c", "--collection", type=str, default=None, help="collection") 22 | parser.add_argument("-d", "--dataset", type=str, help="dataset") 23 | parser.add_argument("-e", "--estimator", type=str, help="estimator") 24 | args = parser.parse_args() 25 | e = experiment(dataset, estimator) 26 | e.observers.extend(observers) 27 | e.run( 28 | config_updates={ 29 | "dataset": { 30 | "repository": args.repository, 31 | "collection": args.collection, 32 | "dataset": args.dataset, 33 | }, 34 | "estimator": {"estimator": args.estimator}, 35 | } 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /skdatasets/tests/utils/test_estimator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | 6 | from sklearn.model_selection import GridSearchCV 7 | 8 | from skdatasets.utils.estimator import json2estimator 9 | 10 | 11 | def test_json2estimator(): 12 | """Tests instantiation of estimator from a json file.""" 13 | import sklearn 14 | 15 | e = json2estimator("skdatasets/tests/utils/LinearRegression.json") 16 | assert type(e) == GridSearchCV 17 | 18 | 19 | def test_json2estimator_custom(): 20 | """Tests instantiation of a custom estimator from a json file.""" 21 | import skdatasets 22 | 23 | e = json2estimator("skdatasets/tests/utils/LinearRegressionCustom.json") 24 | assert type(e) == GridSearchCV 25 | -------------------------------------------------------------------------------- /skdatasets/tests/utils/test_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | from __future__ import annotations 6 | 7 | import pytest 8 | import tempfile 9 | from typing import TYPE_CHECKING, Iterable, Tuple, Union 10 | 11 | import numpy as np 12 | import pytest 13 | from sacred.observers import FileStorageObserver 14 | from sklearn.datasets import load_diabetes, load_iris, load_wine 15 | from sklearn.model_selection import GridSearchCV, train_test_split 16 | from sklearn.neighbors import KNeighborsClassifier 17 | from sklearn.tree import DecisionTreeRegressor 18 | from sklearn.utils import Bunch 19 | 20 | from skdatasets.utils.experiment import ( 21 | ScorerLike, 22 | create_experiments, 23 | experiment, 24 | fetch_scores, 25 | run_experiments, 26 | ) 27 | 28 | if TYPE_CHECKING: 29 | from skdatasets.utils.experiment import CVLike 30 | 31 | ExplicitSplitType = Tuple[ 32 | np.typing.NDArray[float], 33 | np.typing.NDArray[Union[float, int]], 34 | np.typing.NDArray[float], 35 | np.typing.NDArray[Union[float, int]], 36 | ] 37 | 38 | 39 | def _dataset( 40 | inner_cv: CVLike = None, 41 | outer_cv: CVLike = None, 42 | ) -> Bunch: 43 | data = load_diabetes() 44 | if outer_cv is None: 45 | X, X_test, y, y_test = train_test_split(data.data, data.target) 46 | data.data = X 47 | data.target = y 48 | data.data_test = X_test 49 | data.target_test = y_test 50 | data.outer_cv = None 51 | else: 52 | data.data_test = data.target_test = None 53 | data.outer_cv = outer_cv 54 | data.inner_cv = inner_cv 55 | 56 | return data 57 | 58 | 59 | def _estimator(cv: CVLike) -> GridSearchCV: 60 | return GridSearchCV( 61 | DecisionTreeRegressor(), 62 | {"max_depth": [2, 4]}, 63 | cv=cv, 64 | ) 65 | 66 | 67 | def _experiment( 68 | inner_cv: CVLike, 69 | outer_cv: CVLike | Iterable[ExplicitSplitType], 70 | ) -> None: 71 | with tempfile.TemporaryDirectory() as tmpdirname: 72 | e = experiment(_dataset, _estimator) 73 | e.observers.append(FileStorageObserver(tmpdirname)) 74 | e.run( 75 | config_updates={ 76 | "dataset": { 77 | "inner_cv": inner_cv, 78 | "outer_cv": outer_cv, 79 | }, 80 | }, 81 | ) 82 | 83 | 84 | @pytest.mark.skip(reason="Waiting for Sacred to be fixed.") 85 | def test_nested_cv() -> None: 86 | """Tests nested CV experiment.""" 87 | _experiment(3, 3) 88 | 89 | 90 | @pytest.mark.skip(reason="Waiting for Sacred to be fixed.") 91 | def test_inner_cv() -> None: 92 | """Tests inner CV experiment.""" 93 | _experiment(3, None) 94 | 95 | 96 | @pytest.mark.skip(reason="Waiting for Sacred to be fixed.") 97 | def test_explicit_inner_folds() -> None: 98 | """Tests explicit inner folds experiment.""" 99 | X, y = load_diabetes(return_X_y=True) 100 | _experiment( 101 | [ 102 | (np.arange(10), np.arange(10, 20)), 103 | (np.arange(10, 20), np.arange(20, 30)), 104 | (np.arange(20, 30), np.arange(30, 40)), 105 | ], 106 | 3, 107 | ) 108 | 109 | 110 | @pytest.mark.skip(reason="Waiting for Sacred to be fixed.") 111 | def test_explicit_outer_folds_indexes() -> None: 112 | """Tests explicit outer folds experiment.""" 113 | X, y = load_diabetes(return_X_y=True) 114 | _experiment( 115 | 3, 116 | [ 117 | (np.arange(10), np.arange(10, 20)), 118 | (np.arange(10, 20), np.arange(20, 30)), 119 | (np.arange(20, 30), np.arange(30, 40)), 120 | ], 121 | ) 122 | 123 | 124 | @pytest.mark.skip(reason="Waiting for Sacred to be fixed.") 125 | def test_explicit_outer_folds() -> None: 126 | """Tests explicit outer folds experiment.""" 127 | X, y = load_diabetes(return_X_y=True) 128 | _experiment( 129 | 3, 130 | [ 131 | (X[:10], y[:10], X[10:20], y[10:20]), 132 | (X[10:20], y[10:20], X[20:30], y[20:30]), 133 | (X[20:30], y[20:30], X[30:40], y[30:40]), 134 | ], 135 | ) 136 | 137 | 138 | @pytest.mark.skip(reason="Waiting for Sacred to be fixed.") 139 | def test_explicit_nested_folds() -> None: 140 | """Tests explicit nested folds experiment.""" 141 | X, y = load_diabetes(return_X_y=True) 142 | _experiment( 143 | [ 144 | (np.arange(3, 10), np.arange(3)), 145 | (np.concatenate((np.arange(3), np.arange(7, 10))), np.arange(3, 7)), 146 | (np.arange(7, 10), np.arange(7)), 147 | ], 148 | [ 149 | (np.arange(10), np.arange(10, 20)), 150 | (np.arange(10, 20), np.arange(20, 30)), 151 | (np.arange(20, 30), np.arange(30, 40)), 152 | ], 153 | ) 154 | 155 | 156 | @pytest.mark.parametrize( 157 | ["scoring", "expected_mean", "expected_std"], 158 | [ 159 | ( 160 | None, 161 | [ 162 | [0.96666667, 0.97333333, 0.98], 163 | [0.70285714, 0.69126984, 0.68063492], 164 | ], 165 | [ 166 | [0.02108185, 0.02494438, 0.01632993], 167 | [0.07920396, 0.04877951, 0.0662983], 168 | ], 169 | ), 170 | ( 171 | "recall_micro", 172 | [ 173 | [0.96666667, 0.97333333, 0.98], 174 | [0.70285714, 0.69126984, 0.68063492], 175 | ], 176 | [ 177 | [0.02108185, 0.02494438, 0.01632993], 178 | [0.07920396, 0.04877951, 0.0662983], 179 | ], 180 | ), 181 | ], 182 | ) 183 | def test_create_experiments_basic( 184 | scoring: ScorerLike[np.typing.NDArray[np.float_], np.typing.NDArray[np.int_]], 185 | expected_mean: np.typing.NDArray[np.float_], 186 | expected_std: np.typing.NDArray[np.float_], 187 | ) -> None: 188 | 189 | with tempfile.TemporaryDirectory() as tmpdirname: 190 | experiments = create_experiments( 191 | estimators={ 192 | "knn-3": KNeighborsClassifier(n_neighbors=3), 193 | "knn-5": KNeighborsClassifier(n_neighbors=5), 194 | "knn-7": KNeighborsClassifier(n_neighbors=7), 195 | }, 196 | datasets={ 197 | "iris": load_iris(), 198 | "wine": load_wine(), 199 | }, 200 | scoring=scoring, 201 | storage=tmpdirname, 202 | ) 203 | 204 | ids = run_experiments(experiments) 205 | 206 | scores = fetch_scores( 207 | storage=tmpdirname, 208 | ids=ids, 209 | ) 210 | 211 | assert scores.dataset_names == ("iris", "wine") 212 | assert scores.estimator_names == ("knn-3", "knn-5", "knn-7") 213 | np.testing.assert_allclose( 214 | scores.scores_mean, 215 | expected_mean, 216 | ) 217 | np.testing.assert_allclose( 218 | scores.scores_std, 219 | expected_std, 220 | rtol=1e-6, 221 | ) 222 | -------------------------------------------------------------------------------- /skdatasets/tests/utils/test_run.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | 6 | import subprocess 7 | 8 | 9 | def test_binary_classification(): 10 | """Tests binary classification experiment.""" 11 | ret = subprocess.call( 12 | [ 13 | "skdatasets/tests/utils/run.py", 14 | "-r", 15 | "keel", 16 | "-c", 17 | "imbalanced", 18 | "-d", 19 | "abalone9-18", 20 | "-e", 21 | "skdatasets/tests/utils/MLPClassifier.json", 22 | ] 23 | ) 24 | assert ret >= 0 25 | ret = subprocess.call( 26 | [ 27 | "skdatasets/tests/utils/run.py", 28 | "-r", 29 | "libsvm", 30 | "-c", 31 | "binary", 32 | "-d", 33 | "breast-cancer", 34 | "-e", 35 | "skdatasets/tests/utils/MLPClassifier.json", 36 | ] 37 | ) 38 | assert ret >= 0 39 | ret = subprocess.call( 40 | [ 41 | "skdatasets/tests/utils/run.py", 42 | "-r", 43 | "raetsch", 44 | "-d", 45 | "banana", 46 | "-e", 47 | "skdatasets/tests/utils/MLPClassifier.json", 48 | ] 49 | ) 50 | assert ret >= 0 51 | 52 | 53 | def test_multiclass_classification(): 54 | """Tests multiclass classification experiment.""" 55 | ret = subprocess.call( 56 | [ 57 | "skdatasets/tests/utils/run.py", 58 | "-r", 59 | "sklearn", 60 | "-d", 61 | "iris", 62 | "-e", 63 | "skdatasets/tests/utils/MLPClassifier.json", 64 | ] 65 | ) 66 | assert ret >= 0 67 | ret = subprocess.call( 68 | [ 69 | "skdatasets/tests/utils/run.py", 70 | "-r", 71 | "uci", 72 | "-d", 73 | "wine", 74 | "-e", 75 | "skdatasets/tests/utils/MLPClassifier.json", 76 | ] 77 | ) 78 | assert ret >= 0 79 | ret = subprocess.call( 80 | [ 81 | "skdatasets/tests/utils/run.py", 82 | "-r", 83 | "libsvm", 84 | "-c", 85 | "multiclass", 86 | "-d", 87 | "shuttle", 88 | "-e", 89 | "skdatasets/tests/utils/MLPClassifier.json", 90 | ] 91 | ) 92 | assert ret >= 0 93 | ret = subprocess.call( 94 | [ 95 | "skdatasets/tests/utils/run.py", 96 | "-r", 97 | "libsvm", 98 | "-c", 99 | "multiclass", 100 | "-d", 101 | "usps", 102 | "-e", 103 | "skdatasets/tests/utils/MLPClassifier.json", 104 | ] 105 | ) 106 | assert ret >= 0 107 | 108 | 109 | def test_regression(): 110 | """Tests regression experiment.""" 111 | ret = subprocess.call( 112 | [ 113 | "skdatasets/tests/utils/run.py", 114 | "-r", 115 | "libsvm", 116 | "-c", 117 | "regression", 118 | "-d", 119 | "housing", 120 | "-e", 121 | "skdatasets/tests/utils/MLPRegressor.json", 122 | ] 123 | ) 124 | assert ret >= 0 125 | -------------------------------------------------------------------------------- /skdatasets/tests/utils/test_scores.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | 6 | import numpy as np 7 | 8 | from skdatasets.utils.scores import hypotheses_table, scores_table 9 | 10 | datasets = [ 11 | "a4a", 12 | "a8a", 13 | "combined", 14 | "dna", 15 | "ijcnn1", 16 | "letter", 17 | "pendigits", 18 | "satimage", 19 | "shuttle", 20 | "usps", 21 | "w7a", 22 | "w8a", 23 | ] 24 | estimators = [ 25 | "LogisticRegression", 26 | "MLPClassifier0", 27 | "MLPClassifier1", 28 | "MLPClassifier2", 29 | "MLPClassifier3", 30 | "MLPClassifier4", 31 | "MLPClassifier5", 32 | ] 33 | scores = np.asarray( 34 | ( 35 | (89.79, 89.78, 89.76, 89.88, 89.85, 89.91, 89.93), 36 | (90.73, 90.73, 90.73, 90.85, 90.83, 90.81, 90.80), 37 | (92.36, 92.31, 94.58, 94.82, 94.84, 94.92, 94.89), 38 | (99.28, 99.27, 99.28, 99.26, 99.27, 99.25, 99.25), 39 | (91.34, 91.34, 99.29, 99.33, 99.34, 99.53, 99.54), 40 | (98.07, 98.04, 99.94, 99.95, 99.96, 99.96, 99.95), 41 | (99.17, 99.08, 99.87, 99.87, 99.88, 99.90, 99.89), 42 | (96.67, 96.28, 98.84, 98.87, 98.90, 98.87, 98.92), 43 | (95.85, 92.83, 99.88, 99.93, 99.96, 99.98, 99.99), 44 | (99.12, 99.11, 99.65, 99.58, 99.58, 99.65, 99.60), 45 | (95.93, 95.40, 94.58, 96.31, 96.34, 96.58, 96.50), 46 | (95.80, 95.99, 95.35, 96.20, 96.22, 96.36, 96.71), 47 | ) 48 | ) 49 | 50 | 51 | def test_scores_table() -> None: 52 | """Tests scores table.""" 53 | scores_table(scores, datasets=datasets, estimators=estimators) 54 | scores_table( 55 | scores, 56 | stds=scores / 10.0, 57 | datasets=datasets, 58 | estimators=estimators, 59 | ) 60 | 61 | 62 | def test_hypotheses_table() -> None: 63 | """Tests hypotheses table.""" 64 | for multitest in ("kruskal", "friedmanchisquare", None): 65 | for test in ("mannwhitneyu", "wilcoxon"): 66 | hypotheses_table(scores, estimators, multitest=multitest, test=test) 67 | for correction in ( 68 | "bonferroni", 69 | "sidak", 70 | "holm-sidak", 71 | "holm", 72 | "simes-hochberg", 73 | "hommel", 74 | "fdr_bh", 75 | "fdr_by", 76 | "fdr_tsbh", 77 | "fdr_tsbky", 78 | ): 79 | hypotheses_table( 80 | scores, 81 | estimators, 82 | multitest=multitest, 83 | test=test, 84 | correction=correction, 85 | ) 86 | -------------------------------------------------------------------------------- /skdatasets/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daviddiazvico/scikit-datasets/d3b25a490aa6f9bb99e72d2b0cf17cbb73702b1f/skdatasets/utils/__init__.py -------------------------------------------------------------------------------- /skdatasets/utils/estimator.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | 6 | import jsonpickle 7 | 8 | 9 | def json2estimator(estimator, **kwargs): 10 | """Instantiate a Scikit-learn estimator from a json file. 11 | 12 | Instantiate a Scikit-learn estimator from a json file passing its path as 13 | argument. 14 | 15 | Parameters 16 | ---------- 17 | estimator : str 18 | Path of the json file containing the estimator specification. 19 | **kwargs : dict 20 | Dictionary of optional keyword arguments. 21 | 22 | Returns 23 | ------- 24 | estimator : Estimator 25 | Instantiated Scikit-learn estimator. 26 | 27 | """ 28 | with open(estimator, "r") as definition: 29 | estimator = jsonpickle.decode(definition.read()) 30 | for k, v in kwargs.items(): 31 | setattr(estimator, k, v) 32 | return estimator 33 | -------------------------------------------------------------------------------- /skdatasets/utils/experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | from __future__ import annotations 6 | 7 | import itertools 8 | from contextlib import contextmanager 9 | from dataclasses import dataclass 10 | from time import perf_counter, sleep 11 | from typing import ( 12 | Any, 13 | Callable, 14 | Dict, 15 | Iterable, 16 | Iterator, 17 | List, 18 | Literal, 19 | Mapping, 20 | Protocol, 21 | Sequence, 22 | Tuple, 23 | TypeVar, 24 | Union, 25 | ) 26 | from warnings import warn 27 | 28 | import numpy as np 29 | from sacred import Experiment, Ingredient 30 | from sacred.observers import FileStorageObserver, MongoObserver, RunObserver 31 | from sklearn.base import BaseEstimator, is_classifier 32 | from sklearn.metrics import check_scoring 33 | from sklearn.model_selection import check_cv 34 | from sklearn.utils import Bunch 35 | 36 | from incense import ExperimentLoader, FileSystemExperimentLoader 37 | from incense.experiment import FileSystemExperiment 38 | 39 | SelfType = TypeVar("SelfType") 40 | 41 | 42 | class DataLike(Protocol): 43 | def __getitem__( 44 | self: SelfType, 45 | key: np.typing.NDArray[int], 46 | ) -> SelfType: 47 | pass 48 | 49 | def __len__(self) -> int: 50 | pass 51 | 52 | 53 | DataType = TypeVar("DataType", bound=DataLike, contravariant=True) 54 | TargetType = TypeVar("TargetType", bound=DataLike) 55 | IndicesType = Tuple[np.typing.NDArray[int], np.typing.NDArray[int]] 56 | ExplicitSplitType = Tuple[ 57 | np.typing.NDArray[float], 58 | np.typing.NDArray[Union[float, int]], 59 | np.typing.NDArray[float], 60 | np.typing.NDArray[Union[float, int]], 61 | ] 62 | 63 | ConfigLike = Union[ 64 | Mapping[str, Any], 65 | str, 66 | ] 67 | ScorerLike = Union[ 68 | str, 69 | Callable[[BaseEstimator, DataType, TargetType], float], 70 | None, 71 | ] 72 | 73 | 74 | class EstimatorProtocol(Protocol[DataType, TargetType]): 75 | def fit(self: SelfType, X: DataType, y: TargetType) -> SelfType: 76 | pass 77 | 78 | def predict(self, X: DataType) -> TargetType: 79 | pass 80 | 81 | 82 | class CVSplitter(Protocol): 83 | def split( 84 | self, 85 | X: np.typing.NDArray[float], 86 | y: None = None, 87 | groups: None = None, 88 | ) -> Iterable[IndicesType]: 89 | pass 90 | 91 | def get_n_splits( 92 | self, 93 | X: np.typing.NDArray[float], 94 | y: None = None, 95 | groups: None = None, 96 | ) -> int: 97 | pass 98 | 99 | 100 | CVLike = Union[ 101 | CVSplitter, 102 | Iterable[IndicesType], 103 | int, 104 | None, 105 | ] 106 | 107 | EstimatorLike = Union[ 108 | EstimatorProtocol[Any, Any], 109 | Callable[..., EstimatorProtocol[Any, Any]], 110 | Tuple[Callable[..., EstimatorProtocol[Any, Any]], ConfigLike], 111 | ] 112 | 113 | DatasetLike = Union[ 114 | Bunch, 115 | Callable[..., Bunch], 116 | Tuple[Callable[..., Bunch], ConfigLike], 117 | ] 118 | 119 | 120 | @dataclass 121 | class ScoresInfo: 122 | r""" 123 | Class containing the scores of several related experiments. 124 | 125 | Attributes 126 | ---------- 127 | dataset_names : Sequence of :external:class:`str` 128 | Name of the datasets, with the same order in which are present 129 | in the rows of the scores. 130 | estimator_names : Sequence of :external:class:`str` 131 | Name of the estimators, with the same order in which are present 132 | in the columns of the scores. 133 | scores : :external:class:`numpy.ndarray` 134 | Test scores. It has size ``n_datasets`` :math:`\times` ``n_estimators`` 135 | :math:`\times` ``n_partitions``. 136 | scores_mean : :external:class:`numpy.ndarray` 137 | Test score means. It has size ``n_datasets`` 138 | :math:`\times` ``n_estimators``. 139 | scores_std : :external:class:`numpy.ndarray` 140 | Test score standard deviations. It has size ``n_datasets`` 141 | :math:`\times` ``n_estimators``. 142 | 143 | See Also 144 | -------- 145 | fetch_scores 146 | 147 | """ 148 | dataset_names: Sequence[str] 149 | estimator_names: Sequence[str] 150 | scores: np.typing.NDArray[float] 151 | scores_mean: np.typing.NDArray[float] 152 | scores_std: np.typing.NDArray[float] 153 | 154 | 155 | def _append_info(experiment: Experiment, name: str, value: Any) -> None: 156 | info_list = experiment.info.get(name, []) 157 | info_list.append(value) 158 | experiment.info[name] = info_list 159 | 160 | 161 | @contextmanager 162 | def _add_timing(experiment: Experiment, name: str) -> Iterator[None]: 163 | initial_time = perf_counter() 164 | try: 165 | yield None 166 | finally: 167 | final_time = perf_counter() 168 | elapsed_time = final_time - initial_time 169 | _append_info(experiment, name, elapsed_time) 170 | 171 | 172 | def _iterate_outer_cv( 173 | outer_cv: CVLike | Iterable[Tuple[DataType, TargetType, DataType, TargetType]], 174 | estimator: EstimatorProtocol[DataType, TargetType], 175 | X: DataType, 176 | y: TargetType, 177 | ) -> Iterable[Tuple[DataType, TargetType, DataType, TargetType]]: 178 | """Iterate over multiple partitions.""" 179 | if isinstance(outer_cv, Iterable): 180 | outer_cv, cv_copy = itertools.tee(outer_cv) 181 | if len(next(cv_copy)) == 4: 182 | yield from outer_cv 183 | 184 | cv = check_cv(outer_cv, y, classifier=is_classifier(estimator)) 185 | yield from ( 186 | (X[train], y[train], X[test], y[test]) for train, test in cv.split(X, y) 187 | ) 188 | 189 | 190 | def _benchmark_from_data( 191 | experiment: Experiment, 192 | *, 193 | estimator: BaseEstimator, 194 | X_train: DataType, 195 | y_train: TargetType, 196 | X_test: DataType, 197 | y_test: TargetType, 198 | scoring: ScorerLike[DataType, TargetType] = None, 199 | save_estimator: bool = False, 200 | save_train: bool = False, 201 | ) -> None: 202 | 203 | scoring_fun = check_scoring(estimator, scoring) 204 | 205 | with _add_timing(experiment, "fit_time"): 206 | estimator.fit(X_train, y_train) 207 | 208 | if save_estimator: 209 | _append_info(experiment, "fitted_estimator", estimator) 210 | 211 | best_params = getattr(estimator, "best_params_", None) 212 | if best_params: 213 | _append_info(experiment, "search_best_params", best_params) 214 | 215 | best_score = getattr(estimator, "best_score_", None) 216 | if best_params: 217 | _append_info(experiment, "search_best_score", best_score) 218 | 219 | with _add_timing(experiment, "score_time"): 220 | test_score = scoring_fun(estimator, X_test, y_test) 221 | 222 | _append_info(experiment, "test_score", float(test_score)) 223 | 224 | if save_train: 225 | train_score = scoring_fun(estimator, X_train, y_train) 226 | _append_info(experiment, "train_score", float(train_score)) 227 | 228 | for output in ("transform", "predict"): 229 | method = getattr(estimator, output, None) 230 | if method is not None: 231 | with _add_timing(experiment, f"{output}_time"): 232 | _append_info(experiment, f"{output}", method(X_test)) 233 | 234 | 235 | def _compute_means(experiment: Experiment) -> None: 236 | 237 | experiment.info["score_mean"] = float(np.nanmean(experiment.info["test_score"])) 238 | experiment.info["score_std"] = float(np.nanstd(experiment.info["test_score"])) 239 | 240 | 241 | def _benchmark_one( 242 | experiment: Experiment, 243 | *, 244 | estimator: BaseEstimator, 245 | data: Bunch, 246 | scoring: ScorerLike[DataType, TargetType] = None, 247 | save_estimator: bool = False, 248 | save_train: bool = False, 249 | ) -> None: 250 | """Use only one predefined partition.""" 251 | X = data.data 252 | y = data.target 253 | 254 | train_indices = getattr(data, "train_indices", []) 255 | validation_indices = getattr(data, "validation_indices", []) 256 | test_indices = getattr(data, "test_indices", []) 257 | 258 | X_train_val = X[train_indices + validation_indices] if train_indices else X 259 | y_train_val = y[train_indices + validation_indices] if train_indices else y 260 | 261 | X_test = X[test_indices] 262 | y_test = y[test_indices] 263 | 264 | _benchmark_from_data( 265 | experiment=experiment, 266 | estimator=estimator, 267 | X_train=X_train_val, 268 | y_train=y_train_val, 269 | X_test=X_test, 270 | y_test=y_test, 271 | scoring=scoring, 272 | save_estimator=save_estimator, 273 | save_train=save_train, 274 | ) 275 | 276 | _compute_means(experiment) 277 | 278 | 279 | def _benchmark_partitions( 280 | experiment: Experiment, 281 | *, 282 | estimator: BaseEstimator, 283 | data: Bunch, 284 | scoring: ScorerLike[DataType, TargetType] = None, 285 | save_estimator: bool = False, 286 | save_train: bool = False, 287 | outer_cv: CVLike | Literal["dataset"] = None, 288 | ) -> None: 289 | """Use several partitions.""" 290 | outer_cv = data.outer_cv if outer_cv == "dataset" else outer_cv 291 | 292 | for X_train, y_train, X_test, y_test in _iterate_outer_cv( 293 | outer_cv=outer_cv, 294 | estimator=estimator, 295 | X=data.data, 296 | y=data.target, 297 | ): 298 | 299 | _benchmark_from_data( 300 | experiment=experiment, 301 | estimator=estimator, 302 | X_train=X_train, 303 | y_train=y_train, 304 | X_test=X_test, 305 | y_test=y_test, 306 | scoring=scoring, 307 | save_estimator=save_estimator, 308 | save_train=save_train, 309 | ) 310 | 311 | _compute_means(experiment) 312 | 313 | 314 | def _benchmark( 315 | experiment: Experiment, 316 | *, 317 | estimator: BaseEstimator, 318 | data: Bunch, 319 | scoring: ScorerLike[DataType, TargetType] = None, 320 | save_estimator: bool = False, 321 | save_train: bool = False, 322 | outer_cv: CVLike | Literal[False, "dataset"] = None, 323 | ) -> None: 324 | """Run the experiment.""" 325 | if outer_cv is False: 326 | _benchmark_one( 327 | experiment=experiment, 328 | estimator=estimator, 329 | data=data, 330 | scoring=scoring, 331 | save_estimator=save_estimator, 332 | save_train=save_train, 333 | ) 334 | else: 335 | _benchmark_partitions( 336 | experiment=experiment, 337 | estimator=estimator, 338 | data=data, 339 | scoring=scoring, 340 | save_estimator=save_estimator, 341 | save_train=save_train, 342 | outer_cv=outer_cv, 343 | ) 344 | 345 | 346 | def experiment( 347 | dataset: Callable[..., Bunch], 348 | estimator: Callable[..., BaseEstimator], 349 | *, 350 | scoring: ScorerLike[DataType, TargetType] = None, 351 | save_estimator: bool = False, 352 | save_train: bool = False, 353 | ) -> Experiment: 354 | """ 355 | Prepare a Scikit-learn experiment as a Sacred experiment. 356 | 357 | Prepare a Scikit-learn experiment indicating a dataset and an estimator and 358 | return it as a Sacred experiment. 359 | 360 | Parameters 361 | ---------- 362 | dataset : function 363 | Dataset fetch function. Might receive any argument. Must return a 364 | :external:class:`sklearn.utils.Bunch` with ``data``, ``target`` 365 | (might be ``None``), ``inner_cv`` (might be ``None``) and ``outer_cv`` 366 | (might be ``None``). 367 | estimator : function 368 | Estimator initialization function. Might receive any keyword argument. 369 | Must return an initialized sklearn-compatible estimator. 370 | 371 | Returns 372 | ------- 373 | experiment : Experiment 374 | Sacred experiment, ready to be run. 375 | 376 | """ 377 | dataset_ingredient = Ingredient("dataset") 378 | dataset = dataset_ingredient.capture(dataset) 379 | estimator_ingredient = Ingredient("estimator") 380 | estimator = estimator_ingredient.capture(estimator) 381 | experiment = Experiment( 382 | ingredients=( 383 | dataset_ingredient, 384 | estimator_ingredient, 385 | ), 386 | ) 387 | 388 | @experiment.main 389 | def run() -> None: 390 | """Run the experiment.""" 391 | data = dataset() 392 | 393 | # Metaparameter search 394 | cv = getattr(data, "inner_cv", None) 395 | 396 | try: 397 | e = estimator(cv=cv) 398 | except TypeError as exception: 399 | warn(f"The estimator does not accept cv: {exception}") 400 | e = estimator() 401 | 402 | # Model assessment 403 | _benchmark( 404 | experiment=experiment, 405 | estimator=e, 406 | data=data, 407 | scoring=scoring, 408 | save_estimator=save_estimator, 409 | save_train=save_train, 410 | ) 411 | 412 | # Ensure that everything is in the info dict at the end 413 | # See https://github.com/IDSIA/sacred/issues/830 414 | sleep(experiment.current_run.beat_interval + 1) 415 | 416 | return experiment 417 | 418 | 419 | def _get_estimator_function( 420 | experiment: Experiment, 421 | estimator: EstimatorLike, 422 | ) -> Callable[..., EstimatorProtocol[Any, Any]]: 423 | 424 | if hasattr(estimator, "fit"): 425 | 426 | def estimator_function() -> EstimatorProtocol: 427 | return estimator 428 | 429 | else: 430 | estimator_function = estimator 431 | 432 | return experiment.capture(estimator_function) 433 | 434 | 435 | def _get_dataset_function( 436 | experiment: Experiment, 437 | dataset: DatasetLike, 438 | ) -> Callable[..., Bunch]: 439 | 440 | if callable(dataset): 441 | dataset_function = dataset 442 | else: 443 | 444 | def dataset_function() -> Bunch: 445 | return dataset 446 | 447 | return experiment.capture(dataset_function) 448 | 449 | 450 | def _create_one_experiment( 451 | *, 452 | estimator_name: str, 453 | estimator: EstimatorLike, 454 | dataset_name: str, 455 | dataset: DatasetLike, 456 | storage: RunObserver, 457 | config: ConfigLike, 458 | inner_cv: CVLike | Literal[False, "dataset"] = None, 459 | outer_cv: CVLike | Literal[False, "dataset"] = None, 460 | scoring: ScorerLike[DataType, TargetType] = None, 461 | save_estimator: bool = False, 462 | save_train: bool = False, 463 | ) -> Experiment: 464 | experiment = Experiment() 465 | 466 | experiment.add_config(config) 467 | 468 | experiment.add_config({"estimator_name": estimator_name}) 469 | if isinstance(estimator, tuple): 470 | estimator, estimator_config = estimator 471 | experiment.add_config(estimator_config) 472 | 473 | experiment.add_config({"dataset_name": dataset_name}) 474 | if isinstance(dataset, tuple): 475 | dataset, dataset_config = dataset 476 | experiment.add_config(dataset_config) 477 | 478 | experiment.observers.append(storage) 479 | 480 | estimator_function = _get_estimator_function(experiment, estimator) 481 | dataset_function = _get_dataset_function(experiment, dataset) 482 | 483 | @experiment.main 484 | def run() -> None: 485 | """Run the experiment.""" 486 | dataset = dataset_function() 487 | 488 | # Metaparameter search 489 | cv = dataset.inner_cv if inner_cv == "dataset" else inner_cv 490 | 491 | estimator = estimator_function() 492 | if hasattr(estimator, "cv") and cv is not False: 493 | estimator.cv = cv 494 | 495 | # Model assessment 496 | _benchmark( 497 | experiment=experiment, 498 | estimator=estimator, 499 | data=dataset, 500 | scoring=scoring, 501 | save_estimator=save_estimator, 502 | save_train=save_train, 503 | outer_cv=outer_cv, 504 | ) 505 | 506 | return experiment 507 | 508 | 509 | def create_experiments( 510 | *, 511 | datasets: Mapping[str, DatasetLike], 512 | estimators: Mapping[str, EstimatorLike], 513 | storage: RunObserver | str, 514 | config: ConfigLike | None = None, 515 | inner_cv: CVLike | Literal[False, "dataset"] = False, 516 | outer_cv: CVLike | Literal[False, "dataset"] = None, 517 | scoring: ScorerLike[DataType, TargetType] = None, 518 | save_estimator: bool = False, 519 | save_train: bool = False, 520 | ) -> Sequence[Experiment]: 521 | """ 522 | Create several Sacred experiments. 523 | 524 | It receives a set of estimators and datasets, and create Sacred experiment 525 | objects for them. 526 | 527 | Parameters 528 | ---------- 529 | datasets : Mapping 530 | Mapping where each key is the name for a dataset and each value 531 | is either: 532 | 533 | * A :external:class:`sklearn.utils.Bunch` with the fields explained 534 | in :doc:`/structure`. Only ``data`` and ``target`` are 535 | mandatory. 536 | * A function receiving arbitrary config values and returning a 537 | :external:class:`sklearn.utils.Bunch` object like the one explained 538 | above. 539 | * A tuple with such a function and additional configuration (either 540 | a mapping or a filename). 541 | estimators : Mapping 542 | Mapping where each key is the name for a estimator and each value 543 | is either: 544 | 545 | * A scikit-learn compatible estimator. 546 | * A function receiving arbitrary config values and returning a 547 | scikit-learn compatible estimator. 548 | * A tuple with such a function and additional configuration (either 549 | a mapping or a filename). 550 | storage : :external:class:`sacred.observers.RunObserver` or :class:`str` 551 | Where the experiments will be stored. Either a Sacred observer, for 552 | example to store in a Mongo database, or the name of a directory, to 553 | use a file observer. 554 | config : Mapping, :class:`str` or ``None``, default ``None`` 555 | A mapping or filename with additional configuration for the experiment. 556 | inner_cv : CV-like object, ``"datasets"`` or ``False``, default ``False`` 557 | For estimators that perform cross validation (they have a ``cv`` 558 | parameter) this sets the cross validation strategy, as follows: 559 | 560 | * If ``False`` the original value of ``cv`` is unchanged. 561 | * If ``"dataset"``, the :external:class:`sklearn.utils.Bunch` objects 562 | for the datasets must have a ``inner_cv`` attribute, which will 563 | be the one used. 564 | * Otherwise, ``cv`` is changed to this value. 565 | outer_cv : CV-like object, ``"datasets"`` or ``False``, default ``None`` 566 | The strategy used to evaluate different partitions of the data, as 567 | follows: 568 | 569 | * If ``False`` use only one partition: the one specified in the 570 | dataset. Thus the :external:class:`sklearn.utils.Bunch` objects 571 | for the datasets should have defined at least a train and a test 572 | partition. 573 | * If ``"dataset"``, the :external:class:`sklearn.utils.Bunch` objects 574 | for the datasets must have a ``outer_cv`` attribute, which will 575 | be the one used. 576 | * Otherwise, this will be passed to 577 | :external:func:`sklearn.model_selection.check_cv` and the resulting 578 | cross validator will be used to define the partitions. 579 | scoring : string, callable or ``None``, default ``None`` 580 | Scoring method used to measure the performance of the estimator. 581 | If a callable, it should have the signature `scorer(estimator, X, y)`. 582 | If ``None`` it uses the ``scorer`` method of the estimator. 583 | save_estimator : bool, default ``False`` 584 | Whether to save the fitted estimator. This is useful for debugging 585 | and for obtaining extra information in some cases, but for some 586 | estimators it could consume much storage. 587 | save_train : bool, default ``False`` 588 | If ``True``, compute and store also the score over the train data. 589 | 590 | Returns 591 | ------- 592 | experiments : Sequence of :external:class:`sacred.Experiment` 593 | Sequence of Sacred experiments, ready to be run. 594 | 595 | See Also 596 | -------- 597 | run_experiments 598 | fetch_scores 599 | 600 | """ 601 | if isinstance(storage, str): 602 | storage = FileStorageObserver(storage) 603 | 604 | if config is None: 605 | config = {} 606 | 607 | return [ 608 | _create_one_experiment( 609 | estimator_name=estimator_name, 610 | estimator=estimator, 611 | dataset_name=dataset_name, 612 | dataset=dataset, 613 | storage=storage, 614 | config=config, 615 | inner_cv=inner_cv, 616 | outer_cv=outer_cv, 617 | scoring=scoring, 618 | save_estimator=save_estimator, 619 | save_train=save_train, 620 | ) 621 | for estimator_name, estimator in estimators.items() 622 | for dataset_name, dataset in datasets.items() 623 | ] 624 | 625 | 626 | def run_experiments( 627 | experiments: Sequence[Experiment], 628 | ) -> Sequence[int]: 629 | """ 630 | Run Sacred experiments. 631 | 632 | Parameters 633 | ---------- 634 | experiments : Sequence of :external:class:`sacred.Experiment` 635 | Sequence of Sacred experiments to be run. 636 | 637 | Returns 638 | ------- 639 | ids : Sequence of :external:class:`int` 640 | Sequence of identifiers for each experiment. 641 | 642 | See Also 643 | -------- 644 | create_experiments 645 | fetch_scores 646 | 647 | """ 648 | return [e.run()._id for e in experiments] 649 | 650 | 651 | def _loader_from_observer( 652 | storage: RunObserver | str, 653 | ) -> ExperimentLoader | FileSystemExperimentLoader: 654 | 655 | if isinstance(storage, str): 656 | return FileSystemExperimentLoader(storage) 657 | elif isinstance(storage, FileStorageObserver): 658 | return FileSystemExperimentLoader(storage.basedir) 659 | elif isinstance(storage, MongoObserver): 660 | database = storage.runs.database 661 | client = database.client 662 | url, port = list( 663 | client.topology_description.server_descriptions().keys(), 664 | )[0] 665 | 666 | return ExperimentLoader( 667 | mongo_uri=f"mongodb://{url}:{port}/", 668 | db_name=database.name, 669 | unpickle=False, 670 | ) 671 | 672 | raise ValueError(f"Observer {storage} is not supported.") 673 | 674 | 675 | def _get_experiments( 676 | *, 677 | storage: RunObserver | str, 678 | ids: Sequence[int] | None = None, 679 | dataset_names: Sequence[str] | None = None, 680 | estimator_names: Sequence[str] | None = None, 681 | ) -> Sequence[Experiment]: 682 | 683 | loader = _loader_from_observer(storage) 684 | 685 | if ( 686 | (ids, dataset_names, estimator_names) == (None, None, None) 687 | or isinstance(loader, FileSystemExperimentLoader) 688 | and ids is None 689 | ): 690 | find_all_fun = getattr( 691 | loader, 692 | "find_all", 693 | lambda: [ 694 | FileSystemExperiment.from_run_dir(run_dir) 695 | for run_dir in loader._runs_dir.iterdir() 696 | ], 697 | ) 698 | 699 | experiments = find_all_fun() 700 | 701 | elif (dataset_names, estimator_names) == (None, None) or isinstance( 702 | loader, FileSystemExperimentLoader 703 | ): 704 | load_ids_fun = getattr( 705 | loader, 706 | "find_by_ids", 707 | lambda id_seq: [ 708 | loader.find_by_id(experiment_id) for experiment_id in id_seq 709 | ], 710 | ) 711 | 712 | experiments = load_ids_fun(ids) 713 | 714 | else: 715 | 716 | conditions: List[ 717 | Mapping[ 718 | str, 719 | Mapping[str, Sequence[Any]], 720 | ] 721 | ] = [] 722 | 723 | if ids is not None: 724 | conditions.append({"_id": {"$in": ids}}) 725 | 726 | if estimator_names is not None: 727 | conditions.append({"config.estimator_name": {"$in": estimator_names}}) 728 | 729 | if dataset_names is not None: 730 | conditions.append({"config.dataset_name": {"$in": dataset_names}}) 731 | 732 | query = {"$and": conditions} 733 | 734 | experiments = loader.find(query) 735 | 736 | if isinstance(loader, FileSystemExperimentLoader): 737 | # Filter experiments by dataset and estimator names 738 | experiments = [ 739 | e 740 | for e in experiments 741 | if ( 742 | ( 743 | estimator_names is None 744 | or e.config["estimator_name"] in estimator_names 745 | ) 746 | and (dataset_names is None or e.config["dataset_name"] in dataset_names) 747 | ) 748 | ] 749 | 750 | return experiments 751 | 752 | 753 | def fetch_scores( 754 | *, 755 | storage: RunObserver | str, 756 | ids: Sequence[int] | None = None, 757 | dataset_names: Sequence[str] | None = None, 758 | estimator_names: Sequence[str] | None = None, 759 | ) -> ScoresInfo: 760 | """ 761 | Fetch scores from Sacred experiments. 762 | 763 | By default, it retrieves every experiment. The parameters ``ids``, 764 | ``estimator_names`` and ``dataset_names`` can be used to restrict the 765 | number of experiments returned. 766 | 767 | Parameters 768 | ---------- 769 | storage : :external:class:`sacred.observers.RunObserver` or :class:`str` 770 | Where the experiments are stored. Either a Sacred observer, for 771 | example for a Mongo database, or the name of a directory, to 772 | use a file observer. 773 | ids : Sequence of :external:class:`int` or ``None``, default ``None`` 774 | If not ``None``, return only experiments whose id is contained 775 | in the sequence. 776 | dataset_names : Sequence of :class:`str` or ``None``, default ``None`` 777 | If not ``None``, return only experiments whose dataset names are 778 | contained in the sequence. 779 | The order of the names is also the one used for datasets when 780 | combining the results. 781 | estimator_names : Sequence of :class:`str` or ``None``, default ``None`` 782 | If not ``None``, return only experiments whose estimator names are 783 | contained in the sequence. 784 | The order of the names is also the one used for estimators when 785 | combining the results. 786 | 787 | Returns 788 | ------- 789 | info : :class:`ScoresInfo` 790 | Class containing information about experiments scores. 791 | 792 | See Also 793 | -------- 794 | run_experiments 795 | fetch_scores 796 | 797 | """ 798 | 799 | experiments = _get_experiments( 800 | storage=storage, 801 | ids=ids, 802 | dataset_names=dataset_names, 803 | estimator_names=estimator_names, 804 | ) 805 | 806 | dict_experiments: Dict[ 807 | str, 808 | Dict[str, Tuple[np.typing.NDArray[float], float, float]], 809 | ] = {} 810 | estimator_list = [] 811 | dataset_list = [] 812 | 813 | nobs = 0 814 | 815 | for experiment in experiments: 816 | estimator_name = experiment.config["estimator_name"] 817 | if estimator_name not in estimator_list: 818 | estimator_list.append(estimator_name) 819 | dataset_name = experiment.config["dataset_name"] 820 | if dataset_name not in dataset_list: 821 | dataset_list.append(dataset_name) 822 | scores = experiment.info.get("test_score", np.array([])) 823 | score_mean = experiment.info.get("score_mean", np.nan) 824 | score_std = experiment.info.get("score_std", np.nan) 825 | 826 | nobs = max(nobs, len(scores)) 827 | 828 | assert np.isnan(score_mean) or score_mean == np.mean(scores) 829 | assert np.isnan(score_std) or score_std == np.std(scores) 830 | 831 | if estimator_name not in dict_experiments: 832 | dict_experiments[estimator_name] = {} 833 | 834 | if dataset_name in dict_experiments[estimator_name]: 835 | raise ValueError( 836 | f"Repeated experiment: ({estimator_name}, {dataset_name})", 837 | ) 838 | 839 | dict_experiments[estimator_name][dataset_name] = ( 840 | scores, 841 | score_mean, 842 | score_std, 843 | ) 844 | 845 | estimator_names = ( 846 | tuple(estimator_list) if estimator_names is None else estimator_names 847 | ) 848 | dataset_names = tuple(dataset_list) if dataset_names is None else dataset_names 849 | matrix_shape = (len(dataset_names), len(estimator_names)) 850 | 851 | scores = np.full(matrix_shape + (nobs,), np.nan) 852 | scores_mean = np.full(matrix_shape, np.nan) 853 | scores_std = np.full(matrix_shape, np.nan) 854 | 855 | for i, dataset_name in enumerate(dataset_names): 856 | for j, estimator_name in enumerate(estimator_names): 857 | dict_estimator = dict_experiments.get(estimator_name, {}) 858 | s, mean, std = dict_estimator.get( 859 | dataset_name, 860 | (np.array([]), np.nan, np.nan), 861 | ) 862 | if len(s) == nobs: 863 | scores[i, j] = s 864 | scores_mean[i, j] = mean 865 | scores_std[i, j] = std 866 | 867 | scores = np.array(scores.tolist()) 868 | 869 | return ScoresInfo( 870 | dataset_names=dataset_names, 871 | estimator_names=estimator_names, 872 | scores=scores, 873 | scores_mean=scores_mean, 874 | scores_std=scores_std, 875 | ) 876 | -------------------------------------------------------------------------------- /skdatasets/utils/scores.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: David Diaz Vico 3 | @license: MIT 4 | """ 5 | from __future__ import annotations 6 | 7 | import itertools as it 8 | from dataclasses import dataclass 9 | from functools import reduce 10 | from typing import Any, Callable, Literal, Mapping, Optional, Sequence, Tuple 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from scipy.stats import ( 15 | friedmanchisquare, 16 | kruskal, 17 | mannwhitneyu, 18 | rankdata, 19 | wilcoxon, 20 | ) 21 | from scipy.stats.stats import ttest_ind_from_stats, ttest_rel 22 | from statsmodels.sandbox.stats.multicomp import multipletests 23 | 24 | CorrectionLike = Literal[ 25 | None, 26 | "bonferroni", 27 | "sidak", 28 | "holm-sidak", 29 | "holm", 30 | "simes-hochberg", 31 | "hommel", 32 | "fdr_bh", 33 | "fdr_by", 34 | "fdr_tsbh", 35 | "fdr_tsbky", 36 | ] 37 | 38 | MultitestLike = Literal["kruskal", "friedmanchisquare"] 39 | 40 | TestLike = Literal["mannwhitneyu", "wilcoxon"] 41 | 42 | 43 | @dataclass 44 | class SummaryRow: 45 | values: np.typing.NDArray[Any] 46 | greater_is_better: bool | None = None 47 | 48 | 49 | @dataclass 50 | class ScoreCell: 51 | mean: float 52 | std: float | None 53 | rank: int 54 | significant: bool 55 | 56 | 57 | def average_rank( 58 | ranks: np.typing.NDArray[np.integer[Any]], 59 | **kwargs: Any, 60 | ) -> SummaryRow: 61 | """Compute rank averages.""" 62 | return SummaryRow( 63 | values=np.mean(ranks, axis=0), 64 | greater_is_better=False, 65 | ) 66 | 67 | 68 | def average_mean_score( 69 | means: np.typing.NDArray[np.floating[Any]], 70 | greater_is_better: bool, 71 | **kwargs: Any, 72 | ) -> SummaryRow: 73 | """Compute score mean averages.""" 74 | return SummaryRow( 75 | values=np.mean(means, axis=0), 76 | greater_is_better=greater_is_better, 77 | ) 78 | 79 | 80 | def _is_significant( 81 | scores1: np.typing.NDArray[np.floating[Any]], 82 | scores2: np.typing.NDArray[np.floating[Any]], 83 | mean1: np.typing.NDArray[np.floating[Any]], 84 | mean2: np.typing.NDArray[np.floating[Any]], 85 | std1: np.typing.NDArray[np.floating[Any]], 86 | std2: np.typing.NDArray[np.floating[Any]], 87 | *, 88 | nobs: int | None = None, 89 | two_sided: bool = True, 90 | paired_test: bool = False, 91 | significancy_level: float = 0.05, 92 | ) -> bool: 93 | 94 | alternative = "two-sided" if two_sided else "greater" 95 | 96 | if paired_test: 97 | assert scores1.ndim == 1 98 | assert scores2.ndim == 1 99 | 100 | _, pvalue = ttest_rel( 101 | scores1, 102 | scores2, 103 | axis=-1, 104 | alternative=alternative, 105 | ) 106 | 107 | else: 108 | assert nobs 109 | 110 | _, pvalue = ttest_ind_from_stats( 111 | mean1=mean1, 112 | std1=std1, 113 | nobs1=nobs, 114 | mean2=mean2, 115 | std2=std2, 116 | nobs2=nobs, 117 | equal_var=False, 118 | alternative=alternative, 119 | ) 120 | 121 | return pvalue < significancy_level 122 | 123 | 124 | def _all_significants( 125 | scores: np.typing.NDArray[np.floating[Any]], 126 | means: np.typing.NDArray[np.floating[Any]], 127 | stds: np.typing.NDArray[np.floating[Any]] | None, 128 | ranks: np.typing.NDArray[np.integer[Any]], 129 | *, 130 | nobs: int | None = None, 131 | two_sided: bool = True, 132 | paired_test: bool = False, 133 | significancy_level: float = 0, 134 | ) -> np.typing.NDArray[np.bool_]: 135 | 136 | significant_matrix = np.zeros_like(ranks, dtype=np.bool_) 137 | 138 | if stds is None or significancy_level <= 0: 139 | return significant_matrix 140 | 141 | for row, (scores_row, mean_row, std_row, rank_row) in enumerate( 142 | zip(scores, means, stds, ranks), 143 | ): 144 | for column, (scores1, mean1, std1, rank1) in enumerate( 145 | zip(scores_row, mean_row, std_row, rank_row), 146 | ): 147 | # Compare every element with all the ones with immediate below rank 148 | # It must be significantly better than all of them 149 | index2 = np.flatnonzero(rank_row == (rank1 + 1)) 150 | 151 | is_significant = len(index2) > 0 and all( 152 | _is_significant( 153 | scores1, 154 | scores_row[idx], 155 | mean1, 156 | mean_row[idx], 157 | std1, 158 | std_row[idx], 159 | nobs=nobs, 160 | two_sided=two_sided, 161 | paired_test=paired_test, 162 | significancy_level=significancy_level, 163 | ) 164 | for idx in index2 165 | ) 166 | 167 | if is_significant: 168 | significant_matrix[row, column] = True 169 | 170 | return significant_matrix 171 | 172 | 173 | def _set_style_classes( 174 | table: pd.DataFrame, 175 | *, 176 | all_ranks: np.typing.NDArray[np.integer[Any]], 177 | significants: np.typing.NDArray[np.bool_], 178 | n_summary_rows: int, 179 | ) -> pd.io.formats.style.Styler: 180 | rank_class_names = np.char.add( 181 | "rank", 182 | all_ranks.astype(str), 183 | ) 184 | 185 | is_summary_row = np.zeros_like(all_ranks, dtype=np.bool_) 186 | is_summary_row[-n_summary_rows:, :] = True 187 | 188 | summary_rows_class_name = np.char.multiply( 189 | "summary", 190 | is_summary_row.astype(int), 191 | ) 192 | 193 | significant_class_name = np.char.multiply( 194 | "significant", 195 | np.insert( 196 | significants, 197 | (len(significants),) * n_summary_rows, 198 | 0, 199 | axis=0, 200 | ).astype(int), 201 | ) 202 | 203 | styler = table.style.set_td_classes( 204 | pd.DataFrame( 205 | reduce( 206 | np.char.add, 207 | ( 208 | rank_class_names, 209 | " ", 210 | summary_rows_class_name, 211 | " ", 212 | significant_class_name, 213 | ), 214 | ), 215 | index=table.index, 216 | columns=table.columns, 217 | ), 218 | ) 219 | 220 | return styler 221 | 222 | 223 | def _set_style_formatter( 224 | styler: pd.io.formats.style.Styler, 225 | *, 226 | precision: int, 227 | show_rank: bool = True, 228 | ) -> pd.io.formats.style.Styler: 229 | def _formatter( 230 | data: object, 231 | ) -> str: 232 | if isinstance(data, str): 233 | return data 234 | elif isinstance(data, int): 235 | return str(int) 236 | elif isinstance(data, float): 237 | return f"{data:.{precision}f}" 238 | elif isinstance(data, ScoreCell): 239 | str_repr = f"{data.mean:.{precision}f}" 240 | if data.std is not None: 241 | str_repr += f" ± {data.std:.{precision}f}" 242 | if show_rank: 243 | precision_rank = 0 if isinstance(data.rank, int) else precision 244 | str_repr += f" ({data.rank:.{precision_rank}f})" 245 | return str_repr 246 | else: 247 | return "" 248 | 249 | return styler.format( 250 | _formatter, 251 | ) 252 | 253 | 254 | def _set_default_style_html( 255 | styler: pd.io.formats.style.Styler, 256 | *, 257 | n_summary_rows: int, 258 | ) -> pd.io.formats.style.Styler: 259 | 260 | last_rows_mask = np.zeros(len(styler.data), dtype=int) 261 | last_rows_mask[-n_summary_rows:] = 1 262 | 263 | styler = styler.set_table_styles( 264 | [ 265 | { 266 | "selector": ".summary", 267 | "props": [("font-style", "italic")], 268 | }, 269 | { 270 | "selector": ".rank1", 271 | "props": [("font-weight", "bold")], 272 | }, 273 | { 274 | "selector": ".rank2", 275 | "props": [("text-decoration", "underline")], 276 | }, 277 | { 278 | "selector": ".significant::after", 279 | "props": [ 280 | ("content", '"*"'), 281 | ("width", "0px"), 282 | ("display", "inline-block"), 283 | ], 284 | }, 285 | { 286 | "selector": ".col_heading", 287 | "props": [("font-weight", "bold")], 288 | }, 289 | ], 290 | ) 291 | 292 | styler = styler.apply_index( 293 | lambda _: np.char.multiply( 294 | "font-style: italic; font-weight: bold", 295 | last_rows_mask, 296 | ), 297 | axis=0, 298 | ) 299 | 300 | styler = styler.apply_index( 301 | lambda idx: ["font-weight: bold"] * len(idx), 302 | axis=1, 303 | ) 304 | 305 | return styler 306 | 307 | 308 | def _set_style_from_class( 309 | styler: pd.io.formats.style.Styler, 310 | class_name: str, 311 | style: str, 312 | ) -> pd.io.formats.style.Styler: 313 | style_matrix = np.full(styler.data.shape, style) 314 | 315 | for row in range(style_matrix.shape[0]): 316 | for column in range(style_matrix.shape[1]): 317 | classes = styler.cell_context.get( 318 | (row, column), 319 | "", 320 | ).split() 321 | 322 | if class_name not in classes: 323 | style_matrix[row, column] = "" 324 | 325 | return styler.apply(lambda x: style_matrix, axis=None) 326 | 327 | 328 | def _set_default_style_latex( 329 | styler: pd.io.formats.style.Styler, 330 | *, 331 | n_summary_rows: int, 332 | ) -> pd.io.formats.style.Styler: 333 | 334 | last_rows_mask = np.zeros(len(styler.data), dtype=int) 335 | last_rows_mask[-n_summary_rows:] = 1 336 | 337 | styler.set_table_styles( 338 | [ 339 | { 340 | "selector": r"newcommand{\summary}", 341 | "props": r":[1]{\textit{#1}};", 342 | }, 343 | { 344 | "selector": r"newcommand{\significant}", 345 | "props": r":[1]{#1*};", 346 | }, 347 | { 348 | "selector": r"newcommand{\rank}", 349 | "props": ( 350 | r":[2]{\ifnum#1=1 \textbf{#2} \else " 351 | r"\ifnum#1=2 \underline{#2} \else #2 \fi\fi};" 352 | ), 353 | }, 354 | ], 355 | overwrite=False, 356 | ) 357 | 358 | for rank in range(1, styler.data.shape[1] + 1): 359 | styler = _set_style_from_class( 360 | styler, 361 | f"rank{rank}", 362 | f"rank{{{rank}}}:--rwrap; ", 363 | ) 364 | 365 | for class_name in ("summary", "significant"): 366 | 367 | styler = _set_style_from_class( 368 | styler, 369 | class_name, 370 | f"{class_name}:--rwrap; ", 371 | ) 372 | 373 | styler = styler.apply_index( 374 | lambda _: np.char.multiply( 375 | "textbf:--rwrap;summary:--rwrap;", 376 | last_rows_mask, 377 | ), 378 | axis=0, 379 | ) 380 | 381 | styler = styler.apply_index( 382 | lambda idx: ["textbf:--rwrap"] * len(idx), 383 | axis=1, 384 | ) 385 | 386 | return styler 387 | 388 | 389 | def _set_default_style( 390 | styler: pd.io.formats.style.Styler, 391 | *, 392 | n_summary_rows: int, 393 | default_style: Literal["html", "latex", None], 394 | ) -> pd.io.formats.style.Styler: 395 | 396 | if default_style == "html": 397 | styler = _set_default_style_html( 398 | styler, 399 | n_summary_rows=n_summary_rows, 400 | ) 401 | elif default_style == "latex": 402 | styler = _set_default_style_latex( 403 | styler, 404 | n_summary_rows=n_summary_rows, 405 | ) 406 | 407 | return styler 408 | 409 | 410 | def scores_table( 411 | scores: np.typing.ArrayLike, 412 | stds: np.typing.ArrayLike | None = None, 413 | *, 414 | datasets: Sequence[str], 415 | estimators: Sequence[str], 416 | nobs: int | None = None, 417 | greater_is_better: bool = True, 418 | method: Literal["average", "min", "max", "dense", "ordinal"] = "min", 419 | significancy_level: float = 0, 420 | paired_test: bool = False, 421 | two_sided: bool = True, 422 | default_style: Literal["html", "latex", None] = "html", 423 | precision: int = 2, 424 | show_rank: bool = True, 425 | summary_rows: Sequence[Tuple[str, Callable[..., SummaryRow]]] = ( 426 | ("Average rank", average_rank), 427 | ), 428 | ) -> pd.io.formats.style.Styler: 429 | """ 430 | Scores table. 431 | 432 | Prints a table where each row represents a dataset and each column 433 | represents an estimator. 434 | 435 | Parameters 436 | ---------- 437 | scores: array-like 438 | Matrix of scores where each column represents a model. 439 | Either the full matrix with all experiment results or the 440 | matrix with the mean scores can be passed. 441 | stds: array-like, default=None 442 | Matrix of standard deviations where each column represents a 443 | model. If ``scores`` is the full matrix with all results 444 | this is automatically computed from it and should not be passed. 445 | datasets: sequence of :external:class:`str` 446 | List of dataset names. 447 | estimators: sequence of :external:class:`str` 448 | List of estimator names. 449 | nobs: :external:class:`int` 450 | Number of repetitions of the experiments. Used only for computing 451 | significances when ``scores`` is not the full matrix. 452 | greater_is_better: boolean, default=True 453 | Whether a greater score is better (score) or worse 454 | (loss). 455 | method: {'average', 'min', 'max', 'dense', 'ordinal'}, default='average' 456 | Method used to solve ties. 457 | significancy_level: :external:class:`float`, default=0 458 | Significancy level for considerin a result significant. If nonzero, 459 | significancy is calculated using a t-test. In that case, if 460 | ``paired_test`` is ``True``, ``scores`` should be the full matrix 461 | and a paired test is performed. Otherwise, the t-test assumes 462 | independence, and either ``scores`` should be the full matrix 463 | or ``nobs`` should be passed. 464 | paired_test: :external:class:`bool`, default=False 465 | Whether to perform a paired test or a test assuming independence. 466 | If ``True``, ``scores`` should be the full matrix. 467 | Otherwise, either ``scores`` should be the full matrix 468 | or ``nobs`` should be passed. 469 | two_sided: :external:class:`bool`, default=True 470 | Whether to perform a two sided t-test or a one sided t-test. 471 | default_style: {'html', 'latex', None}, default='html' 472 | Default style for the table. Use ``None`` for no style. Note that 473 | the CSS classes and textual formatting are always set. 474 | precision: :external:class:`int` 475 | Number of decimals used for floating point numbers. 476 | summary_rows: sequence 477 | List of (name, callable) tuples for additional summary rows. 478 | By default, the rank average is computed. 479 | 480 | Returns 481 | ------- 482 | table: array-like 483 | Table of mean and standard deviation of each estimator-dataset 484 | pair. A ranking of estimators is also generated. 485 | 486 | """ 487 | scores = np.asanyarray(scores) 488 | stds = None if stds is None else np.asanyarray(stds) 489 | 490 | assert scores.ndim in {2, 3} 491 | means = scores if scores.ndim == 2 else np.mean(scores, axis=-1) 492 | if scores.ndim == 3: 493 | assert stds is None 494 | assert nobs is None 495 | stds = np.std(scores, axis=-1) 496 | nobs = scores.shape[-1] 497 | 498 | ranks = np.asarray( 499 | [ 500 | rankdata(-m, method=method) 501 | if greater_is_better 502 | else rankdata(m, method=method) 503 | for m in means.round(precision) 504 | ] 505 | ) 506 | 507 | significants = _all_significants( 508 | scores, 509 | means, 510 | stds, 511 | ranks, 512 | nobs=nobs, 513 | two_sided=two_sided, 514 | paired_test=paired_test, 515 | significancy_level=significancy_level, 516 | ) 517 | 518 | table = pd.DataFrame(data=means, index=datasets, columns=estimators) 519 | for i, d in enumerate(datasets): 520 | for j, e in enumerate(estimators): 521 | table.loc[d, e] = ScoreCell( 522 | mean=means[i, j], 523 | std=None if stds is None else stds[i, j], 524 | rank=int(ranks[i, j]), 525 | significant=significants[i, j], 526 | ) 527 | 528 | # Create additional summary rows 529 | additional_ranks = [] 530 | for name, summary_fun in summary_rows: 531 | row = summary_fun( 532 | scores=scores, 533 | means=means, 534 | stds=stds, 535 | ranks=ranks, 536 | greater_is_better=greater_is_better, 537 | ) 538 | table.loc[name] = row.values 539 | 540 | if row.greater_is_better is None: 541 | additional_ranks.append(np.full(len(row.values), -1)) 542 | else: 543 | additional_ranks.append( 544 | rankdata(-row.values, method=method) 545 | if row.greater_is_better 546 | else rankdata(row.values, method=method), 547 | ) 548 | 549 | styler = _set_style_classes( 550 | table, 551 | all_ranks=np.vstack([ranks] + additional_ranks), 552 | significants=significants, 553 | n_summary_rows=len(summary_rows), 554 | ) 555 | 556 | styler = _set_style_formatter( 557 | styler, 558 | precision=precision, 559 | show_rank=show_rank, 560 | ) 561 | 562 | return _set_default_style( 563 | styler, 564 | n_summary_rows=len(summary_rows), 565 | default_style=default_style, 566 | ) 567 | 568 | 569 | def hypotheses_table( 570 | samples: np.typing.ArrayLike, 571 | models: Sequence[str], 572 | *, 573 | alpha: float = 0.05, 574 | multitest: Optional[MultitestLike] = None, 575 | test: TestLike = "wilcoxon", 576 | correction: CorrectionLike = None, 577 | multitest_args: Optional[Mapping[str, Any]] = None, 578 | test_args: Optional[Mapping[str, Any]] = None, 579 | ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]: 580 | """ 581 | Hypotheses table. 582 | 583 | Prints a hypothesis table with a selected test and correction. 584 | 585 | Parameters 586 | ---------- 587 | samples: array-like 588 | Matrix of samples where each column represent a model. 589 | models: array-like 590 | Model names. 591 | alpha: float in [0, 1], default=0.05 592 | Significance level. 593 | multitest: {'kruskal', 'friedmanchisquare'}, default=None 594 | Ranking multitest used. 595 | test: {'mannwhitneyu', 'wilcoxon'}, default='wilcoxon' 596 | Ranking test used. 597 | correction: {'bonferroni', 'sidak', 'holm-sidak', 'holm', \ 598 | 'simes-hochberg', 'hommel', 'fdr_bh', 'fdr_by', 'fdr_tsbh', \ 599 | 'fdr_tsbky'}, default=None 600 | Method used to adjust the p-values. 601 | multitest_args: dict 602 | Optional ranking test arguments. 603 | test_args: dict 604 | Optional ranking test arguments. 605 | 606 | Returns 607 | ------- 608 | multitest_table: array-like 609 | Table of p-value and rejection/non-rejection for the 610 | multitest hypothesis. 611 | test_table: array-like 612 | Table of p-values and rejection/non-rejection for each test 613 | hypothesis. 614 | 615 | """ 616 | if multitest_args is None: 617 | multitest_args = {} 618 | 619 | if test_args is None: 620 | test_args = {} 621 | 622 | samples = np.asanyarray(samples) 623 | 624 | versus = list(it.combinations(range(len(models)), 2)) 625 | comparisons = [ 626 | f"{models[first]} vs {models[second]}" for first, second in versus 627 | ] 628 | 629 | multitests = { 630 | "kruskal": kruskal, 631 | "friedmanchisquare": friedmanchisquare, 632 | } 633 | tests = { 634 | "mannwhitneyu": mannwhitneyu, 635 | "wilcoxon": wilcoxon, 636 | } 637 | 638 | multitest_table = None 639 | if multitest is not None: 640 | multitest_table = pd.DataFrame( 641 | index=[multitest], 642 | columns=["p-value", "Hypothesis"], 643 | ) 644 | _, pvalue = multitests[multitest]( 645 | *samples.T, 646 | **multitest_args, 647 | ) 648 | reject_str = "Rejected" if pvalue <= alpha else "Not rejected" 649 | multitest_table.loc[multitest] = ["{0:.2f}".format(pvalue), reject_str] 650 | 651 | # If the multitest does not detect a significative difference, 652 | # the individual tests are not meaningful, so skip them. 653 | if pvalue > alpha: 654 | return multitest_table, None 655 | 656 | pvalues = [ 657 | tests[test]( 658 | samples[:, first], 659 | samples[:, second], 660 | **test_args, 661 | )[1] 662 | for first, second in versus 663 | ] 664 | 665 | if correction is not None: 666 | reject_bool, pvalues, _, _ = multipletests( 667 | pvalues, 668 | alpha, 669 | method=correction, 670 | ) 671 | reject = ["Rejected" if r else "Not rejected" for r in reject_bool] 672 | else: 673 | reject = [ 674 | "Rejected" if pvalue <= alpha else "Not rejected" for pvalue in pvalues 675 | ] 676 | 677 | data = [("{0:.2f}".format(p), r) for p, r in zip(pvalues, reject)] 678 | 679 | test_table = pd.DataFrame( 680 | data, 681 | index=comparisons, 682 | columns=["p-value", "Hypothesis"], 683 | ) 684 | 685 | return multitest_table, test_table 686 | --------------------------------------------------------------------------------