├── .bumpversion.cfg ├── .coveragerc ├── .gitattributes ├── .github ├── CONTRIBUTING.rst └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .nojekyll ├── .pre-commit-config.yaml ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.rst ├── build_tools ├── conda │ ├── build.bat │ ├── build.sh │ └── meta.yaml └── travis │ ├── conda_upload.sh │ ├── install.sh │ ├── install_conda.sh │ ├── make_docs.sh │ ├── success.sh │ └── test.sh ├── docs ├── Makefile ├── _static │ ├── copybutton.js │ └── custom.css ├── conf.py ├── content │ ├── api.rst │ ├── api │ │ ├── common.rst │ │ ├── datasets.rst │ │ ├── model_selection.rst │ │ ├── preprocessing.rst │ │ ├── target.rst │ │ └── utils.rst │ ├── installation.rst │ ├── intro.rst │ ├── pipeline.rst │ ├── target.rst │ ├── transformers.rst │ ├── whatsnew.rst │ └── wrappers.rst ├── index.rst └── make.bat ├── environment.yml ├── examples ├── README.txt └── plot_activity_recognition.py ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── setup.py ├── sklearn_xarray ├── __init__.py ├── common │ ├── __init__.py │ ├── base.py │ └── wrappers.py ├── datasets.py ├── externals │ ├── __init__.py │ └── _numpy_groupies_np.py ├── model_selection.py ├── preprocessing.py ├── target.py └── utils.py └── tests ├── __init__.py ├── mocks.py ├── test_common.py ├── test_datasets.py ├── test_model_selection.py ├── test_preprocessing.py ├── test_target.py └── test_utils.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.4.0 3 | 4 | [bumpversion:file:setup.py] 5 | search = version="{current_version}" 6 | replace = version="{new_version}" 7 | 8 | [bumpversion:file:sklearn_xarray/__init__.py] 9 | search = __version__ = "{current_version}" 10 | replace = __version__ = "{new_version}" 11 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | *externals* 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.github/CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing to sklearn-xarray 2 | ============================== 3 | 4 | These guidelines have been largely adopted from the 5 | `scikit-learn contribution guidelines `_. 6 | 7 | 8 | How to contribute 9 | ----------------- 10 | 11 | The preferred workflow for contributing to sklearn-xarray is to fork the 12 | repository, clone, and develop on a branch. Steps: 13 | 14 | #. Fork the `project repository `_ 15 | by clicking on the 'Fork' button near the top right of the page. This creates 16 | a copy of the code under your GitHub user account. For more details on 17 | how to fork a repository see `this guide `_. 18 | 19 | #. Clone your fork of the sklearn-xarray repo from your GitHub account to your 20 | local disk:: 21 | 22 | $ git clone git@github.com:YourLogin/sklearn-xarray.git 23 | $ cd sklearn-xarray 24 | 25 | #. Create a ``feature`` branch to hold your development changes:: 26 | 27 | $ git checkout -b my-feature 28 | 29 | Always use a ``feature`` branch. It's good practice to never work on the 30 | ``master`` branch! 31 | 32 | #. Develop the feature on your feature branch. Add changed files using 33 | ``git add`` and then ``git commit`` files:: 34 | 35 | $ git add modified_files 36 | $ git commit 37 | 38 | to record your changes in Git, then push the changes to your GitHub 39 | account with:: 40 | 41 | $ git push -u origin my-feature 42 | 43 | #. Follow `these instructions `_ 44 | to create a pull request from your fork. This will send an email to the 45 | committers. 46 | 47 | 48 | Pull Request Checklist 49 | ---------------------- 50 | 51 | We recommended that your contribution complies with the following rules 52 | before you submit a pull request: 53 | 54 | - Follow the scikit-learn 55 | `coding-guidelines `_. 56 | 57 | - Use, when applicable, the validation tools and scripts in the 58 | ``sklearn-xarray.utils`` submodule as well as ``sklearn.utils`` from 59 | scikit-learn. 60 | 61 | - Give your pull request a helpful title that summarises what your 62 | contribution does. In some cases `Fix ` is enough. 63 | `Fix #` is not enough. 64 | 65 | - Often pull requests resolve one or more other issues (or pull requests). 66 | If merging your pull request means that some other issues/PRs should 67 | be closed, you should `use keywords to create link to them `_ 68 | (e.g., `Fixes #1234`; multiple issues/PRs are allowed as long as each one 69 | is preceded by a keyword). Upon merging, those issues/PRs will 70 | automatically be closed by GitHub. If your pull request is simply related 71 | to some other issues/PRs, create a link to them without using the keywords 72 | (e.g., `See also #1234`). 73 | 74 | - All public methods should have informative docstrings with sample 75 | usage presented as doctests when appropriate. 76 | 77 | - If you add a new module, add a file ``doc/content/api/your_module.rst`` that 78 | looks as follows:: 79 | 80 | 81 | =================== 82 | 83 | .. automodule:: sklearn_xarray.your_module 84 | :members: 85 | 86 | 87 | - Update ``doc/content/api.rst`` by adding a section for your module that looks 88 | like this:: 89 | 90 | 91 | ------------------ 92 | 93 | .. py:currentmodule:: sklearn_xarray.your_module 94 | 95 | Module: :py:mod:`sklearn_xarray.your_module` 96 | 97 | .. autosummary:: 98 | :nosignatures: 99 | 100 | YourClass1 101 | YourClass2 102 | your_function_1 103 | your_function_2 104 | 105 | and add ``api/your_module.rst`` to the toctree at the end of the file. If you 106 | add new classes or functions to an existing module you just have to 107 | add their names to the ``autosummary`` list like in the snippet above. 108 | 109 | - If you add new functionality you should demonstrate it by adding an example 110 | in the ``examples`` folder. Take a look at the source code of the existing 111 | examples as a syntax reference. 112 | 113 | - Documentation and high-coverage tests are necessary for enhancements to be 114 | accepted. Bug-fixes or new features should be provided with 115 | `non-regression tests `_. 116 | These tests verify the correct behavior of the fix or feature. In this 117 | manner, further modifications on the code base are granted to be consistent 118 | with the desired behavior. 119 | For the Bug-fixes case, at the time of the PR, these tests should fail for 120 | the code base in master and pass for the PR code. 121 | 122 | 123 | You can also check for common programming errors with the following 124 | tools: 125 | 126 | - Code with good unittest **coverage** (at least 80%), check with:: 127 | 128 | $ pip install pytest pytest-cov 129 | $ pytest --cov=sklearn_xarray 130 | 131 | - No flake8 warnings, check with:: 132 | 133 | $ pip install flake8 134 | $ flake8 sklearn_xarray tests --ignore=E203,W503,W504 --exclude=**/externals 135 | 136 | - Format code with black:: 137 | 138 | $ pip install black==19.10b0 139 | $ black . 140 | 141 | - pre-commit will run flake8 and black before each commit:: 142 | 143 | $ pip install pre-commit 144 | $ pre-commit install 145 | 146 | 147 | Filing bugs 148 | ----------- 149 | We use GitHub issues to track all bugs and feature requests; feel free to 150 | open an issue if you have found a bug or wish to see a feature implemented. 151 | 152 | It is recommended to check that your issue complies with the 153 | following rules before submitting: 154 | 155 | - Verify that your issue is not being currently addressed by other 156 | `issues `_ 157 | or `pull requests `_. 158 | 159 | - Please ensure all code snippets and error messages are formatted in 160 | appropriate code blocks. 161 | See `Creating and highlighting code blocks `_. 162 | 163 | - Please include your operating system type and version number, as well 164 | as your Python, scikit-learn, numpy, and scipy versions. This information 165 | can be found by running the following code snippet:: 166 | 167 | import platform; print(platform.platform()) 168 | import sys; print("Python", sys.version) 169 | import numpy; print("NumPy", numpy.__version__) 170 | import scipy; print("SciPy", scipy.__version__) 171 | import sklearn; print("Scikit-Learn", sklearn.__version__) 172 | 173 | - Please be specific about what estimators and/or functions are involved 174 | and the shape of the data, as appropriate; please include a 175 | `reproducible `_ code snippet 176 | or link to a `gist `_. If an exception is raised, 177 | please provide the traceback. 178 | 179 | 180 | New contributor tips 181 | -------------------- 182 | 183 | A great way to start contributing to sklearn-xarray is to pick an item from the 184 | list of `good first issues `_. 185 | Issues that might be a little more complicated to tackle are marked with 186 | `help wanted `_. 187 | 188 | 189 | Documentation 190 | ------------- 191 | 192 | We are glad to accept any sort of documentation: function docstrings, 193 | reStructuredText documents (like this one), tutorials, etc. 194 | reStructuredText documents live in the source code repository under the 195 | doc/ directory. 196 | 197 | You can edit the documentation using any text editor and then generate 198 | the HTML output by typing ``make html`` from the doc/ directory. 199 | Alternatively, ``make`` can be used to quickly generate the 200 | documentation without the example gallery. The resulting HTML files will 201 | be placed in ``_build/html/`` and are viewable in a web browser. 202 | 203 | For building the documentation, you will need 204 | `sphinx `_, 205 | `matplotlib `_, and 206 | `pillow `_. 207 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Make sure you have: 2 | - [ ] added the necessary updates to `doc/content/whatsnew.rst` 3 | - [ ] added documentation for new features 4 | - [ ] run bumpversion and reset conda build number if you are preparing a new release 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | docs/modules/ 4 | docs/_build/ 5 | docs/auto_examples/ 6 | coverage/ 7 | 8 | # So far, all html files are auto-generated 9 | *.html 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | -------------------------------------------------------------------------------- /.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phausamann/sklearn-xarray/0a8e61222a89e02665f444233e2bb2eb2bef7184/.nojekyll -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 19.10b0 4 | hooks: 5 | - id: black 6 | 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.7.9 9 | hooks: 10 | - id: flake8 11 | exclude: sklearn_xarray/externals 12 | args: ['--ignore=E203,W503,W504'] 13 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Config file for automatic testing at travis-ci.org 2 | os: linux 3 | dist: bionic 4 | language: python 5 | 6 | cache: 7 | apt: true 8 | directories: 9 | - "$HOME/.cache/pip" 10 | - "$HOME/download" 11 | 12 | env: 13 | global: 14 | - TEST_DIR=/tmp/test_dir/ 15 | - PKG_NAME=sklearn-xarray 16 | - MODULE=sklearn_xarray 17 | - USER=phausamann 18 | 19 | jobs: 20 | include: 21 | - python: 3.6 22 | env: COVERAGE="true" BLACK="true" FLAKE8="true" PAGES="true" PYPI="true" 23 | - python: 3.7 24 | - python: 3.8 25 | - env: CONDA_BLD_PATH=~/conda-bld 26 | install: 27 | - source build_tools/travis/install_conda.sh 28 | - conda install -y conda-build anaconda-client 29 | - conda config --set anaconda_upload no 30 | script: 31 | - conda build build_tools/conda 32 | after_success: 33 | - chmod +x build_tools/travis/conda_upload.sh 34 | 35 | install: 36 | - source build_tools/travis/install_conda.sh 37 | - source build_tools/travis/install.sh 38 | 39 | script: 40 | - bash build_tools/travis/test.sh 41 | - bash build_tools/travis/make_docs.sh 42 | 43 | after_success: 44 | - bash build_tools/travis/success.sh 45 | 46 | deploy: 47 | - provider: pages 48 | local-dir: docs/_build/html 49 | skip-cleanup: true 50 | github-token: $GITHUB_TOKEN # Set in the settings page of your repository, as a secure variable 51 | keep-history: true 52 | on: 53 | condition: "$PAGES = true" 54 | branch: master 55 | 56 | - provider: pypi 57 | skip_cleanup: true 58 | user: $USER 59 | password: 60 | secure: cRtfmsupJcyrZ1EU+NJ1eng0Abn9LeDVHLf1xZ1/1sg3qq6PwJwxNFJHTMin0sIJXERFAGG3btRnFqiwYsrxF7OdWObDYZN3G9riKKhS2Z5bSanWyrQk4XF/s9haONHKv2falsZ6nnux9GDMod+ojPzedNGagISLsLixHMRZmYFnUAJtdzDOm6PoNTui0+0C3bHoAIPu+FZJ1rPV1xmGM+4YGLg/j3yFt6SIY0XYY9d2torXSwD1E0+8V/kPxTcyNCQVE9LlFP3v9xLt2wYq7ehjGbetehSZyJxjchjtgABBMBkGTKqBwb3pgagaRmC9KVatpRVVVSLJRZAbFOmfK3QkZzrVzDVwOEloVhhUxUAm3rZDbZHmvHO0maS5VkpDAb3lE1edLziLiD0qqLBSuy5Tru+uELa6IO6gO8r/dA8usnKAcNWHIjrpLd3W7P+btjrmrSx8ReYs9PitKFiCLgleoAJGZFoSN0sOIAimCzvIsCvJyjlbHOvDyb+ziqvxu66yz/hBmupGibIT2529pyVW713gBOyrIvsLqzX3uDw6aYMTSi4aYp5+sfkCA5RE8Fc6PEPnqj6LbWjBF6bgelj3wUc9J4ZniuSWFMDKmBhk/p/j9CRg7RYQ5g+lK5E0oJma0vThqx8MKDivVk4oOMD8txcA0g1DJn8oE6i4ptc= 61 | on: 62 | condition: "$PYPI = true" 63 | branch: master 64 | tags: true 65 | 66 | - provider: script 67 | skip_cleanup: true 68 | script: build_tools/travis/conda_upload.sh 69 | on: 70 | condition: "$CONDA_BLD_PATH = $HOME/conda-bld" 71 | branch: master 72 | tags: true 73 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, phausamann 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include README.rst 3 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. -*- mode: rst -*- 2 | 3 | |Travis|_ |Coverage|_ |PyPI|_ |Black|_ 4 | 5 | .. |Travis| image:: https://travis-ci.org/phausamann/sklearn-xarray.svg?branch=master 6 | .. _Travis: https://travis-ci.org/phausamann/sklearn-xarray 7 | 8 | .. |Coverage| image:: https://coveralls.io/repos/github/phausamann/sklearn-xarray/badge.svg?branch=master 9 | .. _Coverage: https://coveralls.io/github/phausamann/sklearn-xarray?branch=master 10 | 11 | .. |PyPI| image:: https://badge.fury.io/py/sklearn-xarray.svg 12 | .. _PyPI: https://badge.fury.io/py/sklearn-xarray 13 | 14 | .. |Black| image:: https://img.shields.io/badge/code%20style-black-000000.svg 15 | .. _Black: https://github.com/psf/black 16 | 17 | sklearn-xarray 18 | ============== 19 | 20 | **sklearn-xarray** is an open-source python package that combines the 21 | n-dimensional labeled arrays of xarray_ with the machine learning and model 22 | selection tools of scikit-learn_. The package contains wrappers that allow 23 | the user to apply scikit-learn estimators to xarray types without losing their 24 | labels. 25 | 26 | .. _scikit-learn: http://scikit-learn.org/stable/ 27 | .. _xarray: http://xarray.pydata.org 28 | 29 | 30 | Documentation 31 | ------------- 32 | 33 | The package documentation can be found at 34 | https://phausamann.github.io/sklearn-xarray/ 35 | 36 | 37 | Features 38 | ---------- 39 | 40 | - Makes sklearn estimators compatible with xarray DataArrays and Datasets. 41 | - Allows for estimators to change the number of samples. 42 | - Adds a large number of pre-processing transformers. 43 | 44 | 45 | Installation 46 | ------------- 47 | 48 | The package can be installed with ``pip``:: 49 | 50 | $ pip install sklearn-xarray 51 | 52 | or with ``conda``:: 53 | 54 | $ conda install -c phausamann sklearn-xarray 55 | 56 | 57 | Example 58 | ------- 59 | 60 | The `activity recognition example`_ demonstrates how to use the 61 | package for cross-validated grid search for an activity recognition task. 62 | You can also download the example as a jupyter notebook. 63 | 64 | .. _activity recognition example: https://phausamann.github.io/sklearn-xarray/auto_examples/plot_activity_recognition.html 65 | 66 | 67 | Contributing 68 | ------------ 69 | 70 | Please read the `contribution guide `_ 71 | if you want to contribute to this project. 72 | -------------------------------------------------------------------------------- /build_tools/conda/build.bat: -------------------------------------------------------------------------------- 1 | "%PYTHON%" setup.py install --single-version-externally-managed --record=record.txt 2 | if errorlevel 1 exit 1 -------------------------------------------------------------------------------- /build_tools/conda/build.sh: -------------------------------------------------------------------------------- 1 | echo "Building" 2 | $PYTHON setup.py install --single-version-externally-managed --record=record.txt # Python command to install the script. -------------------------------------------------------------------------------- /build_tools/conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data() %} 2 | 3 | package: 4 | name: sklearn-xarray 5 | version: {{ data['version'] }} 6 | 7 | source: 8 | path: ../.. 9 | 10 | build: 11 | noarch: python 12 | number: 0 13 | ignore_run_exports: 14 | - python_abi 15 | 16 | requirements: 17 | build: 18 | - python 19 | - setuptools 20 | run: 21 | - python >=3.6 22 | - numpy 23 | - scipy 24 | - scikit-learn 25 | - xarray 26 | - pandas 27 | 28 | test: 29 | source_files: 30 | - tests 31 | - .coveragerc 32 | requires: 33 | - pytest 34 | - pytest-cov 35 | commands: 36 | - pytest --cov=sklearn_xarray 37 | 38 | about: 39 | home: https://github.com/phausamann/sklearn-xarray 40 | license: BSD-3 41 | license_file: LICENSE 42 | -------------------------------------------------------------------------------- /build_tools/travis/conda_upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then 4 | source "$HOME/miniconda3/etc/profile.d/conda.sh" 5 | else 6 | source "$HOME/miniconda/etc/profile.d/conda.sh" 7 | fi 8 | 9 | conda activate base 10 | anaconda -t $ANACONDA_TOKEN upload -u $USER $CONDA_BLD_PATH/noarch/$PKG_NAME*.tar.bz2 --force 11 | -------------------------------------------------------------------------------- /build_tools/travis/install.sh: -------------------------------------------------------------------------------- 1 | # Configure the conda environment and put it in the path using the 2 | # provided versions 3 | conda create -n testenv -y -c conda-forge \ 4 | python=$TRAVIS_PYTHON_VERSION \ 5 | dask-ml \ 6 | --file requirements.txt \ 7 | --file requirements_dev.txt 8 | 9 | source activate testenv 10 | 11 | if [[ "$COVERAGE" == "true" ]]; then 12 | conda install -y -c conda-forge pytest-cov coveralls 13 | fi 14 | 15 | python --version 16 | python -c "import numpy; print('numpy %s' % numpy.__version__)" 17 | python -c "import scipy; print('scipy %s' % scipy.__version__)" 18 | python setup.py develop 19 | -------------------------------------------------------------------------------- /build_tools/travis/install_conda.sh: -------------------------------------------------------------------------------- 1 | # Deactivate the travis-provided virtual environment and setup a 2 | # conda-based environment instead 3 | deactivate 4 | 5 | # Use the miniconda installer for faster download / install of conda 6 | # itself 7 | pushd . 8 | cd 9 | mkdir -p download 10 | cd download 11 | echo "Cached in $HOME/download :" 12 | ls -l 13 | echo 14 | if [[ ! -f miniconda.sh ]] 15 | then 16 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 17 | -O miniconda.sh 18 | fi 19 | chmod +x miniconda.sh && ./miniconda.sh -b 20 | cd .. 21 | export PATH=/home/travis/miniconda3/bin:$PATH 22 | conda update --yes conda 23 | popd 24 | -------------------------------------------------------------------------------- /build_tools/travis/make_docs.sh: -------------------------------------------------------------------------------- 1 | mkdir -p docs/modules/generated 2 | 3 | cd docs 4 | set -o pipefail && make html doctest 2>&1 | tee ~/log.txt 5 | cd .. 6 | 7 | cat ~/log.txt && if grep -q "Traceback (most recent call last):" ~/log.txt; \ 8 | then false; else true; fi 9 | 10 | cp .nojekyll docs/_build/html/.nojekyll 11 | -------------------------------------------------------------------------------- /build_tools/travis/success.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | if [[ "$COVERAGE" == "true" ]]; then 4 | # Need to run coveralls from a git checkout, so we copy .coverage 5 | # from TEST_DIR where nosetests has been run 6 | cp $TEST_DIR/.coverage $TRAVIS_BUILD_DIR 7 | cd $TRAVIS_BUILD_DIR 8 | # Ignore coveralls failures as the coveralls server is not 9 | # very reliable but we don't want travis to report a failure 10 | # in the github UI just because the coverage report failed to 11 | # be published. 12 | coveralls || echo "Coveralls upload failed" 13 | fi -------------------------------------------------------------------------------- /build_tools/travis/test.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | if [[ "$BLACK" == "true" ]]; then 4 | conda install -y -c conda-forge black=19.10b0 5 | black --check . 6 | fi 7 | 8 | if [[ "$FLAKE8" == "true" ]]; then 9 | conda install -y -c conda-forge flake8=3.7.9 10 | flake8 --ignore=E203,W503,W504 --exclude=**/externals 11 | fi 12 | 13 | # Get into a temp directory to run test from the installed package and 14 | # check if we do not leave artifacts 15 | mkdir -p $TEST_DIR 16 | cp .coveragerc $TEST_DIR/.coveragerc 17 | cp -r tests $TEST_DIR 18 | 19 | wd=$(pwd) 20 | cd $TEST_DIR 21 | 22 | if [[ "$COVERAGE" == "true" ]]; then 23 | pytest --cov=$MODULE 24 | else 25 | pytest 26 | fi 27 | 28 | cd $wd 29 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -v 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | -rm -rf $(BUILDDIR)/* 51 | -rm -rf auto_examples/ 52 | -rm -rf generated/* 53 | -rm -rf modules/generated/* 54 | 55 | html: 56 | # These two lines make the build a bit more lengthy, and the 57 | # the embedding of images more robust 58 | rm -rf $(BUILDDIR)/html/_images 59 | #rm -rf _build/doctrees/ 60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 63 | 64 | dirhtml: 65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 66 | @echo 67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 68 | 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | pickle: 75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 76 | @echo 77 | @echo "Build finished; now you can process the pickle files." 78 | 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | htmlhelp: 85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 86 | @echo 87 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 88 | ".hhp project file in $(BUILDDIR)/htmlhelp." 89 | 90 | qthelp: 91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 92 | @echo 93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/project-template.qhcp" 96 | @echo "To view the help file:" 97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/project-template.qhc" 98 | 99 | devhelp: 100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 101 | @echo 102 | @echo "Build finished." 103 | @echo "To view the help file:" 104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/project-template" 105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/project-template" 106 | @echo "# devhelp" 107 | 108 | epub: 109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 110 | @echo 111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 112 | 113 | latex: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo 116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 118 | "(use \`make latexpdf' here to do that automatically)." 119 | 120 | latexpdf: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through pdflatex..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | latexpdfja: 127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 128 | @echo "Running LaTeX files through platex and dvipdfmx..." 129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 131 | 132 | text: 133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 134 | @echo 135 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 136 | 137 | man: 138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 139 | @echo 140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 141 | 142 | texinfo: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo 145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 146 | @echo "Run \`make' in that directory to run these through makeinfo" \ 147 | "(use \`make info' here to do that automatically)." 148 | 149 | info: 150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 151 | @echo "Running Texinfo files through makeinfo..." 152 | make -C $(BUILDDIR)/texinfo info 153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 154 | 155 | gettext: 156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 157 | @echo 158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 159 | 160 | changes: 161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 162 | @echo 163 | @echo "The overview file is in $(BUILDDIR)/changes." 164 | 165 | linkcheck: 166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 167 | @echo 168 | @echo "Link check complete; look for any errors in the above output " \ 169 | "or in $(BUILDDIR)/linkcheck/output.txt." 170 | 171 | doctest: 172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 173 | @echo "Testing of doctests in the sources finished, look at the " \ 174 | "results in $(BUILDDIR)/doctest/output.txt." 175 | 176 | xml: 177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 178 | @echo 179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 180 | 181 | pseudoxml: 182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 183 | @echo 184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 185 | -------------------------------------------------------------------------------- /docs/_static/copybutton.js: -------------------------------------------------------------------------------- 1 | // Copyright 2014 PSF. Licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 2 | // File originates from the cpython source found in Doc/tools/sphinxext/static/copybutton.js 3 | 4 | $(document).ready(function() { 5 | /* Add a [>>>] button on the top-right corner of code samples to hide 6 | * the >>> and ... prompts and the output and thus make the code 7 | * copyable. */ 8 | var div = $('.highlight-python .highlight,' + 9 | '.highlight-python3 .highlight,' + 10 | '.highlight-pycon .highlight,' + 11 | '.highlight-pycon3 .highlight') 12 | var pre = div.find('pre'); 13 | 14 | // get the styles from the current theme 15 | pre.parent().parent().css('position', 'relative'); 16 | var hide_text = 'Hide the prompts and output'; 17 | var show_text = 'Show the prompts and output'; 18 | var border_width = pre.css('border-top-width'); 19 | var border_style = pre.css('border-top-style'); 20 | var border_color = pre.css('border-top-color'); 21 | var button_styles = { 22 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 23 | 'border-color': border_color, 'border-style': border_style, 24 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 25 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 26 | 'border-radius': '0 3px 0 0' 27 | } 28 | 29 | // create and add the button to all the code blocks that contain >>> 30 | div.each(function(index) { 31 | var jthis = $(this); 32 | if (jthis.find('.gp').length > 0) { 33 | var button = $('>>>'); 34 | button.css(button_styles) 35 | button.attr('title', hide_text); 36 | button.data('hidden', 'false'); 37 | jthis.prepend(button); 38 | } 39 | // tracebacks (.gt) contain bare text elements that need to be 40 | // wrapped in a span to work with .nextUntil() (see later) 41 | jthis.find('pre:has(.gt)').contents().filter(function() { 42 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 43 | }).wrap(''); 44 | }); 45 | 46 | // define the behavior of the button when it's clicked 47 | $('.copybutton').click(function(e){ 48 | e.preventDefault(); 49 | var button = $(this); 50 | if (button.data('hidden') === 'false') { 51 | // hide the code output 52 | button.parent().find('.go, .gp, .gt').hide(); 53 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 54 | button.css('text-decoration', 'line-through'); 55 | button.attr('title', show_text); 56 | button.data('hidden', 'true'); 57 | } else { 58 | // show the code output 59 | button.parent().find('.go, .gp, .gt').show(); 60 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 61 | button.css('text-decoration', 'none'); 62 | button.attr('title', hide_text); 63 | button.data('hidden', 'false'); 64 | } 65 | }); 66 | }); 67 | 68 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | .classifier:before { 2 | font-style: normal; 3 | margin: 0.5em; 4 | content: ":"; 5 | } -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # sklearn-xarray documentation build configuration file, created by 4 | # sphinx-quickstart on Mon Jan 18 14:44:12 2016. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | 18 | # pngmath / imgmath compatibility layer for different sphinx versions 19 | import sphinx 20 | import sklearn_xarray 21 | from distutils.version import LooseVersion 22 | 23 | import sphinx_rtd_theme 24 | 25 | # If extensions (or modules to document with autodoc) are in another directory, 26 | # add these directories to sys.path here. If the directory is relative to the 27 | # documentation root, use os.path.abspath to make it absolute, like shown here. 28 | # sys.path.insert(0, os.path.abspath('.')) 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Try to override the matplotlib configuration as early as possible 33 | try: 34 | import gen_rst # noqa 35 | except ImportError: 36 | pass 37 | # -- General configuration ------------------------------------------------ 38 | 39 | # If your documentation needs a minimal Sphinx version, state it here. 40 | # needs_sphinx = '1.0' 41 | 42 | sys.path.append(os.path.join(os.path.dirname(__name__), "..")) 43 | 44 | # Add any Sphinx extension module names here, as strings. They can be 45 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 46 | # ones. 47 | extensions = [ 48 | "sphinx.ext.autodoc", 49 | "sphinx.ext.doctest", 50 | "sphinx.ext.intersphinx", 51 | "sphinx.ext.todo", 52 | "numpydoc", 53 | "sphinx.ext.ifconfig", 54 | "sphinx.ext.viewcode", 55 | "sphinx_gallery.gen_gallery", 56 | "sphinx.ext.autosummary", 57 | ] 58 | 59 | if LooseVersion(sphinx.__version__) < LooseVersion("1.4"): 60 | extensions.append("sphinx.ext.pngmath") 61 | else: 62 | extensions.append("sphinx.ext.imgmath") 63 | 64 | sphinx_gallery_conf = { 65 | # path to your examples scripts 66 | "examples_dirs": "../examples", 67 | # path where to save gallery generated examples 68 | "gallery_dirs": "auto_examples", 69 | } 70 | 71 | # Add any paths that contain templates here, relative to this directory. 72 | templates_path = ["_templates"] 73 | 74 | # The suffix of source filenames. 75 | source_suffix = ".rst" 76 | 77 | # The encoding of source files. 78 | # source_encoding = 'utf-8-sig' 79 | 80 | # Generate the plots for the gallery 81 | plot_gallery = True 82 | 83 | # The master toctree document. 84 | master_doc = "index" 85 | 86 | # General information about the project. 87 | project = u"sklearn-xarray" 88 | copyright = u"2020, Peter Hausamann" 89 | 90 | # The version info for the project you're documenting, acts as replacement for 91 | # |version| and |release|, also used in various other places throughout the 92 | # built documents. 93 | # 94 | # The short X.Y version. 95 | version = sklearn_xarray.__version__ 96 | # The full version, including alpha/beta/rc tags. 97 | release = sklearn_xarray.__version__ 98 | 99 | # The language for content autogenerated by Sphinx. Refer to documentation 100 | # for a list of supported languages. 101 | # language = None 102 | 103 | # There are two options for replacing |today|: either, you set today to some 104 | # non-false value, then it is used: 105 | # today = '' 106 | # Else, today_fmt is used as the format for a strftime call. 107 | # today_fmt = '%B %d, %Y' 108 | 109 | # List of patterns, relative to source directory, that match files and 110 | # directories to ignore when looking for source files. 111 | exclude_patterns = ["_build"] 112 | 113 | # The reST default role (used for this markup: `text`) to use for all 114 | # documents. 115 | # default_role = None 116 | 117 | # If true, '()' will be appended to :func: etc. cross-reference text. 118 | # add_function_parentheses = True 119 | 120 | # If true, the current module name will be prepended to all description 121 | # unit titles (such as .. function::). 122 | add_module_names = False 123 | 124 | # If true, sectionauthor and moduleauthor directives will be shown in the 125 | # output. They are ignored by default. 126 | # show_authors = False 127 | 128 | # The name of the Pygments (syntax highlighting) style to use. 129 | pygments_style = "sphinx" 130 | 131 | # A list of ignored prefixes for module index sorting. 132 | # modindex_common_prefix = [] 133 | 134 | # If true, keep warnings as "system message" paragraphs in the built documents. 135 | # keep_warnings = False 136 | 137 | 138 | # -- Options for HTML output ---------------------------------------------- 139 | 140 | # The theme to use for HTML and HTML Help pages. See the documentation for 141 | # a list of builtin themes. 142 | html_theme = "sphinx_rtd_theme" 143 | 144 | # Theme options are theme-specific and customize the look and feel of a theme 145 | # further. For a list of options available for each theme, see the 146 | # documentation. 147 | # html_theme_options = {} 148 | 149 | # Add any paths that contain custom themes here, relative to this directory. 150 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 151 | 152 | # The name for this set of Sphinx documents. If None, it defaults to 153 | # " v documentation". 154 | # html_title = None 155 | 156 | # A shorter title for the navigation bar. Default is the same as html_title. 157 | # html_short_title = None 158 | 159 | # The name of an image file (relative to this directory) to place at the top 160 | # of the sidebar. 161 | # html_logo = None 162 | 163 | # The name of an image file (within the static path) to use as favicon of the 164 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 165 | # pixels large. 166 | # html_favicon = None 167 | 168 | # Add any paths that contain custom static files (such as style sheets) here, 169 | # relative to this directory. They are copied after the builtin static files, 170 | # so a file named "default.css" will overwrite the builtin "default.css". 171 | html_static_path = ["_static"] 172 | html_css_files = [ 173 | "custom.css", 174 | ] 175 | 176 | # Add any extra paths that contain custom files (such as robots.txt or 177 | # .htaccess) here, relative to this directory. These files are copied 178 | # directly to the root of the documentation. 179 | # html_extra_path = [] 180 | 181 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 182 | # using the given strftime format. 183 | # html_last_updated_fmt = '%b %d, %Y' 184 | 185 | # If true, SmartyPants will be used to convert quotes and dashes to 186 | # typographically correct entities. 187 | # html_use_smartypants = True 188 | 189 | # Custom sidebar templates, maps document names to template names. 190 | # html_sidebars = {} 191 | 192 | # Additional templates that should be rendered to pages, maps page names to 193 | # template names. 194 | # html_additional_pages = {} 195 | 196 | # If false, no module index is generated. 197 | # html_domain_indices = True 198 | 199 | # If false, no index is generated. 200 | # html_use_index = True 201 | 202 | # If true, the index is split into individual pages for each letter. 203 | # html_split_index = False 204 | 205 | # If true, links to the reST sources are added to the pages. 206 | # html_show_sourcelink = True 207 | 208 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 209 | # html_show_sphinx = True 210 | 211 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 212 | # html_show_copyright = True 213 | 214 | # If true, an OpenSearch description file will be output, and all pages will 215 | # contain a tag referring to it. The value of this option must be the 216 | # base URL from which the finished HTML is served. 217 | # html_use_opensearch = '' 218 | 219 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 220 | # html_file_suffix = None 221 | 222 | # Output file base name for HTML help builder. 223 | htmlhelp_basename = "sklearn-xarraydoc" 224 | 225 | # -- Options for LaTeX output --------------------------------------------- 226 | 227 | latex_elements = { 228 | # The paper size ('letterpaper' or 'a4paper'). 229 | # 'papersize': 'letterpaper', 230 | # The font size ('10pt', '11pt' or '12pt'). 231 | # 'pointsize': '10pt', 232 | # Additional stuff for the LaTeX preamble. 233 | # 'preamble': '', 234 | } 235 | 236 | # Grouping the document tree into LaTeX files. List of tuples 237 | # (source start file, target name, title, 238 | # author, documentclass [howto, manual, or own class]). 239 | latex_documents = [ 240 | ( 241 | "index", 242 | "sklearn-xarray.tex", 243 | u"sklearn-xarray Documentation", 244 | u"Peter Hausamann", 245 | "manual", 246 | ), 247 | ] 248 | 249 | # The name of an image file (relative to this directory) to place at the top of 250 | # the title page. 251 | # latex_logo = None 252 | 253 | # For "manual" documents, if this is true, then toplevel headings are parts, 254 | # not chapters. 255 | # latex_use_parts = False 256 | 257 | # If true, show page references after internal links. 258 | # latex_show_pagerefs = False 259 | 260 | # If true, show URL addresses after external links. 261 | # latex_show_urls = False 262 | 263 | # Documents to append as an appendix to all manuals. 264 | # latex_appendices = [] 265 | 266 | # If false, no module index is generated. 267 | # latex_domain_indices = True 268 | 269 | 270 | # -- Options for manual page output --------------------------------------- 271 | 272 | # One entry per manual page. List of tuples 273 | # (source start file, name, description, authors, manual section). 274 | man_pages = [ 275 | ( 276 | "index", 277 | "sklearn-xarray", 278 | u"sklearn-xarray Documentation", 279 | [u"Peter Hausamann"], 280 | 1, 281 | ) 282 | ] 283 | 284 | # If true, show URL addresses after external links. 285 | # man_show_urls = False 286 | 287 | 288 | # -- Options for Texinfo output ------------------------------------------- 289 | 290 | # Grouping the document tree into Texinfo files. List of tuples 291 | # (source start file, target name, title, author, 292 | # dir menu entry, description, category) 293 | texinfo_documents = [ 294 | ( 295 | "index", 296 | "sklearn-xarray", 297 | u"sklearn-xarray Documentation", 298 | u"Peter Hausamann", 299 | "sklearn-xarray", 300 | "Integrate xarray and sklearn.", 301 | "Miscellaneous", 302 | ), 303 | ] 304 | 305 | 306 | # def generate_example_rst(app, what, name, obj, options, lines): 307 | # # generate empty examples files, so that we don't get 308 | # # inclusion errors if there are no examples for a class / module 309 | # examples_path = os.path.join(app.srcdir, "modules", "generated", 310 | # "%s.examples" % name) 311 | # if not os.path.exists(examples_path): 312 | # # touch file 313 | # open(examples_path, 'w').close() 314 | 315 | 316 | def setup(app): 317 | app.add_javascript("copybutton.js") 318 | 319 | 320 | # app.connect('autodoc-process-docstring', generate_example_rst) 321 | 322 | 323 | # Documents to append as an appendix to all manuals. 324 | # texinfo_appendices = [] 325 | 326 | # If false, no module index is generated. 327 | # texinfo_domain_indices = True 328 | 329 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 330 | # texinfo_show_urls = 'footnote' 331 | 332 | # If true, do not generate a @detailmenu in the "Top" node's menu. 333 | # texinfo_no_detailmenu = False 334 | 335 | 336 | # Example configuration for intersphinx: refer to the Python standard library. 337 | intersphinx_mapping = {"http://docs.python.org/": None} 338 | -------------------------------------------------------------------------------- /docs/content/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. _API/Wrappers: 5 | 6 | Wrappers 7 | -------- 8 | 9 | .. py:currentmodule:: sklearn_xarray 10 | 11 | Module: :py:mod:`sklearn_xarray` 12 | 13 | .. autosummary:: 14 | :nosignatures: 15 | 16 | wrap 17 | ClassifierWrapper 18 | RegressorWrapper 19 | TransformerWrapper 20 | 21 | 22 | Target 23 | ------ 24 | 25 | .. py:currentmodule:: sklearn_xarray 26 | 27 | Module: :py:mod:`sklearn_xarray` 28 | 29 | .. autosummary:: 30 | :nosignatures: 31 | 32 | Target 33 | 34 | 35 | .. _API/Pre-processing: 36 | 37 | Pre-processing 38 | -------------- 39 | 40 | .. py:currentmodule:: sklearn_xarray.preprocessing 41 | 42 | Module: :py:mod:`sklearn_xarray.preprocessing` 43 | 44 | Object interface: 45 | 46 | .. autosummary:: 47 | :nosignatures: 48 | 49 | Concatenator 50 | Featurizer 51 | Reducer 52 | Resampler 53 | Sanitizer 54 | Segmenter 55 | Selector 56 | Splitter 57 | Transposer 58 | 59 | 60 | Functional interface: 61 | 62 | .. autosummary:: 63 | :nosignatures: 64 | 65 | concatenate 66 | featurize 67 | preprocess 68 | reduce 69 | resample 70 | sanitize 71 | segment 72 | select 73 | split 74 | transpose 75 | 76 | 77 | Model selection 78 | --------------- 79 | 80 | .. py:currentmodule:: sklearn_xarray.model_selection 81 | 82 | Module: :py:mod:`sklearn_xarray.model_selection` 83 | 84 | .. autosummary:: 85 | :nosignatures: 86 | 87 | CrossValidatorWrapper 88 | 89 | 90 | Utility functions 91 | ----------------- 92 | 93 | .. py:currentmodule:: sklearn_xarray.utils 94 | 95 | Module: :py:mod:`sklearn_xarray.utils` 96 | 97 | .. autosummary:: 98 | :nosignatures: 99 | 100 | convert_to_ndarray 101 | get_group_indices 102 | segment_array 103 | is_dataarray 104 | is_dataset 105 | is_target 106 | 107 | 108 | Datasets 109 | -------- 110 | 111 | .. py:currentmodule:: sklearn_xarray.datasets 112 | 113 | Module: :py:mod:`sklearn_xarray.datasets` 114 | 115 | .. autosummary:: 116 | :nosignatures: 117 | 118 | load_dummy_dataarray 119 | load_dummy_dataset 120 | load_digits_dataarray 121 | load_wisdm_dataarray 122 | 123 | 124 | List of modules 125 | --------------- 126 | 127 | .. toctree:: 128 | 129 | api/common 130 | api/preprocessing 131 | api/model_selection 132 | api/utils 133 | api/datasets 134 | -------------------------------------------------------------------------------- /docs/content/api/common.rst: -------------------------------------------------------------------------------- 1 | Top-level functions and classes 2 | =============================== 3 | 4 | .. automodule:: sklearn_xarray 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/content/api/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | 4 | .. automodule:: sklearn_xarray.datasets 5 | :members: -------------------------------------------------------------------------------- /docs/content/api/model_selection.rst: -------------------------------------------------------------------------------- 1 | Model Selection 2 | =============== 3 | 4 | .. automodule:: sklearn_xarray.model_selection 5 | :members: -------------------------------------------------------------------------------- /docs/content/api/preprocessing.rst: -------------------------------------------------------------------------------- 1 | Preprocessing 2 | ============= 3 | 4 | .. automodule:: sklearn_xarray.preprocessing 5 | :members: -------------------------------------------------------------------------------- /docs/content/api/target.rst: -------------------------------------------------------------------------------- 1 | Target 2 | ====== 3 | 4 | .. automodule:: sklearn_xarray.target 5 | :members: -------------------------------------------------------------------------------- /docs/content/api/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | 4 | .. automodule:: sklearn_xarray.utils 5 | :members: -------------------------------------------------------------------------------- /docs/content/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | The package can be installed with ``pip``:: 5 | 6 | $ pip install sklearn-xarray 7 | 8 | or with ``conda``:: 9 | 10 | $ conda install -c phausamann sklearn-xarray 11 | 12 | 13 | Testing 14 | ------- 15 | 16 | To run the unit tests, install ``pytest`` and run:: 17 | 18 | $ pytest 19 | 20 | 21 | Building the docs 22 | ----------------- 23 | 24 | To build the documentation, install ``sphinx``, ``sphinx-gallery``, 25 | ``sphinx_rtd_theme`` and ``numpydoc`` and run:: 26 | 27 | $ make -C docs html 28 | 29 | -------------------------------------------------------------------------------- /docs/content/intro.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | scikit-learn_ is an amazing machine learning package, but it has some 5 | drawbacks. For example, data arrays are limited to 2D numpy arrays and the the 6 | API does not support estimators changing the number of samples in the data. 7 | 8 | xarray_ provides a rich framework for labeled n-dimensional data structures, 9 | unfortunately most scikit-learn estimators will strip these structures of their 10 | labels and only return numpy arrays. 11 | 12 | sklearn-xarray tries to establish a bridge between the two packages that 13 | allows the user to integrate xarray data types into the scikit-learn 14 | framework with minor overhead. 15 | 16 | .. _scikit-learn: http://scikit-learn.org/stable/ 17 | .. _xarray: http://xarray.pydata.org -------------------------------------------------------------------------------- /docs/content/pipeline.rst: -------------------------------------------------------------------------------- 1 | Pipelining and cross-validation 2 | =============================== 3 | 4 | 5 | Integration with sklearn pipelines 6 | ---------------------------------- 7 | 8 | Wrapped estimators can be used in sklearn pipelines without any additional 9 | overhead: 10 | 11 | .. doctest:: 12 | 13 | >>> from sklearn_xarray import wrap, Target 14 | >>> from sklearn_xarray.datasets import load_digits_dataarray 15 | >>> from sklearn.pipeline import Pipeline 16 | >>> from sklearn.decomposition import PCA 17 | >>> from sklearn.linear_model.logistic import LogisticRegression 18 | >>> 19 | >>> X = load_digits_dataarray() 20 | >>> y = Target(coord='digit')(X) 21 | >>> 22 | >>> pipeline = Pipeline([ 23 | ... ('pca', wrap(PCA(n_components=50), reshapes='feature')), 24 | ... ('cls', wrap(LogisticRegression(), reshapes='feature')) 25 | ... ]) 26 | >>> 27 | >>> pipeline.fit(X, y) # doctest:+ELLIPSIS 28 | Pipeline(...) 29 | >>> pipeline.score(X, y) 30 | 1.0 31 | 32 | 33 | Cross-validated grid search 34 | --------------------------- 35 | 36 | .. py:currentmodule:: sklearn_xarray.model_selection 37 | 38 | .. note:: 39 | This feature is currently only available for DataArrays. 40 | 41 | The module :py:mod:`sklearn_xarray.model_selection` contains the 42 | :py:class:`CrossValidatorWrapper` class that wraps a cross-validator instance 43 | from ``sklearn.model_selection``. With such a wrapped cross-validator, it is 44 | possible to use xarray data types with a ``GridSearchCV`` estimator: 45 | 46 | .. doctest:: 47 | 48 | >>> from sklearn_xarray.model_selection import CrossValidatorWrapper 49 | >>> from sklearn.model_selection import GridSearchCV, KFold 50 | >>> 51 | >>> cv = CrossValidatorWrapper(KFold()) 52 | >>> pipeline = Pipeline([ 53 | ... ('pca', wrap(PCA(), reshapes='feature')), 54 | ... ('cls', wrap(LogisticRegression(), reshapes='feature')) 55 | ... ]) 56 | >>> 57 | >>> gridsearch = GridSearchCV( 58 | ... pipeline, cv=cv, param_grid={'pca__n_components': [20, 40, 60]} 59 | ... ) 60 | >>> 61 | >>> gridsearch.fit(X, y) # doctest:+ELLIPSIS 62 | GridSearchCV(...) 63 | >>> gridsearch.best_params_ 64 | {'pca__n_components': 20} 65 | >>> gridsearch.best_score_ 66 | 0.9182110801609408 67 | 68 | 69 | -------------------------------------------------------------------------------- /docs/content/target.rst: -------------------------------------------------------------------------------- 1 | Using coordinates as targets 2 | ============================ 3 | 4 | .. py:currentmodule:: sklearn_xarray.target 5 | 6 | With sklearn-xarray you can easily point an sklearn estimator to a 7 | coordinate in an xarray DataArray or Dataset in order to use it as a target 8 | for supervised learning. This is achieved with a :py:class:`Target` object: 9 | 10 | .. doctest:: 11 | 12 | >>> from sklearn_xarray import wrap, Target 13 | >>> from sklearn_xarray.datasets import load_digits_dataarray 14 | >>> from sklearn.linear_model.logistic import LogisticRegression 15 | >>> 16 | >>> X = load_digits_dataarray() 17 | >>> y = Target(coord='digit')(X) 18 | >>> X 19 | 20 | array([[ 0., 0., 5., ..., 0., 0., 0.], 21 | [ 0., 0., 0., ..., 10., 0., 0.], 22 | [ 0., 0., 0., ..., 16., 9., 0.], 23 | ..., 24 | [ 0., 0., 1., ..., 6., 0., 0.], 25 | [ 0., 0., 2., ..., 12., 0., 0.], 26 | [ 0., 0., 10., ..., 12., 1., 0.]]) 27 | Coordinates: 28 | * sample (sample) int64 0 1 2 3 4 5 6 ... 1790 1791 1792 1793 1794 1795 1796 29 | * feature (feature) int64 0 1 2 3 4 5 6 7 8 9 ... 55 56 57 58 59 60 61 62 63 30 | digit (sample) int64 0 1 2 3 4 5 6 7 8 9 0 1 ... 7 9 5 4 8 8 4 9 0 8 9 8 31 | >>> y 32 | sklearn_xarray.Target with data: 33 | 34 | array([0, 1, 2, ..., 8, 9, 8]) 35 | Coordinates: 36 | * sample (sample) int64 0 1 2 3 4 5 6 ... 1790 1791 1792 1793 1794 1795 1796 37 | digit (sample) int64 0 1 2 3 4 5 6 7 8 9 0 1 ... 7 9 5 4 8 8 4 9 0 8 9 8 38 | 39 | 40 | The target can point to any DataArray or Dataset that contains the specified 41 | coordinate, simply by calling the target with the Dataset/DataArray as an 42 | argument. When you construct a target without specifying a coordinate, the 43 | target data will be the Dataset/DataArray itself. 44 | 45 | The Target object can be used as a target for a wrapped estimator in accordance 46 | with sklearn's usual syntax: 47 | 48 | .. doctest:: 49 | 50 | >>> wrapper = wrap(LogisticRegression()) 51 | >>> wrapper.fit(X, y) # doctest:+ELLIPSIS 52 | EstimatorWrapper(...) 53 | >>> wrapper.score(X, y) 54 | 1.0 55 | 56 | .. note:: 57 | You don't have to assign the Target to any data, the wrapper's fit method 58 | will automatically call ``y(X)``. 59 | 60 | Pre-processing 61 | -------------- 62 | 63 | In some cases, it is necessary to pre-process the coordinate before it can be 64 | used as a target. For this, the constructor takes a ``transform_func`` parameter 65 | which can be used with the ``fit_transform`` method of transformers in 66 | ``sklearn.preprocessing`` (and also any other object implementing the sklearn 67 | transformer interface): 68 | 69 | .. doctest:: 70 | 71 | >>> from sklearn.neural_network import MLPClassifier 72 | >>> from sklearn.preprocessing import LabelBinarizer 73 | >>> 74 | >>> y = Target(coord='digit', transform_func=LabelBinarizer().fit_transform)(X) 75 | >>> wrapper = wrap(MLPClassifier()) 76 | >>> wrapper.fit(X, y) # doctest:+ELLIPSIS 77 | EstimatorWrapper(...) 78 | 79 | Indexing 80 | -------- 81 | 82 | A :py:class:`Target` object can be indexed in the same way as the underlying 83 | coordinate and interfaces with ``numpy`` by providing an ``__array__`` 84 | attribute which returns ``numpy.array()`` of the (transformed) coordinate. 85 | 86 | 87 | Multi-dimensional coordinates 88 | ----------------------------- 89 | 90 | In some cases, the target coordinates span multiple dimensions, but the 91 | transformer expects a lower-dimensional input. With the ``dim`` parameter of 92 | the :py:class:`Target` class you can specify which of the dimensions to keep. 93 | You can also specify the callable ``reduce_func`` to perform the reduction of 94 | the other dimensions (e.g. ``numpy.mean``). Otherwise, the coordinate will 95 | be reduced to the first element along each dimension that is not ``dim``. 96 | 97 | 98 | Lazy evaluation 99 | --------------- 100 | 101 | When you construct a target with a transformer and ``lazy=True``, the 102 | transformation will only be performed when the target's data is actually 103 | accessed. This can significantly improve performance when working with large 104 | datasets in a pipeline, because the target is assigned in each step of the 105 | pipeline. 106 | 107 | .. note:: 108 | When you index a target with lazy evaluation, the transformation is 109 | performed regardless of whether ``lazy`` was set. 110 | -------------------------------------------------------------------------------- /docs/content/transformers.rst: -------------------------------------------------------------------------------- 1 | Custom transformers 2 | =================== 3 | 4 | sklearn-xarray provides a wealth of newly defined transformers that exploit 5 | xarray's powerful array manipulation syntax. Refer to :ref:`API/Pre-processing` 6 | for a full list. 7 | 8 | 9 | Transformers changing the number of samples 10 | ------------------------------------------- 11 | 12 | There are several transformers that change the number of samples in the data, 13 | namely: 14 | 15 | .. py:currentmodule:: sklearn_xarray.preprocessing 16 | 17 | .. autosummary:: 18 | :nosignatures: 19 | 20 | Resampler 21 | Sanitizer 22 | Segmenter 23 | Splitter 24 | 25 | These kinds of transformer are usually disallowed by sklearn, because the 26 | package does not provide any mechanism of also changing the number of samples 27 | of the target in a pipelined supervised learning task. sklearn-xarray 28 | circumvents this restriction with the :py:class:`Target` class. 29 | 30 | We look at an example where the digits dataset is loaded but some of the 31 | samples are corrupted and contain ``nan`` values. The :py:class:`Sanitizer` 32 | transformer removes these samples from the dataset: 33 | 34 | .. doctest:: 35 | 36 | >>> from sklearn_xarray import wrap, Target 37 | >>> from sklearn_xarray.preprocessing import Sanitizer 38 | >>> from sklearn_xarray.datasets import load_digits_dataarray 39 | >>> from sklearn.pipeline import Pipeline 40 | >>> from sklearn.linear_model.logistic import LogisticRegression 41 | >>> 42 | >>> X = load_digits_dataarray(nan_probability=0.1) 43 | >>> y = Target(coord='digit')(X) 44 | >>> 45 | >>> pipeline = Pipeline([ 46 | ... ('san', Sanitizer()), 47 | ... ('cls', wrap(LogisticRegression(), reshapes='feature')) 48 | ... ]) 49 | >>> 50 | >>> pipeline.fit(X, y) # doctest:+ELLIPSIS 51 | Pipeline(...) 52 | 53 | If we had used ``y = X.digits`` instead of the :py:class:`Target` syntax, we 54 | would have gotten:: 55 | 56 | ValueError: Found input variables with inconsistent numbers of samples: [1635, 1797] 57 | 58 | 59 | Groupwise transformations 60 | ------------------------- 61 | 62 | When you apply transformers to your data that change the number of samples, 63 | there are cases when you don't want to apply the resampling operation to your 64 | whole dataset, but rather groups of data. 65 | 66 | One example is the WISDM activity recognition dataset found in the 67 | :py:mod:`sklearn_xarray.datasets` module. It contains time series accelerometer 68 | data from different subjects performing different activities. If, for 69 | example, we wanted to split this dataset into segments of 20 samples, we 70 | should do this in groups of subject/activity pairs, because otherwise we 71 | could get non-continuous samples from different recording times in the same 72 | segment. In order to perform transformations in a groupwise manner, we 73 | specify the ``groupby`` parameter: 74 | 75 | .. doctest:: 76 | 77 | >>> from sklearn_xarray.datasets import load_wisdm_dataarray 78 | >>> from sklearn_xarray.preprocessing import Segmenter 79 | >>> 80 | >>> segmenter = Segmenter( 81 | ... new_len=20, new_dim='timepoint', groupby=['subject', 'activity'] 82 | ... ) 83 | >>> 84 | >>> X = load_wisdm_dataarray() 85 | >>> Xt = segmenter.fit_transform(X) 86 | >>> Xt # doctest:+ELLIPSIS doctest:+NORMALIZE_WHITESPACE 87 | 88 | array([[[ -0.15 , 0.11 , ..., -2.26 , -1.46 ], 89 | [ 9.15 , 9.19 , ..., 9.72 , 9.81 ], 90 | [ -0.34 , 2.76 , ..., 2.03 , 2.15 ]], 91 | [[ 0.27 , -3.06 , ..., -2.56 , -2.6 ], 92 | [ 12.57 , 13.18 , ..., 14.56 , 8.96 ], 93 | [ 5.37 , 6.47 , ..., 0.31 , -3.3 ]], 94 | ..., 95 | [[ -0.3 , 0.27 , ..., 0.42 , 3.17 ], 96 | [ 8.08 , 6.63 , ..., 10.5 , 9.23 ], 97 | [ 0.994285, 0.994285, ..., -5.175732, -4.671779]], 98 | [[ 5.33 , 6.44 , ..., -4.14 , -4.9 ], 99 | [ 8.39 , 9.04 , ..., 6.21 , 6.55 ], 100 | [ -4.794363, -2.179256, ..., 5.938472, 3.827318]]]) 101 | Coordinates: 102 | * axis (axis) >> from sklearn_xarray import wrap 26 | >>> from sklearn_xarray.datasets import load_dummy_dataarray 27 | >>> from sklearn.preprocessing import StandardScaler 28 | >>> 29 | >>> X = load_dummy_dataarray() 30 | >>> Xt = wrap(StandardScaler()).fit_transform(X) 31 | 32 | The :py:func:`wrap` function will return an object with the corresponding 33 | methods for each type of estimator (e.g. ``predict`` for classifiers and 34 | regressors). 35 | 36 | .. note:: 37 | 38 | xarray references axes by name rather than by order. Therefore, you can 39 | specify the ``sample_dim`` parameter of the wrapper to refer to the 40 | dimension in your data that represents the samples. By default, the 41 | wrapper will assume that the first dimension in the array is the sample 42 | dimension. 43 | 44 | When we run the example, we see that the data in the array is scaled, but the 45 | coordinates and dimensions have not changed: 46 | 47 | .. doctest:: 48 | 49 | >>> X # doctest:+SKIP 50 | 51 | array([[ 0.565986, 0.196107, 0.935981, ..., 0.702356, 0.806494, 0.801178], 52 | [ 0.551611, 0.277749, 0.27546 , ..., 0.646887, 0.616391, 0.227552], 53 | [ 0.451261, 0.205744, 0.60436 , ..., 0.426333, 0.008449, 0.763937], 54 | ..., 55 | [ 0.019217, 0.112844, 0.894421, ..., 0.675889, 0.4957 , 0.740349], 56 | [ 0.542255, 0.053288, 0.483674, ..., 0.481905, 0.064586, 0.843511], 57 | [ 0.607809, 0.425632, 0.702882, ..., 0.521591, 0.315032, 0.4258 ]]) 58 | Coordinates: 59 | * sample (sample) int32 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ... 60 | * feature (feature) int32 0 1 2 3 4 5 6 7 8 9 61 | 62 | .. doctest:: 63 | 64 | >>> Xt # doctest:+SKIP 65 | 66 | array([[ 0.128639, -0.947769, 1.625452, ..., 0.525571, 1.07678 , 1.062118], 67 | [ 0.077973, -0.673463, -0.631625, ..., 0.321261, 0.408263, -0.942871], 68 | [-0.275702, -0.91539 , 0.492264, ..., -0.491108, -1.729624, 0.931952], 69 | ..., 70 | [-1.7984 , -1.227519, 1.483434, ..., 0.428084, -0.016158, 0.849506], 71 | [ 0.045001, -1.427621, 0.079865, ..., -0.286418, -1.532214, 1.210086], 72 | [ 0.27604 , -0.176596, 0.828923, ..., -0.140244, -0.651494, -0.249936]]) 73 | Coordinates: 74 | * sample (sample) int32 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ... 75 | * feature (feature) int32 0 1 2 3 4 5 6 7 8 9 76 | 77 | 78 | Estimators changing the shape of the data 79 | ----------------------------------------- 80 | 81 | Many sklearn estimators will change the number of features during 82 | transformation or prediction. In this case, the coordinates along the feature 83 | dimension no longer correspond to those of the original array. Therefore, the 84 | wrapper will omit the coordinates along this dimension. You can specify which 85 | dimension is changed with the ``reshapes`` parameter: 86 | 87 | .. doctest:: 88 | 89 | >>> from sklearn.decomposition import PCA 90 | >>> Xt = wrap(PCA(n_components=5), reshapes='feature').fit_transform(X) 91 | >>> Xt # doctest:+SKIP 92 | 93 | array([[ 0.438773, -0.100947, 0.106754, 0.236872, -0.128751], 94 | [-0.40433 , -0.580941, 0.588425, -0.305739, -0.120676], 95 | [ 0.343535, -0.334365, 0.659667, 0.111196, 0.308099], 96 | ..., 97 | [ 0.519982, 0.38072 , 0.133793, -0.064086, 0.108029], 98 | [-0.099056, -0.086161, -0.115271, -0.053594, -0.736321], 99 | [-0.358513, -0.327132, -0.635314, -0.310221, -0.017318]]) 100 | Coordinates: 101 | * sample (sample) int32 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ... 102 | Dimensions without coordinates: feature 103 | 104 | .. todo:: 105 | reshapes dict 106 | 107 | 108 | Accessing fitted estimators 109 | --------------------------- 110 | 111 | The ``estimator`` attribute of the wrapper will always hold the unfitted 112 | estimator that was passed initially. After calling ``fit`` the fitted estimator 113 | will be stored in the ``estimator_`` attribute: 114 | 115 | .. doctest:: 116 | 117 | >>> wrapper = wrap(StandardScaler()) 118 | >>> wrapper.fit(X) 119 | EstimatorWrapper(copy=True, estimator=StandardScaler(), with_mean=True, 120 | with_std=True) 121 | >>> wrapper.estimator_.mean_ # doctest:+SKIP 122 | array([ 0.46156856, 0.47165326, 0.48397815, 0.48958361, 0.4730579 , 123 | 0.522414 , 0.46496134, 0.52299264, 0.48772645, 0.49043086]) 124 | 125 | The wrapper also directly reflects the fitted attributes: 126 | 127 | .. doctest:: 128 | 129 | >>> wrapper.mean_ # doctest:+SKIP 130 | array([ 0.46156856, 0.47165326, 0.48397815, 0.48958361, 0.4730579 , 131 | 0.522414 , 0.46496134, 0.52299264, 0.48772645, 0.49043086]) 132 | 133 | 134 | Wrapping estimators for Datasets 135 | -------------------------------- 136 | 137 | .. py:currentmodule:: sklearn_xarray.dataset 138 | 139 | The syntax for Datasets is exactly the same as for DataArrays. Note that the 140 | wrapper will fit one estimator for each variable in the Dataset. The fitted 141 | estimators are stored in the attribute ``estimator_dict_``: 142 | 143 | .. doctest:: 144 | 145 | >>> from sklearn_xarray.datasets import load_dummy_dataset 146 | >>> 147 | >>> X = load_dummy_dataset() 148 | >>> wrapper = wrap(StandardScaler()) 149 | >>> wrapper.fit(X) 150 | EstimatorWrapper(copy=True, estimator=StandardScaler(), with_mean=True, 151 | with_std=True) 152 | >>> wrapper.estimator_dict_ 153 | {'var_1': StandardScaler()} 154 | 155 | The wrapper also directly reflects the fitted attributes as dictionaries with 156 | one entry for each variable: 157 | 158 | .. doctest:: 159 | 160 | >>> wrapper.mean_['var_1'] # doctest:+SKIP 161 | array([ 0.46156856, 0.47165326, 0.48397815, 0.48958361, 0.4730579 , 162 | 0.522414 , 0.46496134, 0.52299264, 0.48772645, 0.49043086]) 163 | 164 | 165 | Wrapping dask-ml estimators 166 | --------------------------- 167 | 168 | The dask-ml_ package re-implements a number of scikit-learn estimators for 169 | use with dask_ on-disk arrays. You can wrap these estimators in the same way 170 | in order to work with dask-backed DataArrays and Datasets: 171 | 172 | .. doctest:: 173 | 174 | >>> from dask_ml.preprocessing import StandardScaler 175 | >>> import xarray as xr 176 | >>> import numpy as np 177 | >>> import dask.array as da 178 | >>> 179 | >>> X = xr.DataArray( 180 | ... da.from_array(np.random.random((100, 10)), chunks=(10, 10)), 181 | ... coords={'sample': range(100), 'feature': range(10)}, 182 | ... dims=('sample', 'feature') 183 | ... ) 184 | >>> Xt = wrap(StandardScaler()).fit_transform(X) 185 | >>> type(Xt.data) 186 | 187 | 188 | 189 | .. _dask-ml: http://dask-ml.readthedocs.io/en/latest/index.html 190 | .. _dask: http://dask.pydata.org/en/latest/ 191 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. documentation master 2 | 3 | sklearn-xarray: Using scikit-learn with xarray 4 | ============================================== 5 | 6 | **sklearn-xarray** is an open-source python package that combines the 7 | n-dimensional labeled arrays of xarray_ with the machine learning and model 8 | selection tools of scikit-learn_. The package contains wrappers that allow 9 | the user to apply scikit-learn estimators to xarray types without losing their 10 | labels. 11 | 12 | .. _scikit-learn: http://scikit-learn.org/stable/ 13 | .. _xarray: http://xarray.pydata.org 14 | 15 | The code repository is hosted on GitHub_. 16 | 17 | .. _GitHub: https://github.com/phausamann/sklearn-xarray 18 | 19 | Documentation 20 | ------------- 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | 25 | content/whatsnew 26 | content/intro 27 | content/installation 28 | content/wrappers 29 | content/target 30 | content/pipeline 31 | content/transformers 32 | 33 | auto_examples/index 34 | content/api 35 | 36 | 37 | Indices and tables 38 | ------------------ 39 | 40 | * :ref:`genindex` 41 | * :ref:`search` 42 | 43 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\project-template.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\project-template.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sklearn-xarray 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.6 6 | - numpy 7 | - scipy 8 | - scikit-learn 9 | - pandas 10 | - xarray 11 | - pytest 12 | - matplotlib 13 | - sphinx 14 | - pillow 15 | - sphinx-gallery 16 | - sphinx_rtd_theme 17 | - numpydoc 18 | - bump2version 19 | - pre-commit 20 | - flake8 21 | - black=19.10b0 22 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | General examples 4 | ================ 5 | 6 | Introductory examples. 7 | -------------------------------------------------------------------------------- /examples/plot_activity_recognition.py: -------------------------------------------------------------------------------- 1 | """ 2 | Activity recognition from accelerometer data 3 | ============================================ 4 | 5 | This demo shows how the **sklearn-xarray** package works with the ``Pipeline`` 6 | and ``GridSearchCV`` methods from scikit-learn providing a metadata-aware 7 | grid-searchable pipeline mechansism. 8 | 9 | The package combines the metadata-handling capabilities of xarray with the 10 | machine-learning framework of sklearn. It enables the user to apply 11 | preprocessing steps group by group, use transformers that change the number 12 | of samples, use metadata directly as labels for classification tasks and more. 13 | 14 | The example performs activity recognition from raw accelerometer data with a 15 | Gaussian naive Bayes classifier. It uses the 16 | `WISDM`_ activity prediction dataset which contains the activities 17 | walking, jogging, walking upstairs, walking downstairs, sitting and standing 18 | from 36 different subjects. 19 | 20 | .. _WISDM: http://www.cis.fordham.edu/wisdm/dataset.php 21 | """ 22 | 23 | from __future__ import print_function 24 | 25 | import numpy as np 26 | 27 | from sklearn_xarray import wrap, Target 28 | from sklearn_xarray.preprocessing import Splitter, Sanitizer, Featurizer 29 | from sklearn_xarray.model_selection import CrossValidatorWrapper 30 | from sklearn_xarray.datasets import load_wisdm_dataarray 31 | 32 | from sklearn.preprocessing import StandardScaler, LabelEncoder 33 | from sklearn.decomposition import PCA 34 | from sklearn.naive_bayes import GaussianNB 35 | from sklearn.model_selection import GroupShuffleSplit, GridSearchCV 36 | from sklearn.pipeline import Pipeline 37 | 38 | import matplotlib.pyplot as plt 39 | 40 | ############################################################################## 41 | # First, we load the dataset and plot an example of one subject performing 42 | # the 'Walking' activity. 43 | # 44 | # .. tip:: 45 | # 46 | # In the jupyter notebook version, change the first cell to ``%matplotlib 47 | # notebook`` in order to get an interactive plot that you can zoom and pan. 48 | 49 | X = load_wisdm_dataarray() 50 | 51 | X_plot = X[np.logical_and(X.activity == "Walking", X.subject == 1)] 52 | X_plot = X_plot[:500] / 9.81 53 | X_plot["sample"] = (X_plot.sample - X_plot.sample[0]) / np.timedelta64(1, "s") 54 | 55 | f, axarr = plt.subplots(3, 1, sharex=True) 56 | 57 | axarr[0].plot(X_plot.sample, X_plot.sel(axis="x"), color="#1f77b4") 58 | axarr[0].set_title("Acceleration along x-axis") 59 | 60 | axarr[1].plot(X_plot.sample, X_plot.sel(axis="y"), color="#ff7f0e") 61 | axarr[1].set_ylabel("Acceleration [g]") 62 | axarr[1].set_title("Acceleration along y-axis") 63 | 64 | axarr[2].plot(X_plot.sample, X_plot.sel(axis="z"), color="#2ca02c") 65 | axarr[2].set_xlabel("Time [s]") 66 | axarr[2].set_title("Acceleration along z-axis") 67 | 68 | 69 | ############################################################################## 70 | # Then we define a pipeline with various preprocessing steps and a classifier. 71 | # 72 | # The preprocessing consists of splitting the data into segments, removing 73 | # segments with `nan` values and standardizing. Since the accelerometer data is 74 | # three-dimensional but the standardizer and classifier expect a 75 | # one-dimensional feature vector, we have to vectorize the samples. 76 | # 77 | # Finally, we use PCA and a naive Bayes classifier for classification. 78 | 79 | pl = Pipeline( 80 | [ 81 | ( 82 | "splitter", 83 | Splitter( 84 | groupby=["subject", "activity"], 85 | new_dim="timepoint", 86 | new_len=30, 87 | ), 88 | ), 89 | ("sanitizer", Sanitizer()), 90 | ("featurizer", Featurizer()), 91 | ("scaler", wrap(StandardScaler)), 92 | ("pca", wrap(PCA, reshapes="feature")), 93 | ("cls", wrap(GaussianNB, reshapes="feature")), 94 | ] 95 | ) 96 | 97 | ############################################################################## 98 | # Since we want to use cross-validated grid search to find the best model 99 | # parameters, we define a cross-validator. In order to make sure the model 100 | # performs subject-independent recognition, we use a `GroupShuffleSplit` 101 | # cross-validator that ensures that the same subject will not appear in both 102 | # training and validation set. 103 | 104 | cv = CrossValidatorWrapper( 105 | GroupShuffleSplit(n_splits=2, test_size=0.5), groupby=["subject"] 106 | ) 107 | 108 | ############################################################################## 109 | # The grid search will try different numbers of PCA components to find the best 110 | # parameters for this task. 111 | # 112 | # .. tip:: 113 | # 114 | # To use multi-processing, set ``n_jobs=-1``. 115 | 116 | gs = GridSearchCV( 117 | pl, cv=cv, n_jobs=1, verbose=1, param_grid={"pca__n_components": [10, 20]} 118 | ) 119 | 120 | ############################################################################## 121 | # The label to classify is the activity which we convert to an integer 122 | # representation for the classification. 123 | 124 | y = Target( 125 | coord="activity", transform_func=LabelEncoder().fit_transform, dim="sample" 126 | )(X) 127 | 128 | ############################################################################## 129 | # Finally, we run the grid search and print out the best parameter combination. 130 | 131 | if __name__ == "__main__": # in order for n_jobs=-1 to work on Windows 132 | gs.fit(X, y) 133 | print("Best parameters: {0}".format(gs.best_params_)) 134 | print("Accuracy: {0}".format(gs.best_score_)) 135 | 136 | ############################################################################## 137 | # .. note:: 138 | # 139 | # The performance of this classifier is obviously pretty bad, 140 | # it was chosen for execution speed, not accuracy. 141 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | target_version = ['py27'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | ( 7 | /( 8 | \.eggs # exclude a few common directories in the 9 | | \.git # root of the project 10 | | \.hg 11 | | \.mypy_cache 12 | | \.tox 13 | | \.venv 14 | | _build 15 | | buck-out 16 | | build 17 | | dist 18 | )/ 19 | ) 20 | ''' 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==0.23.1 2 | xarray==0.15.1 3 | pandas==1.0.4 4 | numpy==1.18.5 5 | scipy==1.4.1 6 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pytest==5.4.3 2 | sphinx=2.4.4 3 | sphinx_rtd_theme==0.4.3 4 | sphinx-gallery==0.7.0 5 | numpydoc==1.0.0 6 | matplotlib==3.2.1 7 | pillow==7.1.2 8 | bump2version==1.0.0 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | INSTALL_REQUIRES = ["numpy", "scipy", "scikit-learn", "pandas", "xarray"] 4 | 5 | with open("README.rst") as f: 6 | readme = f.read() 7 | 8 | setup( 9 | name="sklearn-xarray", 10 | version="0.4.0", 11 | description="xarray integration with sklearn", 12 | long_description=readme, 13 | author="Peter Hausamann", 14 | packages=find_packages(), 15 | install_requires=INSTALL_REQUIRES, 16 | author_email="peter.hausamann@tum.de", 17 | ) 18 | -------------------------------------------------------------------------------- /sklearn_xarray/__init__.py: -------------------------------------------------------------------------------- 1 | """ ``sklearn_xarray`` """ 2 | 3 | from sklearn_xarray.common.wrappers import ( 4 | wrap, 5 | EstimatorWrapper, 6 | ClassifierWrapper, 7 | RegressorWrapper, 8 | TransformerWrapper, 9 | ) 10 | from sklearn_xarray.target import Target 11 | 12 | import os 13 | 14 | 15 | __all__ = [ 16 | "wrap", 17 | "EstimatorWrapper", 18 | "ClassifierWrapper", 19 | "RegressorWrapper", 20 | "TransformerWrapper", 21 | "Target", 22 | ] 23 | 24 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 25 | 26 | __version__ = "0.4.0" 27 | -------------------------------------------------------------------------------- /sklearn_xarray/common/__init__.py: -------------------------------------------------------------------------------- 1 | """ ``sklearn_xarray.common`` """ 2 | -------------------------------------------------------------------------------- /sklearn_xarray/common/base.py: -------------------------------------------------------------------------------- 1 | """ ``sklearn_xarray.common.base`` """ 2 | 3 | import numpy as np 4 | import xarray as xr 5 | 6 | from sklearn.base import clone, BaseEstimator 7 | from sklearn.utils.validation import check_is_fitted, check_array, check_X_y 8 | 9 | from sklearn_xarray.utils import is_dataarray, is_dataset, is_target 10 | 11 | 12 | class _CommonEstimatorWrapper(BaseEstimator): 13 | """ Base class for DataArray and Dataset wrappers. """ 14 | 15 | @staticmethod 16 | def _transpose_y(X, y, order): 17 | """ Transpose y. """ 18 | 19 | if y.ndim == X.ndim: 20 | y = np.transpose(np.array(y), order) 21 | elif y.ndim == 1: 22 | y = np.array(y) 23 | else: 24 | raise ValueError("Could not figure out how to transpose y.") 25 | 26 | return y 27 | 28 | def _get_transpose_order(self, X): 29 | """ Get the transpose order that puts the sample dim first. """ 30 | 31 | sample_axis = X.dims.index(self.sample_dim) 32 | order = list(range(len(X.dims))) 33 | order.remove(sample_axis) 34 | order.insert(0, sample_axis) 35 | 36 | return order 37 | 38 | def _update_dims(self, X_in, X_out): 39 | """ Update the dimensions of a reshaped DataArray. """ 40 | 41 | dims_new = list(X_in.dims) 42 | 43 | # dict syntax 44 | if hasattr(self.reshapes, "items"): 45 | 46 | # check if new dims are dropped by estimator 47 | all_old_dims = [] 48 | for _, old_dims in self.reshapes.items(): 49 | all_old_dims += old_dims 50 | 51 | if X_out.ndim == X_in.ndim - len(all_old_dims) + len( 52 | self.reshapes 53 | ): 54 | drop_new_dims = False 55 | elif X_out.ndim == X_in.ndim - len(all_old_dims): 56 | drop_new_dims = True 57 | else: 58 | raise ValueError( 59 | "Inconsistent dimensions returned by estimator" 60 | ) 61 | 62 | for new_dim, old_dims in self.reshapes.items(): 63 | for d in old_dims: 64 | dims_new.remove(d) 65 | if not drop_new_dims: 66 | dims_new.append(new_dim) 67 | 68 | # string syntax 69 | else: 70 | # check if dim is dropped by estimator 71 | if X_out.ndim < X_in.ndim: 72 | dims_new.remove(self.reshapes) 73 | 74 | return dims_new 75 | 76 | def _restore_dims(self, X_in, X_out): 77 | """ Restore the dimensions of a reshaped DataArray. """ 78 | 79 | # dict syntax 80 | if hasattr(self.reshapes, "items"): 81 | 82 | # check if new dims are dropped by estimator 83 | all_old_dims = [] 84 | for _, old_dims in self.reshapes.items(): 85 | all_old_dims += old_dims 86 | 87 | if X_in.ndim == X_out.ndim - len(all_old_dims) + len( 88 | self.reshapes 89 | ): 90 | drop_new_dims = False 91 | elif X_in.ndim == X_out.ndim - len(all_old_dims): 92 | drop_new_dims = True 93 | else: 94 | raise ValueError( 95 | "Inconsistent dimensions returned by estimator" 96 | ) 97 | 98 | # get new dims 99 | dims_new = list(X_in.dims) 100 | dims_old = [] 101 | for d in dims_new: 102 | if d in self.reshapes: 103 | dims_old += self.reshapes[d] 104 | else: 105 | dims_old.append(d) 106 | 107 | if drop_new_dims: 108 | # TODO: figure out where to insert the dropped dims 109 | for d in all_old_dims: 110 | if d not in dims_old: 111 | dims_old.append(d) 112 | 113 | # string syntax 114 | else: 115 | dims_old = list(X_in.dims) 116 | # check if dim is dropped by estimator 117 | if X_out.ndim < X_in.ndim: 118 | # TODO: figure out where to insert the dropped dim 119 | dims_old.append(self.reshapes) 120 | 121 | return dims_old 122 | 123 | def _update_coords(self, X): 124 | """ Update the coordinates of a reshaped DataArray. """ 125 | 126 | coords_new = dict() 127 | 128 | # dict syntax 129 | if hasattr(self.reshapes, "items"): 130 | 131 | all_old_dims = [] 132 | for _, old_dims in self.reshapes.items(): 133 | all_old_dims += old_dims 134 | 135 | # drop all coords along the reshaped dimensions 136 | for c in X.coords: 137 | old_dims_in_c = [x for x in X[c].dims if x in all_old_dims] 138 | if any(old_dims_in_c) and c not in all_old_dims: 139 | c_t = X[c].isel(**{d: 0 for d in old_dims_in_c}) 140 | new_dims = [d for d in X[c].dims if d not in all_old_dims] 141 | coords_new[c] = (new_dims, c_t.drop(old_dims_in_c)) 142 | elif c not in all_old_dims: 143 | coords_new[c] = X[c] 144 | 145 | # string syntax 146 | else: 147 | # drop all coords along the reshaped dimensions 148 | for c in X.coords: 149 | if self.reshapes in X[c].dims and c != self.reshapes: 150 | c_t = X[c].isel(**{self.reshapes: 0}) 151 | new_dims = [d for d in X[c].dims if d != self.reshapes] 152 | coords_new[c] = (new_dims, c_t.drop(self.reshapes)) 153 | elif c != self.reshapes: 154 | coords_new[c] = X[c] 155 | 156 | return coords_new 157 | 158 | def _call_array_method(self, estimator, method, X): 159 | """ Call a method (predict, transform, ...) for DataArray input. """ 160 | 161 | if self.sample_dim is not None: 162 | # transpose to sample dim first, predict and transpose back 163 | order = self._get_transpose_order(X) 164 | X_arr = np.transpose(X.data, order) 165 | y = getattr(estimator, method)(X_arr) 166 | if y.ndim == X.ndim: 167 | y = np.transpose(y, np.argsort(order)) 168 | else: 169 | y = getattr(estimator, method)(X.data) 170 | 171 | # update dims 172 | if method == "inverse_transform": 173 | dims_new = self._restore_dims(X, y) 174 | else: 175 | dims_new = self._update_dims(X, y) 176 | 177 | return y, dims_new 178 | 179 | def _call_fitted(self, method, X): 180 | """ Call a method of a fitted estimator (predict, transform, ...). """ 181 | 182 | check_is_fitted(self, ["type_"]) 183 | 184 | if self.type_ == "DataArray": 185 | 186 | if not is_dataarray(X): 187 | raise ValueError( 188 | "This wrapper was fitted for DataArray inputs, but the " 189 | "provided X does not seem to be a DataArray." 190 | ) 191 | 192 | check_is_fitted(self, ["estimator_"]) 193 | 194 | if self.reshapes is not None: 195 | data, dims = self._call_array_method( 196 | self.estimator_, method, X 197 | ) 198 | coords = self._update_coords(X) 199 | return xr.DataArray(data, coords=coords, dims=dims) 200 | else: 201 | return xr.DataArray( 202 | getattr(self.estimator_, method)(X.data), 203 | coords=X.coords, 204 | dims=X.dims, 205 | ) 206 | 207 | elif self.type_ == "Dataset": 208 | 209 | if not is_dataset(X): 210 | raise ValueError( 211 | "This wrapper was fitted for Dataset inputs, but the " 212 | "provided X does not seem to be a Dataset." 213 | ) 214 | 215 | check_is_fitted(self, ["estimator_dict_"]) 216 | 217 | if self.reshapes is not None: 218 | data_vars = dict() 219 | for v, e in self.estimator_dict_.items(): 220 | yp_v, dims = self._call_array_method(e, method, X[v]) 221 | data_vars[v] = (dims, yp_v) 222 | coords = self._update_coords(X) 223 | return xr.Dataset(data_vars, coords=coords) 224 | else: 225 | data_vars = { 226 | v: (X[v].dims, getattr(e, method)(X[v].data)) 227 | for v, e in self.estimator_dict_.items() 228 | } 229 | return xr.Dataset(data_vars, coords=X.coords) 230 | 231 | elif self.type_ == "other": 232 | 233 | check_is_fitted(self, ["estimator_"]) 234 | 235 | return getattr(self.estimator_, method)(X) 236 | 237 | else: 238 | raise ValueError("Unexpected type_.") 239 | 240 | def _fit(self, X, y=None, **fit_params): 241 | """ Tranpose if necessary and fit. """ 242 | 243 | if self.sample_dim is not None: 244 | order = self._get_transpose_order(X) 245 | X_arr = np.transpose(X.data, order) 246 | if y is not None: 247 | y = self._transpose_y(X, y, order) 248 | else: 249 | X_arr = X.data 250 | 251 | estimator_ = self._make_estimator().fit(X_arr, y, **fit_params) 252 | 253 | return estimator_ 254 | 255 | def _partial_fit(self, estimator, X, y=None, **fit_params): 256 | """ Tranpose if necessary and partial_fit. """ 257 | 258 | if self.sample_dim is not None: 259 | order = self._get_transpose_order(X) 260 | X_arr = np.transpose(X.data, order) 261 | if y is not None: 262 | y = self._transpose_y(X, y, order) 263 | else: 264 | X_arr = X.data 265 | 266 | return estimator.partial_fit(X_arr, y, **fit_params) 267 | 268 | def _fit_transform(self, estimator, X, y=None, **fit_params): 269 | """ Fit & transform with ``estimator`` and update coords and dims. """ 270 | 271 | if self.sample_dim is not None: 272 | # transpose to sample dim first, transform and transpose back 273 | order = self._get_transpose_order(X) 274 | X_arr = np.transpose(X.data, order) 275 | if y is not None: 276 | y = self._transpose_y(X, y, order) 277 | Xt = estimator.fit_transform(X_arr, y, **fit_params) 278 | if Xt.ndim == X.ndim: 279 | # TODO: handle the other case 280 | Xt = np.transpose(Xt, np.argsort(order)) 281 | else: 282 | Xt = estimator.fit_transform(X.data, y, **fit_params) 283 | 284 | # update dims 285 | dims_new = self._update_dims(X, Xt) 286 | 287 | return Xt, dims_new 288 | 289 | 290 | # -- Wrapper methods -- 291 | def partial_fit(self, X, y=None, **fit_params): 292 | """ A wrapper around the partial_fit function. 293 | 294 | Parameters 295 | ---------- 296 | X : xarray DataArray, Dataset or other array-like 297 | The input samples. 298 | 299 | y : xarray DataArray, Dataset or other array-like 300 | The target values. 301 | """ 302 | 303 | if self.estimator is None: 304 | raise ValueError("You must specify an estimator instance to wrap.") 305 | 306 | if is_target(y): 307 | y = y(X) 308 | 309 | if is_dataarray(X): 310 | 311 | if not hasattr(self, "type_"): 312 | self.type_ = "DataArray" 313 | self.estimator_ = self._fit(X, y, **fit_params) 314 | elif self.type_ == "DataArray": 315 | self.estimator_ = self._partial_fit( 316 | self.estimator_, X, y, **fit_params 317 | ) 318 | else: 319 | raise ValueError( 320 | "This wrapper was not fitted for DataArray inputs." 321 | ) 322 | 323 | # TODO: check if this needs to be removed for compat wrappers 324 | for v in vars(self.estimator_): 325 | if v.endswith("_") and not v.startswith("_"): 326 | setattr(self, v, getattr(self.estimator_, v)) 327 | 328 | elif is_dataset(X): 329 | 330 | if not hasattr(self, "type_"): 331 | self.type_ = "Dataset" 332 | self.estimator_dict_ = { 333 | v: self._fit(X[v], y, **fit_params) for v in X.data_vars 334 | } 335 | elif self.type_ == "Dataset": 336 | self.estimator_dict_ = { 337 | v: self._partial_fit( 338 | self.estimator_dict_[v], X[v], y, **fit_params 339 | ) 340 | for v in X.data_vars 341 | } 342 | else: 343 | raise ValueError("This wrapper was not fitted for Dataset inputs.") 344 | 345 | # TODO: check if this needs to be removed for compat wrappers 346 | for e_name, e in self.estimator_dict_.items(): 347 | for v in vars(e): 348 | if v.endswith("_") and not v.startswith("_"): 349 | if hasattr(self, v): 350 | getattr(self, v).update({e_name: getattr(e, v)}) 351 | else: 352 | setattr(self, v, {e_name: getattr(e, v)}) 353 | 354 | else: 355 | 356 | if not hasattr(self, "type_"): 357 | self.type_ = "other" 358 | if y is None: 359 | X = check_array(X) 360 | else: 361 | X, y = check_X_y(X, y) 362 | self.estimator_ = clone(self.estimator).fit(X, y, **fit_params) 363 | elif self.type_ == "other": 364 | self.estimator_ = self.estimator_.partial_fit(X, y, **fit_params) 365 | else: 366 | raise ValueError("This wrapper was not fitted for other inputs.") 367 | 368 | # TODO: check if this needs to be removed for compat wrappers 369 | for v in vars(self.estimator_): 370 | if v.endswith("_") and not v.startswith("_"): 371 | setattr(self, v, getattr(self.estimator_, v)) 372 | 373 | return self 374 | 375 | 376 | def predict(self, X): 377 | """ A wrapper around the prediction function. 378 | 379 | Parameters 380 | ---------- 381 | X : xarray DataArray, Dataset or other array-like 382 | The input samples. 383 | 384 | Returns 385 | ------- 386 | y : xarray DataArray, Dataset or other array-like 387 | The predicted output. 388 | """ 389 | 390 | return self._call_fitted("predict", X) 391 | 392 | 393 | def predict_proba(self, X): 394 | """ A wrapper around the predict_proba function. 395 | 396 | Parameters 397 | ---------- 398 | X : xarray DataArray, Dataset or other array-like 399 | The input samples. 400 | 401 | Returns 402 | ------- 403 | y : xarray DataArray, Dataset or other array-like 404 | The predicted output. 405 | """ 406 | 407 | return self._call_fitted("predict_proba", X) 408 | 409 | 410 | def predict_log_proba(self, X): 411 | """ A wrapper around the predict_log_proba function. 412 | 413 | Parameters 414 | ---------- 415 | X : xarray DataArray, Dataset or other array-like 416 | The input samples. 417 | 418 | Returns 419 | ------- 420 | y : xarray DataArray, Dataset or other array-like 421 | The predicted output. 422 | """ 423 | 424 | return self._call_fitted("predict_log_proba", X) 425 | 426 | 427 | def decision_function(self, X): 428 | """ A wrapper around the decision_function function. 429 | 430 | Parameters 431 | ---------- 432 | X : xarray DataArray, Dataset or other array-like 433 | The input samples. 434 | 435 | Returns 436 | ------- 437 | y : xarray DataArray, Dataset or other array-like 438 | The predicted output. 439 | """ 440 | 441 | return self._call_fitted("decision_function", X) 442 | 443 | 444 | def transform(self, X): 445 | """ A wrapper around the transformation function. 446 | 447 | Parameters 448 | ---------- 449 | X : xarray DataArray, Dataset or other array-like 450 | The input samples. 451 | 452 | Returns 453 | ------- 454 | Xt : xarray DataArray, Dataset or other array-like 455 | The transformed output. 456 | """ 457 | 458 | return self._call_fitted("transform", X) 459 | 460 | 461 | def inverse_transform(self, X): 462 | """ A wrapper around the inverse transformation function. 463 | 464 | Parameters 465 | ---------- 466 | X : xarray DataArray, Dataset or other array-like 467 | The input samples. 468 | 469 | Returns 470 | ------- 471 | Xt : xarray DataArray, Dataset or other array-like 472 | The transformed output. 473 | """ 474 | 475 | return self._call_fitted("inverse_transform", X) 476 | 477 | 478 | def fit_transform(self, X, y=None, **fit_params): 479 | """ A wrapper around the fit_transform function. 480 | 481 | Parameters 482 | ---------- 483 | X : xarray DataArray, Dataset or other array-like 484 | The input samples. 485 | 486 | y : xarray DataArray, Dataset or other array-like 487 | The target values. 488 | 489 | Returns 490 | ------- 491 | Xt : xarray DataArray, Dataset or other array-like 492 | The transformed output. 493 | """ 494 | 495 | if self.estimator is None: 496 | raise ValueError("You must specify an estimator instance to wrap.") 497 | 498 | if is_target(y): 499 | y = y(X) 500 | 501 | if is_dataarray(X): 502 | 503 | self.type_ = "DataArray" 504 | self.estimator_ = clone(self.estimator) 505 | 506 | if self.reshapes is not None: 507 | data, dims = self._fit_transform( 508 | self.estimator_, X, y, **fit_params 509 | ) 510 | coords = self._update_coords(X) 511 | return xr.DataArray(data, coords=coords, dims=dims) 512 | else: 513 | return xr.DataArray( 514 | self.estimator_.fit_transform(X.data, y, **fit_params), 515 | coords=X.coords, 516 | dims=X.dims, 517 | ) 518 | 519 | elif is_dataset(X): 520 | 521 | self.type_ = "Dataset" 522 | self.estimator_dict_ = {v: clone(self.estimator) for v in X.data_vars} 523 | 524 | if self.reshapes is not None: 525 | data_vars = dict() 526 | for v, e in self.estimator_dict_.items(): 527 | yp_v, dims = self._fit_transform(e, X[v], y, **fit_params) 528 | data_vars[v] = (dims, yp_v) 529 | coords = self._update_coords(X) 530 | return xr.Dataset(data_vars, coords=coords) 531 | else: 532 | data_vars = { 533 | v: (X[v].dims, e.fit_transform(X[v].data, y, **fit_params)) 534 | for v, e in self.estimator_dict_.items() 535 | } 536 | return xr.Dataset(data_vars, coords=X.coords) 537 | 538 | else: 539 | 540 | self.type_ = "other" 541 | if y is None: 542 | X = check_array(X) 543 | else: 544 | X, y = check_X_y(X, y) 545 | 546 | self.estimator_ = clone(self.estimator) 547 | Xt = self.estimator_.fit_transform(X, y, **fit_params) 548 | 549 | for v in vars(self.estimator_): 550 | if v.endswith("_") and not v.startswith("_"): 551 | setattr(self, v, getattr(self.estimator_, v)) 552 | 553 | return Xt 554 | 555 | 556 | def score(self, X, y, sample_weight=None): 557 | """ Returns the score of the prediction. 558 | 559 | Parameters 560 | ---------- 561 | X : xarray Dataset or Dataset 562 | The training set. 563 | 564 | y : xarray Dataset or Dataset 565 | The target values. 566 | 567 | sample_weight : array-like, shape = [n_samples], optional 568 | Sample weights. 569 | 570 | Returns 571 | ------- 572 | score : float 573 | Score of self.predict(X) wrt. y. 574 | """ 575 | 576 | if self.type_ == "DataArray": 577 | 578 | if not is_dataarray(X): 579 | raise ValueError( 580 | "This wrapper was fitted for DataArray inputs, but the " 581 | "provided X does not seem to be a DataArray." 582 | ) 583 | 584 | check_is_fitted(self, ["estimator_"]) 585 | 586 | if is_target(y): 587 | y = y(X) 588 | 589 | return self.estimator_.score(X, y, sample_weight) 590 | 591 | elif self.type_ == "Dataset": 592 | 593 | if not is_dataset(X): 594 | raise ValueError( 595 | "This wrapper was fitted for Dataset inputs, but the " 596 | "provided X does not seem to be a Dataset." 597 | ) 598 | 599 | check_is_fitted(self, ["estimator_dict_"]) 600 | 601 | # TODO: this probably has to be done for each data_var individually 602 | if is_target(y): 603 | y = y(X) 604 | 605 | score_list = [ 606 | e.score(X[v], y, sample_weight) 607 | for v, e in self.estimator_dict_.items() 608 | ] 609 | 610 | return np.mean(score_list) 611 | 612 | elif self.type_ == "other": 613 | 614 | check_is_fitted(self, ["estimator_"]) 615 | 616 | return self.estimator_.score(X, y, sample_weight) 617 | 618 | else: 619 | raise ValueError("Unexpected type_.") 620 | 621 | 622 | # -- Wrapper mixins -- 623 | class _ImplementsPartialFitMixin(_CommonEstimatorWrapper): 624 | 625 | partial_fit = partial_fit 626 | 627 | 628 | class _ImplementsPredictMixin(_CommonEstimatorWrapper): 629 | 630 | predict = predict 631 | 632 | 633 | class _ImplementsPredictProbaMixin(_CommonEstimatorWrapper): 634 | 635 | predict_proba = predict_proba 636 | 637 | 638 | class _ImplementsPredictLogProbaMixin(_CommonEstimatorWrapper): 639 | 640 | predict_log_proba = predict_log_proba 641 | 642 | 643 | class _ImplementsDecisionFunctionMixin(_CommonEstimatorWrapper): 644 | 645 | decision_function = decision_function 646 | 647 | 648 | class _ImplementsTransformMixin(_CommonEstimatorWrapper): 649 | 650 | transform = transform 651 | 652 | 653 | class _ImplementsInverseTransformMixin(_CommonEstimatorWrapper): 654 | 655 | inverse_transform = inverse_transform 656 | 657 | 658 | class _ImplementsFitTransformMixin(_CommonEstimatorWrapper): 659 | 660 | fit_transform = fit_transform 661 | 662 | 663 | class _ImplementsScoreMixin(_CommonEstimatorWrapper): 664 | 665 | score = score 666 | -------------------------------------------------------------------------------- /sklearn_xarray/common/wrappers.py: -------------------------------------------------------------------------------- 1 | """ ``sklearn_xarray.common.wrappers`` """ 2 | 3 | from types import MethodType 4 | 5 | import warnings 6 | 7 | import six 8 | from sklearn.base import BaseEstimator, clone 9 | from sklearn.utils.validation import check_X_y, check_array 10 | 11 | from .base import ( 12 | partial_fit, 13 | predict, 14 | predict_proba, 15 | predict_log_proba, 16 | decision_function, 17 | transform, 18 | inverse_transform, 19 | fit_transform, 20 | score, 21 | _CommonEstimatorWrapper, 22 | _ImplementsPredictMixin, 23 | _ImplementsScoreMixin, 24 | _ImplementsTransformMixin, 25 | _ImplementsFitTransformMixin, 26 | _ImplementsInverseTransformMixin, 27 | ) 28 | 29 | from sklearn_xarray.utils import is_dataarray, is_dataset, is_target 30 | 31 | 32 | # mapping from wrapped methods to wrapper methods 33 | _method_map = { 34 | "partial_fit": partial_fit, 35 | "predict": predict, 36 | "predict_proba": predict_proba, 37 | "predict_log_proba": predict_log_proba, 38 | "decision_function": decision_function, 39 | "transform": transform, 40 | "inverse_transform": inverse_transform, 41 | "fit_transform": fit_transform, 42 | "score": score, 43 | } 44 | 45 | 46 | def wrap(estimator, reshapes=None, sample_dim=None, compat=False, **kwargs): 47 | """ Wrap an sklearn estimator for xarray objects. 48 | 49 | Parameters 50 | ---------- 51 | estimator : sklearn estimator class or instance 52 | The estimator this instance wraps around. 53 | 54 | reshapes : str or dict, optional 55 | The dimension(s) reshaped by this estimator. Any coordinates in the 56 | DataArray along these dimensions will be dropped. If the estimator 57 | drops this dimension (e.g. a binary classifier returning a 1D vector), 58 | the dimension itself will also be dropped. 59 | 60 | You can specify multiple dimensions mapping to multiple new dimensions 61 | with a dict whose items are lists of reshaped dimensions, e.g. 62 | ``{'new_feature': ['feature_1', 'feature_2'], ...}`` 63 | 64 | sample_dim : str, optional 65 | The name of the dimension that represents the samples. By default, 66 | the wrapper will assume that this is the first dimension in the array. 67 | 68 | compat : bool, default False 69 | If True, the method will return a ``CompatEstimatorWrapper`` instead 70 | of an ``EstimatorWrapper``. This might be necessary when the 71 | estimator defines parameters with the same name as the wrapper. 72 | 73 | Returns 74 | ------- 75 | A wrapped estimator. 76 | """ 77 | 78 | if compat: 79 | return CompatEstimatorWrapper( 80 | estimator=estimator, 81 | reshapes=reshapes, 82 | sample_dim=sample_dim, 83 | **kwargs 84 | ) 85 | else: 86 | return EstimatorWrapper( 87 | estimator=estimator, 88 | reshapes=reshapes, 89 | sample_dim=sample_dim, 90 | **kwargs 91 | ) 92 | 93 | 94 | class EstimatorWrapper(_CommonEstimatorWrapper): 95 | """ A wrapper around sklearn estimators compatible with xarray objects. 96 | 97 | Parameters 98 | ---------- 99 | estimator : sklearn estimator 100 | The estimator instance this instance wraps around. 101 | 102 | reshapes : str or dict, optional 103 | The dimension(s) reshaped by this estimator. Any coordinates in the 104 | DataArray along these dimensions will be dropped. If the estimator 105 | drops this dimension (e.g. a binary classifier returning a 1D vector), 106 | the dimension itself will also be dropped. 107 | 108 | You can specify multiple dimensions mapping to multiple new dimensions 109 | with a dict whose items are lists of reshaped dimensions, e.g. 110 | ``{'new_feature': ['feature_1', 'feature_2'], ...}`` 111 | 112 | sample_dim : str, optional 113 | The name of the dimension that represents the samples. By default, 114 | the wrapper will assume that this is the first dimension in the array. 115 | """ 116 | 117 | def __init__( 118 | self, estimator=None, reshapes=None, sample_dim=None, **kwargs 119 | ): 120 | 121 | if "compat" in kwargs: 122 | kwargs.pop("compat") 123 | warnings.simplefilter("always", DeprecationWarning) 124 | warnings.warn( 125 | "The compat argument of EstimatorWrapper is deprecated and " 126 | "will be removed in a future version.", 127 | DeprecationWarning, 128 | ) 129 | warnings.simplefilter("ignore", DeprecationWarning) 130 | 131 | if isinstance(estimator, type): 132 | self.estimator = estimator(**kwargs) 133 | params = self.estimator.get_params() 134 | else: 135 | self.estimator = estimator 136 | params = estimator.get_params() 137 | params.update(kwargs) 138 | 139 | self.reshapes = reshapes 140 | self.sample_dim = sample_dim 141 | 142 | for p in params: 143 | setattr(self, p, params[p]) 144 | 145 | self._param_names = ( 146 | self._get_param_names() + self.estimator._get_param_names() 147 | ) 148 | 149 | self._decorate() 150 | 151 | def __getstate__(self): 152 | 153 | state = self.__dict__.copy() 154 | 155 | for m in _method_map: 156 | if hasattr(self.estimator, m): 157 | state.pop(m) 158 | 159 | return state 160 | 161 | def __setstate__(self, state): 162 | 163 | self.__dict__ = state 164 | self._decorate() 165 | 166 | def _decorate(self): 167 | """ Decorate this instance with wrapping methods for the estimator. """ 168 | 169 | # TODO: check if this needs to be removed for compat wrappers 170 | if hasattr(self.estimator, "_estimator_type"): 171 | setattr(self, "_estimator_type", self.estimator._estimator_type) 172 | 173 | for m in _method_map: 174 | if hasattr(self.estimator, m): 175 | if six.PY2: 176 | setattr( 177 | self, 178 | m, 179 | MethodType(_method_map[m], self, EstimatorWrapper), 180 | ) 181 | else: 182 | setattr(self, m, MethodType(_method_map[m], self)) 183 | 184 | def _make_estimator(self): 185 | """ Return an instance of the wrapped estimator. """ 186 | 187 | params = { 188 | p: getattr(self, p) for p in self.estimator._get_param_names() 189 | } 190 | 191 | return type(self.estimator)(**params) 192 | 193 | def _reset(self): 194 | """ Reset internal data-dependent state of the wrapper. 195 | 196 | __init__ parameters are not touched. 197 | """ 198 | 199 | for v in vars(self).copy(): 200 | if v.endswith("_") and not v.startswith("_"): 201 | delattr(self, v) 202 | 203 | def get_params(self, deep=True): 204 | """ Get parameters for this estimator. 205 | 206 | Parameters 207 | ---------- 208 | deep : boolean, optional 209 | If True, will return the parameters for this estimator and 210 | contained subobjects that are estimators. 211 | 212 | Returns 213 | ------- 214 | params : mapping of string to any 215 | Parameter names mapped to their values. 216 | """ 217 | 218 | # TODO: check if this causes problems for wrapped nested estimators 219 | params = BaseEstimator.get_params(self, deep=False) 220 | params.update({p: getattr(self, p) for p in self._param_names}) 221 | 222 | return params 223 | 224 | def set_params(self, **params): 225 | """ Set the parameters of this estimator. 226 | 227 | The method works on simple estimators as well as on nested objects 228 | (such as pipelines). The latter have parameters of the form 229 | ``__`` so that it's possible to update each 230 | component of a nested object. 231 | 232 | Returns 233 | ------- 234 | self 235 | """ 236 | 237 | for p in self._param_names: 238 | if p in params: 239 | setattr(self, p, params[p]) 240 | 241 | return self 242 | 243 | def fit(self, X, y=None, **fit_params): 244 | """ A wrapper around the fitting function. 245 | 246 | Parameters 247 | ---------- 248 | X : xarray DataArray, Dataset other other array-like 249 | The training input samples. 250 | 251 | y : xarray DataArray, Dataset other other array-like 252 | The target values. 253 | 254 | Returns 255 | ------- 256 | Returns self. 257 | """ 258 | 259 | if self.estimator is None: 260 | raise ValueError("You must specify an estimator instance to wrap.") 261 | 262 | self._reset() 263 | 264 | if is_target(y): 265 | y = y(X) 266 | 267 | if is_dataarray(X): 268 | 269 | self.type_ = "DataArray" 270 | self.estimator_ = self._fit(X, y, **fit_params) 271 | 272 | # TODO: check if this needs to be removed for compat wrappers 273 | for v in vars(self.estimator_): 274 | if v.endswith("_") and not v.startswith("_"): 275 | setattr(self, v, getattr(self.estimator_, v)) 276 | 277 | elif is_dataset(X): 278 | 279 | self.type_ = "Dataset" 280 | self.estimator_dict_ = { 281 | v: self._fit(X[v], y, **fit_params) for v in X.data_vars 282 | } 283 | 284 | # TODO: check if this needs to be removed for compat wrappers 285 | for e_name, e in six.iteritems(self.estimator_dict_): 286 | for v in vars(e): 287 | if v.endswith("_") and not v.startswith("_"): 288 | if hasattr(self, v): 289 | getattr(self, v).update({e_name: getattr(e, v)}) 290 | else: 291 | setattr(self, v, {e_name: getattr(e, v)}) 292 | 293 | else: 294 | 295 | self.type_ = "other" 296 | if y is None: 297 | X = check_array(X) 298 | else: 299 | X, y = check_X_y(X, y) 300 | 301 | self.estimator_ = self._make_estimator().fit(X, y, **fit_params) 302 | 303 | # TODO: check if this needs to be removed for compat wrappers 304 | for v in vars(self.estimator_): 305 | if v.endswith("_") and not v.startswith("_"): 306 | setattr(self, v, getattr(self.estimator_, v)) 307 | 308 | return self 309 | 310 | 311 | class CompatEstimatorWrapper(EstimatorWrapper): 312 | """ A wrapper around sklearn estimators compatible with xarray objects. 313 | 314 | Parameters 315 | ---------- 316 | estimator : sklearn estimator 317 | The estimator instance this instance wraps around. 318 | 319 | reshapes : str or dict, optional 320 | The dimension(s) reshaped by this estimator. Any coordinates in the 321 | DataArray along these dimensions will be dropped. If the estimator 322 | drops this dimension (e.g. a binary classifier returning a 1D vector), 323 | the dimension itself will also be dropped. 324 | 325 | You can specify multiple dimensions mapping to multiple new dimensions 326 | with a dict whose items are lists of reshaped dimensions, e.g. 327 | ``{'new_feature': ['feature_1', 'feature_2'], ...}`` 328 | 329 | sample_dim : str, optional 330 | The name of the dimension that represents the samples. By default, 331 | the wrapper will assume that this is the first dimension in the array. 332 | """ 333 | 334 | def __init__( 335 | self, estimator=None, reshapes=None, sample_dim=None, **kwargs 336 | ): 337 | 338 | if isinstance(estimator, type): 339 | self.estimator = estimator(**kwargs) 340 | else: 341 | self.estimator = estimator 342 | self.estimator.set_params(**kwargs) 343 | 344 | self.reshapes = reshapes 345 | self.sample_dim = sample_dim 346 | 347 | if "compat" in kwargs: 348 | warnings.simplefilter("always", DeprecationWarning) 349 | warnings.warn( 350 | "The compat argument of EstimatorWrapper is deprecated and " 351 | "will be removed in a future version.", 352 | DeprecationWarning, 353 | ) 354 | warnings.simplefilter("ignore", DeprecationWarning) 355 | 356 | self._decorate() 357 | 358 | def _make_estimator(self): 359 | """ Return an instance of the wrapped estimator. """ 360 | 361 | return clone(self.estimator) 362 | 363 | def get_params(self, deep=True): 364 | """Get parameters for this estimator. 365 | 366 | Parameters 367 | ---------- 368 | deep : boolean, optional 369 | If True, will return the parameters for this estimator and 370 | contained subobjects that are estimators. 371 | 372 | Returns 373 | ------- 374 | params : mapping of string to any 375 | Parameter names mapped to their values. 376 | """ 377 | 378 | return BaseEstimator.get_params(self, deep=deep) 379 | 380 | def set_params(self, **params): 381 | """Set the parameters of this estimator. 382 | 383 | Valid parameter keys can be listed with ``get_params()``. 384 | 385 | Returns 386 | ------- 387 | self 388 | """ 389 | 390 | return BaseEstimator.set_params(self, **params) 391 | 392 | 393 | class TransformerWrapper( 394 | EstimatorWrapper, 395 | _ImplementsTransformMixin, 396 | _ImplementsFitTransformMixin, 397 | _ImplementsInverseTransformMixin, 398 | ): 399 | """ A wrapper around sklearn transformers compatible with xarray objects. 400 | 401 | Parameters 402 | ---------- 403 | estimator : sklearn estimator 404 | The estimator this instance wraps around. 405 | 406 | reshapes : str or dict, optional 407 | The dimension reshaped by this estimator. 408 | """ 409 | 410 | 411 | class RegressorWrapper( 412 | EstimatorWrapper, _ImplementsPredictMixin, _ImplementsScoreMixin 413 | ): 414 | """ A wrapper around sklearn regressors compatible with xarray objects. 415 | 416 | Parameters 417 | ---------- 418 | estimator : sklearn estimator 419 | The estimator this instance wraps around. 420 | 421 | reshapes : str or dict, optional 422 | The dimension reshaped by this estimator. 423 | """ 424 | 425 | _estimator_type = "regressor" 426 | 427 | 428 | class ClassifierWrapper( 429 | EstimatorWrapper, _ImplementsPredictMixin, _ImplementsScoreMixin 430 | ): 431 | """ A wrapper around sklearn classifiers compatible with xarray objects. 432 | 433 | Parameters 434 | ---------- 435 | estimator : sklearn estimator 436 | The estimator this instance wraps around. 437 | 438 | reshapes : str or dict, optional 439 | The dimension reshaped by this estimator. 440 | """ 441 | 442 | _estimator_type = "classifier" 443 | -------------------------------------------------------------------------------- /sklearn_xarray/datasets.py: -------------------------------------------------------------------------------- 1 | """ ``sklearn_xarray.datasets`` """ 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import xarray as xr 6 | 7 | 8 | def load_dummy_dataarray(): 9 | """ Load a DataArray for demonstration purposes. """ 10 | 11 | return xr.DataArray( 12 | np.random.random((100, 10)), 13 | coords={"sample": range(100), "feature": range(10)}, 14 | dims=("sample", "feature"), 15 | ) 16 | 17 | 18 | def load_dummy_dataset(): 19 | """ Load a Dataset for demonstration purposes. """ 20 | 21 | return xr.Dataset( 22 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 23 | coords={"sample": range(100), "feature": range(10)}, 24 | ) 25 | 26 | 27 | def load_digits_dataarray(load_images=False, nan_probability=0): 28 | """ Load a the 'digits' dataset from sklearn as a DataArray. 29 | 30 | Parameters 31 | ---------- 32 | load_images : bool, optional 33 | If true, the DataArray will contain the two-dimensional images as 34 | data instead of the vectorized samples. 35 | 36 | nan_probability : float between 0 and 1 37 | The probability with which a sample is injected with NaN values. For 38 | demonstration purposes only. 39 | """ 40 | 41 | from sklearn.datasets import load_digits 42 | 43 | if load_images: 44 | 45 | bunch = load_digits() 46 | data = bunch.images 47 | 48 | if nan_probability > 0: 49 | for i in range(data.shape[0]): 50 | if np.random.rand(1) < nan_probability: 51 | data[i, 0, 0] = np.nan 52 | 53 | return xr.DataArray( 54 | data, 55 | coords={ 56 | "sample": range(data.shape[0]), 57 | "row": range(data.shape[1]), 58 | "col": range(data.shape[2]), 59 | "digit": (["sample"], bunch.target), 60 | }, 61 | dims=("sample", "row", "col"), 62 | ) 63 | 64 | else: 65 | 66 | data, target = load_digits(return_X_y=True) 67 | 68 | if nan_probability > 0: 69 | for i in range(data.shape[0]): 70 | if np.random.rand(1) < nan_probability: 71 | data[i, 0] = np.nan 72 | 73 | return xr.DataArray( 74 | data, 75 | coords={ 76 | "sample": range(data.shape[0]), 77 | "feature": range(data.shape[1]), 78 | "digit": (["sample"], target), 79 | }, 80 | dims=("sample", "feature"), 81 | ) 82 | 83 | 84 | def load_wisdm_dataarray( 85 | url="http://www.cis.fordham.edu/wisdm/includes/" 86 | "datasets/latest/WISDM_ar_latest.tar.gz", 87 | file="WISDM_ar_v1.1/WISDM_ar_v1.1_raw.txt", 88 | folder="data/", 89 | tmp_file="widsm.tar.gz", 90 | ): 91 | """ Load the WISDM activity recognition dataset. 92 | 93 | Parameters 94 | ---------- 95 | url : str, optional 96 | The URL of the dataset. 97 | 98 | file : str, optional 99 | The file containing the data. 100 | 101 | folder : str, optional 102 | The folder where the data will be downloaded and extracted to. 103 | 104 | tmp_file : str, optional 105 | The name of the temporary .tar file in the folder. 106 | 107 | Returns 108 | ------- 109 | X: xarray DataArray 110 | The loaded dataset. 111 | 112 | """ 113 | 114 | import os 115 | import tarfile 116 | import six.moves.urllib.request as ul 117 | 118 | if not os.path.isfile(os.path.join(folder, file)): 119 | ul.urlretrieve(url, tmp_file) 120 | tar = tarfile.open(tmp_file) 121 | tar.extractall(folder) 122 | tar.close() 123 | os.remove(tmp_file) 124 | 125 | column_names = ["subject", "activity", "timestamp", "x", "y", "z"] 126 | df = pd.read_csv( 127 | os.path.join(folder, file), 128 | header=None, 129 | names=column_names, 130 | comment=";", 131 | ) 132 | 133 | time = pd.date_range(start=0, periods=df.shape[0], freq="50ms") 134 | 135 | coords = { 136 | "subject": ("sample", df.subject), 137 | "activity": ("sample", df.activity), 138 | "sample": time, 139 | "axis": ["x", "y", "z"], 140 | } 141 | 142 | X = xr.DataArray(df.iloc[:, 3:6], coords=coords, dims=("sample", "axis")) 143 | 144 | return X 145 | -------------------------------------------------------------------------------- /sklearn_xarray/externals/__init__.py: -------------------------------------------------------------------------------- 1 | """" ``sklearn_xarray.externals`` """ 2 | 3 | try: 4 | import numpy_groupies 5 | except ImportError: 6 | from . import _numpy_groupies_np as numpy_groupies 7 | -------------------------------------------------------------------------------- /sklearn_xarray/externals/_numpy_groupies_np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | _doc_str = """ 5 | See readme file at https://github.com/ml31415/numpy-groupies for a full 6 | description. Below we reproduce the "Full description of inputs" 7 | section from that readme, note that the text below makes references to 8 | other portions of the readme that are not shown here. 9 | 10 | group_idx: 11 | this is an array of non-negative integers, to be used as the "labels" 12 | with which to group the values in ``a``. Although we have so far 13 | assumed that ``group_idx`` is one-dimesnaional, and the same length as 14 | ``a``, it can in fact be two-dimensional (or some form of nested 15 | sequences that can be converted to 2D). When ``group_idx`` is 2D, the 16 | size of the 0th dimension corresponds to the number of dimesnions in 17 | the output, i.e. ``group_idx[i,j]`` gives the index into the ith 18 | dimension in the output 19 | for ``a[j]``. Note that ``a`` should still be 1D (or scalar), with 20 | length matching ``group_idx.shape[1]``. 21 | a: 22 | this is the array of values to be aggregated. See above for a 23 | simple demonstration of what this means. ``a`` will normally be a 24 | one-dimensional array, however it can also be a scalar in some cases. 25 | func: default='sum' 26 | the function to use for aggregation. See the section above for 27 | details. Note that the simplest way to specify the function is using a 28 | string (e.g. ``func='max'``) however a number of aliases are also 29 | defined (e.g. you can use the ``func=np.max``, or even ``func=max``, 30 | where ``max`` is the 31 | builtin function). To check the available aliases see ``utils.py``. 32 | size: default=None 33 | the shape of the output array. If ``None``, the maximum value in 34 | ``group_idx`` will set the size of the output. Note that for 35 | multidimensional output you need to list the size of each dimension 36 | here, or give ``None``. 37 | fill_value: default=0 38 | in the example above, group 2 does not have any data, so requires some 39 | kind of filling value - in this case the default of ``0`` is used. If 40 | you had set ``fill_value=nan`` or something else, that value would 41 | appear instead of ``0`` for the 2 element in the output. Note that 42 | there are some subtle interactions between what is permitted for 43 | ``fill_value`` and the input/output ``dtype`` - exceptions should be 44 | raised in most cases to alert the programmer if issue arrise. 45 | order: default='C' 46 | this is relevant only for multimensional output. It controls the 47 | layout of the output array in memory, can be ``'F'`` for fortran-style. 48 | dtype: default=None 49 | the ``dtype`` of the output. By default something sensible is chosen 50 | based on the input, aggregation function, and ``fill_value``. 51 | ddof: default=0 52 | passed through into calculations of variance and standard deviation 53 | (see above). 54 | """ 55 | 56 | _funcs_common = ( 57 | "first last len mean var std allnan anynan max min argmax argmin".split() 58 | ) 59 | _no_separate_nan_version = {"sort", "rsort", "array", "allnan", "anynan"} 60 | 61 | _alias_str = { 62 | "or": "any", 63 | "and": "all", 64 | "add": "sum", 65 | "count": "len", 66 | "plus": "sum", 67 | "multiply": "prod", 68 | "product": "prod", 69 | "times": "prod", 70 | "amax": "max", 71 | "maximum": "max", 72 | "amin": "min", 73 | "minimum": "min", 74 | "split": "array", 75 | "splice": "array", 76 | "sorted": "sort", 77 | "asort": "sort", 78 | "asorted": "sort", 79 | "rsorted": "rsort", 80 | "dsort": "rsort", 81 | "dsorted": "rsort", 82 | } 83 | 84 | _alias_builtin = { 85 | all: "all", 86 | any: "any", 87 | len: "len", 88 | max: "max", 89 | min: "min", 90 | sum: "sum", 91 | sorted: "sort", 92 | slice: "array", 93 | list: "array", 94 | } 95 | 96 | 97 | def get_aliasing(*extra): 98 | """The assembles the dict mapping strings and functions to the list of 99 | supported function names: 100 | e.g. alias['add'] = 'sum' and alias[sorted] = 'sort' 101 | This funciton should only be called during import. 102 | """ 103 | alias = dict((k, k) for k in _funcs_common) 104 | alias.update(_alias_str) 105 | alias.update((fn, fn) for fn in _alias_builtin.values()) 106 | alias.update(_alias_builtin) 107 | for d in extra: 108 | alias.update(d) 109 | alias.update((k, k) for k in set(alias.values())) 110 | # Treat nan-functions as firstclass member and add them directly 111 | for key in set(alias.values()): 112 | if key not in _no_separate_nan_version: 113 | key = "nan" + key 114 | alias[key] = key 115 | return alias 116 | 117 | 118 | aliasing_purepy = get_aliasing() 119 | 120 | 121 | def get_func(func, aliasing, implementations): 122 | """ Return the key of a found implementation or the func itself """ 123 | try: 124 | func_str = aliasing[func] 125 | except KeyError: 126 | if callable(func): 127 | return func 128 | else: 129 | if func_str in implementations: 130 | return func_str 131 | if ( 132 | func_str.startswith("nan") 133 | and func_str[3:] in _no_separate_nan_version 134 | ): 135 | raise ValueError("%s does not have a nan-version" % func_str[3:]) 136 | else: 137 | raise NotImplementedError("No such function available") 138 | raise ValueError( 139 | "func %s is neither a valid function string nor a " 140 | "callable object" % func 141 | ) 142 | 143 | 144 | def check_boolean(x): 145 | if x not in (0, 1): 146 | raise ValueError("Value not boolean") 147 | 148 | 149 | try: 150 | basestring # Attempt to evaluate basestring 151 | 152 | def isstr(s): 153 | return isinstance(s, basestring) 154 | 155 | 156 | except NameError: 157 | # Probably Python 3.x 158 | def isstr(s): 159 | return isinstance(s, str) 160 | 161 | 162 | try: 163 | import numpy as np 164 | except ImportError: 165 | pass 166 | else: 167 | _alias_numpy = { 168 | np.add: "sum", 169 | np.sum: "sum", 170 | np.any: "any", 171 | np.all: "all", 172 | np.multiply: "prod", 173 | np.prod: "prod", 174 | np.amin: "min", 175 | np.min: "min", 176 | np.minimum: "min", 177 | np.amax: "max", 178 | np.max: "max", 179 | np.maximum: "max", 180 | np.argmax: "argmax", 181 | np.argmin: "argmin", 182 | np.mean: "mean", 183 | np.std: "std", 184 | np.var: "var", 185 | np.array: "array", 186 | np.asarray: "array", 187 | np.sort: "sort", 188 | np.nansum: "nansum", 189 | np.nanprod: "nanprod", 190 | np.nanmean: "nanmean", 191 | np.nanvar: "nanvar", 192 | np.nanmax: "nanmax", 193 | np.nanmin: "nanmin", 194 | np.nanstd: "nanstd", 195 | np.nanargmax: "nanargmax", 196 | np.nanargmin: "nanargmin", 197 | } 198 | 199 | try: 200 | import bottleneck as bn 201 | except ImportError: 202 | _alias_bottleneck = {} 203 | else: 204 | _bn_funcs = "allnan anynan nansum nanmin nanmax nanmean nanvar nanstd" 205 | _alias_bottleneck = dict( 206 | (getattr(bn, fn), fn) for fn in _bn_funcs.split() 207 | ) 208 | 209 | aliasing = get_aliasing(_alias_numpy, _alias_bottleneck) 210 | 211 | def fill_untouched(idx, ret, fill_value): 212 | """any elements of ret not indexed by idx are set to fill_value.""" 213 | untouched = np.ones_like(ret, dtype=bool) 214 | untouched[idx] = False 215 | ret[untouched] = fill_value 216 | 217 | _next_int_dtype = dict( 218 | bool=np.int8, 219 | uint8=np.int16, 220 | int8=np.int16, 221 | uint16=np.int32, 222 | int16=np.int32, 223 | uint32=np.int64, 224 | int32=np.int64, 225 | ) 226 | 227 | _next_float_dtype = dict( 228 | float16=np.float32, 229 | float32=np.float64, 230 | float64=np.complex64, 231 | complex64=np.complex128, 232 | ) 233 | 234 | def minimum_dtype(x, dtype=np.bool_): 235 | """returns the "most basic" dtype which represents `x` properly, which 236 | provides at least the same value range as the specified dtype.""" 237 | 238 | def check_type(x, dtype): 239 | try: 240 | converted = dtype.type(x) 241 | except (ValueError, OverflowError): 242 | return False 243 | # False if some overflow has happened 244 | return converted == x or math.isnan(x) 245 | 246 | def type_loop(x, dtype, dtype_dict, default=None): 247 | while True: 248 | try: 249 | dtype = np.dtype(dtype_dict[dtype.name]) 250 | if check_type(x, dtype): 251 | return np.dtype(dtype) 252 | except KeyError: 253 | if default is not None: 254 | return np.dtype(default) 255 | raise ValueError("Can not determine dtype of %r" % x) 256 | 257 | dtype = np.dtype(dtype) 258 | if check_type(x, dtype): 259 | return dtype 260 | 261 | if np.issubdtype(dtype, np.inexact): 262 | return type_loop(x, dtype, _next_float_dtype) 263 | else: 264 | return type_loop(x, dtype, _next_int_dtype, default=np.float32) 265 | 266 | def minimum_dtype_scalar(x, dtype, a): 267 | if dtype is None: 268 | dtype = ( 269 | np.dtype(type(a)) if isinstance(a, (int, float)) else a.dtype 270 | ) 271 | return minimum_dtype(x, dtype) 272 | 273 | _forced_types = { 274 | "array": np.object, 275 | "all": np.bool_, 276 | "any": np.bool_, 277 | "nanall": np.bool_, 278 | "nanany": np.bool_, 279 | "len": np.int64, 280 | "nanlen": np.int64, 281 | "allnan": np.bool_, 282 | "anynan": np.bool_, 283 | } 284 | _forced_float_types = {"mean", "var", "std", "nanmean", "nanvar", "nanstd"} 285 | _forced_same_type = { 286 | "min", 287 | "max", 288 | "first", 289 | "last", 290 | "nanmin", 291 | "nanmax", 292 | "nanfirst", 293 | "nanlast", 294 | } 295 | 296 | def check_dtype(dtype, func_str, a, n): 297 | if np.isscalar(a) or not a.shape: 298 | if func_str not in ("sum", "prod", "len"): 299 | raise ValueError( 300 | "scalar inputs are supported only for 'sum', " 301 | "'prod' and 'len'" 302 | ) 303 | a_dtype = np.dtype(type(a)) 304 | else: 305 | a_dtype = a.dtype 306 | 307 | if dtype is not None: 308 | # dtype set by the user 309 | # Careful here: np.bool != np.bool_ ! 310 | if np.issubdtype(dtype, np.bool_) and not ( 311 | "all" in func_str or "any" in func_str 312 | ): 313 | raise TypeError( 314 | "function %s requires a more complex datatype " 315 | "than bool" % func_str 316 | ) 317 | if not np.issubdtype(dtype, np.integer) and func_str in ( 318 | "len", 319 | "nanlen", 320 | ): 321 | raise TypeError( 322 | "function %s requires an integer datatype" % func_str 323 | ) 324 | # TODO: Maybe have some more checks here 325 | return np.dtype(dtype) 326 | else: 327 | try: 328 | return np.dtype(_forced_types[func_str]) 329 | except KeyError: 330 | if func_str in _forced_float_types: 331 | if np.issubdtype(a_dtype, np.floating): 332 | return a_dtype 333 | else: 334 | return np.dtype(np.float64) 335 | else: 336 | if func_str == "sum": 337 | # Try to guess the minimally required int size 338 | if np.issubdtype(a_dtype, np.int64): 339 | # It's not getting bigger anymore 340 | # TODO: strictly speaking it might need float 341 | return np.dtype(np.int64) 342 | elif np.issubdtype(a_dtype, np.integer): 343 | maxval = np.iinfo(a_dtype).max * n 344 | return minimum_dtype(maxval, a_dtype) 345 | elif np.issubdtype(a_dtype, np.bool_): 346 | return minimum_dtype(n, a_dtype) 347 | else: 348 | # floating, inexact, whatever 349 | return a_dtype 350 | elif func_str in _forced_same_type: 351 | return a_dtype 352 | else: 353 | if isinstance(a_dtype, np.integer): 354 | return np.dtype(np.int64) 355 | else: 356 | return a_dtype 357 | 358 | def check_fill_value(fill_value, dtype): 359 | try: 360 | return dtype.type(fill_value) 361 | except ValueError: 362 | raise ValueError( 363 | "fill_value must be convertible into %s" % dtype.type.__name__ 364 | ) 365 | 366 | def check_group_idx(group_idx, a=None, check_min=True): 367 | if a is not None and group_idx.size != a.size: 368 | raise ValueError( 369 | "The size of group_idx must be the same as " "a.size" 370 | ) 371 | if not issubclass(group_idx.dtype.type, np.integer): 372 | raise TypeError("group_idx must be of integer type") 373 | if check_min and np.min(group_idx) < 0: 374 | raise ValueError("group_idx contains negative indices") 375 | 376 | def input_validation( 377 | group_idx, 378 | a, 379 | size=None, 380 | order="C", 381 | axis=None, 382 | ravel_group_idx=True, 383 | check_bounds=True, 384 | ): 385 | """ Do some fairly extensive checking of group_idx and a, trying to 386 | give the user as much help as possible with what is wrong. Also, 387 | convert ndim-indexing to 1d indexing. 388 | """ 389 | if not isinstance(a, (int, float, complex)): 390 | a = np.asanyarray(a) 391 | group_idx = np.asanyarray(group_idx) 392 | 393 | if not np.issubdtype(group_idx.dtype, np.integer): 394 | raise TypeError("group_idx must be of integer type") 395 | 396 | # This check works for multidimensional indexing as well 397 | if check_bounds and np.any(group_idx < 0): 398 | raise ValueError("negative indices not supported") 399 | 400 | ndim_idx = np.ndim(group_idx) 401 | ndim_a = np.ndim(a) 402 | 403 | # Deal with the axis arg: if present, then turn 1d indexing into 404 | # multi-dimensional indexing along the specified axis. 405 | if axis is None: 406 | if ndim_a > 1: 407 | raise ValueError( 408 | "a must be scalar or 1 dimensional, use .ravel to" 409 | " flatten. Alternatively specify axis." 410 | ) 411 | elif axis >= ndim_a or axis < -ndim_a: 412 | raise ValueError("axis arg too large for np.ndim(a)") 413 | else: 414 | axis = axis if axis >= 0 else ndim_a + axis # negative indexing 415 | if ndim_idx > 1: 416 | # TODO: we could support a sequence of axis values for multiple 417 | # dimensions of group_idx. 418 | raise NotImplementedError( 419 | "only 1d indexing currently" "supported with axis arg." 420 | ) 421 | elif a.shape[axis] != len(group_idx): 422 | raise ValueError( 423 | "a.shape[axis] doesn't match length of group_idx." 424 | ) 425 | elif size is not None and not np.isscalar(size): 426 | raise NotImplementedError( 427 | "when using axis arg, size must be" "None or scalar." 428 | ) 429 | else: 430 | # Create the broadcast-ready multidimensional indexing. 431 | # Note the user could do this themselves, so this is 432 | # very much just a convenience. 433 | size_in = np.max(group_idx) + 1 if size is None else size 434 | group_idx_in = group_idx 435 | group_idx = [] 436 | size = [] 437 | for ii, s in enumerate(a.shape): 438 | ii_idx = group_idx_in if ii == axis else np.arange(s) 439 | ii_shape = [1] * ndim_a 440 | ii_shape[ii] = s 441 | group_idx.append(ii_idx.reshape(ii_shape)) 442 | size.append(size_in if ii == axis else s) 443 | # Use the indexing, and return. It's a bit simpler than 444 | # using trying to keep all the logic below happy 445 | group_idx = np.ravel_multi_index( 446 | group_idx, size, order=order, mode="raise" 447 | ) 448 | flat_size = np.prod(size) 449 | ndim_idx = ndim_a 450 | return group_idx.ravel(), a.ravel(), flat_size, ndim_idx, size 451 | 452 | if ndim_idx == 1: 453 | if size is None: 454 | size = np.max(group_idx) + 1 455 | else: 456 | if not np.isscalar(size): 457 | raise ValueError("output size must be scalar or None") 458 | if check_bounds and np.any(group_idx > size - 1): 459 | raise ValueError( 460 | "one or more indices are too large for " 461 | "size %d" % size 462 | ) 463 | flat_size = size 464 | else: 465 | if size is None: 466 | size = np.max(group_idx, axis=1) + 1 467 | elif np.isscalar(size): 468 | raise ValueError( 469 | "output size must be of length %d" % len(group_idx) 470 | ) 471 | elif len(size) != len(group_idx): 472 | raise ValueError( 473 | "%d sizes given, but %d output dimensions " 474 | "specified in index" % (len(size), len(group_idx)) 475 | ) 476 | if ravel_group_idx: 477 | group_idx = np.ravel_multi_index( 478 | group_idx, size, order=order, mode="raise" 479 | ) 480 | flat_size = np.prod(size) 481 | 482 | if not (np.ndim(a) == 0 or len(a) == group_idx.size): 483 | raise ValueError( 484 | "group_idx and a must be of the same length, or a" 485 | " can be scalar" 486 | ) 487 | 488 | return group_idx, a, flat_size, ndim_idx, size 489 | 490 | 491 | def _sort(group_idx, a, size, fill_value, dtype=None, reversed_=False): 492 | if np.iscomplexobj(a): 493 | raise NotImplementedError( 494 | "a must be real, could use np.lexsort or " 495 | "sort with recarray for complex." 496 | ) 497 | if not (np.isscalar(fill_value) or len(fill_value) == 0): 498 | raise ValueError("fill_value must be scalar or an empty sequence") 499 | if reversed_: 500 | order_group_idx = np.argsort(group_idx + -1j * a, kind="mergesort") 501 | else: 502 | order_group_idx = np.argsort(group_idx + 1j * a, kind="mergesort") 503 | counts = np.bincount(group_idx, minlength=size) 504 | if np.ndim(a) == 0: 505 | a = np.full(size, a, dtype=type(a)) 506 | ret = np.split(a[order_group_idx], np.cumsum(counts)[:-1]) 507 | ret = np.asarray(ret, dtype=object) 508 | if np.isscalar(fill_value): 509 | fill_untouched(group_idx, ret, fill_value) 510 | return ret 511 | 512 | 513 | def _rsort(group_idx, a, size, fill_value, dtype=None): 514 | return _sort(group_idx, a, size, fill_value, dtype=None, reversed_=True) 515 | 516 | 517 | def _array(group_idx, a, size, fill_value, dtype=None): 518 | """groups a into separate arrays, keeping the order intact.""" 519 | if fill_value is not None and not ( 520 | np.isscalar(fill_value) or len(fill_value) == 0 521 | ): 522 | raise ValueError( 523 | "fill_value must be None, a scalar or an empty " "sequence" 524 | ) 525 | order_group_idx = np.argsort(group_idx, kind="mergesort") 526 | counts = np.bincount(group_idx, minlength=size) 527 | ret = np.split(a[order_group_idx], np.cumsum(counts)[:-1]) 528 | ret = np.asanyarray(ret) 529 | if fill_value is None or np.isscalar(fill_value): 530 | fill_untouched(group_idx, ret, fill_value) 531 | return ret 532 | 533 | 534 | def _sum(group_idx, a, size, fill_value, dtype=None): 535 | dtype = minimum_dtype_scalar(fill_value, dtype, a) 536 | 537 | if np.ndim(a) == 0: 538 | ret = np.bincount(group_idx, minlength=size).astype(dtype) 539 | if a != 1: 540 | ret *= a 541 | else: 542 | if np.iscomplexobj(a): 543 | ret = np.empty(size, dtype=dtype) 544 | ret.real = np.bincount(group_idx, weights=a.real, minlength=size) 545 | ret.imag = np.bincount(group_idx, weights=a.imag, minlength=size) 546 | else: 547 | ret = np.bincount(group_idx, weights=a, minlength=size).astype( 548 | dtype 549 | ) 550 | 551 | if fill_value != 0: 552 | fill_untouched(group_idx, ret, fill_value) 553 | return ret 554 | 555 | 556 | def _len(group_idx, a, size, fill_value, dtype=None): 557 | return _sum(group_idx, 1, size, fill_value, dtype=int) 558 | 559 | 560 | def _last(group_idx, a, size, fill_value, dtype=None): 561 | dtype = minimum_dtype(fill_value, dtype or a.dtype) 562 | if fill_value == 0: 563 | ret = np.zeros(size, dtype=dtype) 564 | else: 565 | ret = np.full(size, fill_value, dtype=dtype) 566 | # repeated indexing gives last value, see: 567 | # the phrase "leaving behind the last value" on this page: 568 | # http://wiki.scipy.org/Tentative_NumPy_Tutorial 569 | ret[group_idx] = a 570 | return ret 571 | 572 | 573 | def _first(group_idx, a, size, fill_value, dtype=None): 574 | dtype = minimum_dtype(fill_value, dtype or a.dtype) 575 | if fill_value == 0: 576 | ret = np.zeros(size, dtype=dtype) 577 | else: 578 | ret = np.full(size, fill_value, dtype=dtype) 579 | ret[group_idx[::-1]] = a[::-1] # same trick as _last, but in reverse 580 | return ret 581 | 582 | 583 | def _prod(group_idx, a, size, fill_value, dtype=None): 584 | dtype = minimum_dtype_scalar(fill_value, dtype, a) 585 | ret = np.full(size, fill_value, dtype=dtype) 586 | if fill_value != 1: 587 | ret[group_idx] = 1 # product starts from 1 588 | np.multiply.at(ret, group_idx, a) 589 | return ret 590 | 591 | 592 | def _all(group_idx, a, size, fill_value, dtype=None): 593 | check_boolean(fill_value) 594 | ret = np.full(size, fill_value, dtype=bool) 595 | if not fill_value: 596 | ret[group_idx] = True 597 | ret[group_idx.compress(np.logical_not(a))] = False 598 | return ret 599 | 600 | 601 | def _any(group_idx, a, size, fill_value, dtype=None): 602 | check_boolean(fill_value) 603 | ret = np.full(size, fill_value, dtype=bool) 604 | if fill_value: 605 | ret[group_idx] = False 606 | ret[group_idx.compress(a)] = True 607 | return ret 608 | 609 | 610 | def _min(group_idx, a, size, fill_value, dtype=None): 611 | dtype = minimum_dtype(fill_value, dtype or a.dtype) 612 | dmax = ( 613 | np.iinfo(a.dtype).max 614 | if issubclass(a.dtype.type, np.integer) 615 | else np.finfo(a.dtype).max 616 | ) 617 | ret = np.full(size, fill_value, dtype=dtype) 618 | if fill_value != dmax: 619 | ret[group_idx] = dmax # min starts from maximum 620 | np.minimum.at(ret, group_idx, a) 621 | return ret 622 | 623 | 624 | def _max(group_idx, a, size, fill_value, dtype=None): 625 | dtype = minimum_dtype(fill_value, dtype or a.dtype) 626 | dmin = ( 627 | np.iinfo(a.dtype).min 628 | if issubclass(a.dtype.type, np.integer) 629 | else np.finfo(a.dtype).min 630 | ) 631 | ret = np.full(size, fill_value, dtype=dtype) 632 | if fill_value != dmin: 633 | ret[group_idx] = dmin # max starts from minimum 634 | np.maximum.at(ret, group_idx, a) 635 | return ret 636 | 637 | 638 | def _argmax(group_idx, a, size, fill_value, dtype=None): 639 | dtype = minimum_dtype(fill_value, dtype or int) 640 | dmin = ( 641 | np.iinfo(a.dtype).min 642 | if issubclass(a.dtype.type, np.integer) 643 | else np.finfo(a.dtype).min 644 | ) 645 | group_max = _max(group_idx, a, size, dmin) 646 | is_max = a == group_max[group_idx] 647 | ret = np.full(size, fill_value, dtype=dtype) 648 | group_idx_max = group_idx[is_max] 649 | (argmax,) = is_max.nonzero() 650 | ret[group_idx_max[::-1]] = argmax[ 651 | ::-1 652 | ] # reverse to ensure first value for each group wins 653 | return ret 654 | 655 | 656 | def _argmin(group_idx, a, size, fill_value, dtype=None): 657 | dtype = minimum_dtype(fill_value, dtype or int) 658 | dmax = ( 659 | np.iinfo(a.dtype).max 660 | if issubclass(a.dtype.type, np.integer) 661 | else np.finfo(a.dtype).max 662 | ) 663 | group_min = _min(group_idx, a, size, dmax) 664 | is_min = a == group_min[group_idx] 665 | ret = np.full(size, fill_value, dtype=dtype) 666 | group_idx_min = group_idx[is_min] 667 | (argmin,) = is_min.nonzero() 668 | ret[group_idx_min[::-1]] = argmin[ 669 | ::-1 670 | ] # reverse to ensure first value for each group wins 671 | return ret 672 | 673 | 674 | def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)): 675 | if np.ndim(a) == 0: 676 | raise ValueError("cannot take mean with scalar a") 677 | counts = np.bincount(group_idx, minlength=size) 678 | if np.iscomplexobj(a): 679 | dtype = a.dtype # TODO: this is a bit clumsy 680 | sums = np.empty(size, dtype=dtype) 681 | sums.real = np.bincount(group_idx, weights=a.real, minlength=size) 682 | sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size) 683 | else: 684 | sums = np.bincount(group_idx, weights=a, minlength=size).astype(dtype) 685 | 686 | with np.errstate(divide="ignore"): 687 | ret = sums.astype(dtype) / counts 688 | if not np.isnan(fill_value): 689 | ret[counts == 0] = fill_value 690 | return ret 691 | 692 | 693 | def _var( 694 | group_idx, 695 | a, 696 | size, 697 | fill_value, 698 | dtype=np.dtype(np.float64), 699 | sqrt=False, 700 | ddof=0, 701 | ): 702 | if np.ndim(a) == 0: 703 | raise ValueError("cannot take variance with scalar a") 704 | counts = np.bincount(group_idx, minlength=size) 705 | sums = np.bincount(group_idx, weights=a, minlength=size) 706 | with np.errstate(divide="ignore"): 707 | means = sums.astype(dtype) / counts 708 | ret = np.bincount( 709 | group_idx, (a - means[group_idx]) ** 2, minlength=size 710 | ) / (counts - ddof) 711 | if sqrt: 712 | ret = np.sqrt(ret) # this is now std not var 713 | if not np.isnan(fill_value): 714 | ret[counts == 0] = fill_value 715 | return ret 716 | 717 | 718 | def _std(group_idx, a, size, fill_value, dtype=np.dtype(np.float64), ddof=0): 719 | return _var( 720 | group_idx, a, size, fill_value, dtype=dtype, sqrt=True, ddof=ddof 721 | ) 722 | 723 | 724 | def _allnan(group_idx, a, size, fill_value, dtype=bool): 725 | return _all( 726 | group_idx, np.isnan(a), size, fill_value=fill_value, dtype=dtype 727 | ) 728 | 729 | 730 | def _anynan(group_idx, a, size, fill_value, dtype=bool): 731 | return _any( 732 | group_idx, np.isnan(a), size, fill_value=fill_value, dtype=dtype 733 | ) 734 | 735 | 736 | def _generic_callable( 737 | group_idx, a, size, fill_value, dtype=None, func=lambda g: g 738 | ): 739 | """groups a by inds, and then applies foo to each group in turn, placing 740 | the results in an array.""" 741 | groups = _array(group_idx, a, size, (), dtype=dtype) 742 | ret = np.full(size, fill_value, dtype=object) 743 | 744 | for i, grp in enumerate(groups): 745 | if np.ndim(grp) == 1 and len(grp) > 0: 746 | ret[i] = func(grp) 747 | return ret 748 | 749 | 750 | _impl_dict = dict( 751 | min=_min, 752 | max=_max, 753 | sum=_sum, 754 | prod=_prod, 755 | last=_last, 756 | first=_first, 757 | all=_all, 758 | any=_any, 759 | mean=_mean, 760 | std=_std, 761 | var=_var, 762 | anynan=_anynan, 763 | allnan=_allnan, 764 | sort=_sort, 765 | rsort=_rsort, 766 | array=_array, 767 | argmax=_argmax, 768 | argmin=_argmin, 769 | len=_len, 770 | ) 771 | _impl_dict.update( 772 | ("nan" + k, v) 773 | for k, v in list(_impl_dict.items()) 774 | if k not in _no_separate_nan_version 775 | ) 776 | 777 | 778 | def _aggregate_base( 779 | group_idx, 780 | a, 781 | func="sum", 782 | size=None, 783 | fill_value=0, 784 | order="C", 785 | dtype=None, 786 | axis=None, 787 | _impl_dict=_impl_dict, 788 | _nansqueeze=False, 789 | **kwargs 790 | ): 791 | group_idx, a, flat_size, ndim_idx, size = input_validation( 792 | group_idx, a, size=size, order=order, axis=axis 793 | ) 794 | func = get_func(func, aliasing, _impl_dict) 795 | if not isstr(func): 796 | # do simple grouping and execute function in loop 797 | ret = _generic_callable( 798 | group_idx, 799 | a, 800 | flat_size, 801 | fill_value, 802 | func=func, 803 | dtype=dtype, 804 | **kwargs 805 | ) 806 | else: 807 | # deal with nans and find the function 808 | if func.startswith("nan"): 809 | if np.ndim(a) == 0: 810 | raise ValueError("nan-version not supported for scalar input.") 811 | if _nansqueeze: 812 | good = ~np.isnan(a) 813 | a = a[good] 814 | group_idx = group_idx[good] 815 | 816 | dtype = check_dtype(dtype, func, a, flat_size) 817 | func = _impl_dict[func] 818 | ret = func( 819 | group_idx, 820 | a, 821 | flat_size, 822 | fill_value=fill_value, 823 | dtype=dtype, 824 | **kwargs 825 | ) 826 | 827 | # deal with ndimensional indexing 828 | if ndim_idx > 1: 829 | ret = ret.reshape(size, order=order) 830 | return ret 831 | 832 | 833 | def aggregate( 834 | group_idx, 835 | a, 836 | func="sum", 837 | size=None, 838 | fill_value=0, 839 | order="C", 840 | dtype=None, 841 | axis=None, 842 | **kwargs 843 | ): 844 | return _aggregate_base( 845 | group_idx, 846 | a, 847 | size=size, 848 | fill_value=fill_value, 849 | order=order, 850 | dtype=dtype, 851 | func=func, 852 | axis=axis, 853 | _impl_dict=_impl_dict, 854 | _nansqueeze=True, 855 | **kwargs 856 | ) 857 | 858 | 859 | aggregate.__doc__ = ( 860 | """ 861 | This is the pure numpy implementation of aggregate. 862 | """ 863 | + _doc_str 864 | ) 865 | -------------------------------------------------------------------------------- /sklearn_xarray/model_selection.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``sklearn_xarray.model_selection`` 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | class CrossValidatorWrapper(object): 9 | """ Wrap an sklearn cross validator for use with xarray. 10 | 11 | Parameters 12 | ---------- 13 | cross_validator : sklearn cross-validator 14 | An instance of a cross-validator. 15 | 16 | dim : str 17 | The dimension along which to perform the split. 18 | 19 | groupby : str or list 20 | Name of coordinate or list of coordinates by which the groups are 21 | determined. 22 | """ 23 | 24 | def __init__(self, cross_validator, dim="sample", groupby=None): 25 | 26 | self.cross_validator = cross_validator 27 | self.dim = dim 28 | self.groupby = groupby 29 | 30 | def get_n_splits(self, X=None, y=None, groups=None): 31 | """ Returns the number of splitting iterations in the cross-validator. 32 | 33 | Parameters 34 | ---------- 35 | X : object 36 | Always ignored, exists for compatibility. 37 | 38 | y : object 39 | Always ignored, exists for compatibility. 40 | 41 | groups : object 42 | Always ignored, exists for compatibility. 43 | 44 | Returns 45 | ------- 46 | n_splits : int 47 | Returns the number of splitting iterations in the cross-validator. 48 | """ 49 | 50 | return self.cross_validator.get_n_splits(X, y, groups) 51 | 52 | def split(self, X, y=None, groups=None): 53 | """ Generate indices to split data into training and test set. 54 | 55 | Parameters 56 | ---------- 57 | X : xarray DataArray or Dataset 58 | Training data, where n_samples is the number of samples 59 | and n_features is the number of features. 60 | 61 | y : array-like, shape (n_samples,) 62 | The target variable for supervised learning problems. 63 | 64 | groups : array-like, with shape (n_samples,), optional 65 | Group labels for the samples used while splitting the dataset into 66 | train/test set. 67 | 68 | Returns 69 | ------- 70 | train : ndarray 71 | The training set indices for that split. 72 | 73 | test : ndarray 74 | The testing set indices for that split. 75 | """ 76 | 77 | if self.groupby is not None: 78 | from .utils import get_group_indices 79 | 80 | groups = np.zeros(len(X[self.dim])) 81 | group_idx = get_group_indices(X, self.groupby, self.dim) 82 | for i in range(len(group_idx)): 83 | groups[group_idx[i]] = i 84 | 85 | return self.cross_validator.split(X[self.dim], y=y, groups=groups) 86 | -------------------------------------------------------------------------------- /sklearn_xarray/target.py: -------------------------------------------------------------------------------- 1 | """``sklearn_xarray.target``""" 2 | 3 | import numpy as np 4 | 5 | 6 | class Target(object): 7 | """ A pointer to xarray coordinates or variables to be used as a target. 8 | 9 | Parameters 10 | ---------- 11 | coord : str, optional 12 | The coordinate or variable that holds the data of the target. If not 13 | specified, the target will be the entire DataArray/Dataset. 14 | 15 | transform_func : callable, optional 16 | A function that transforms the coordinate values to an 17 | sklearn-compatible type and shape. If not specified, the coordinate(s) 18 | will be used as-is. 19 | 20 | transformer : sklearn transformer, optional 21 | **Deprecated**, use ``transform_func=Transformer().fit_transform`` 22 | instead. 23 | 24 | lazy : bool, optinonal 25 | If true, the target coordinate is only transformed by the transformer 26 | when needed. The transformer can implement a ``get_transformed_shape`` 27 | method that returns the shape after the transformation of the provided 28 | coordinate without actually transforming the data. 29 | 30 | dim : str or sequence of str, optional 31 | When set, multi-dimensional coordinates will be reduced to this 32 | dimension/these dimensions. 33 | 34 | reduce_func : callable, optional 35 | A callable that reduces the coordinate(s) to the dimension(s) in 36 | ``dim``. If not specified, the values along dimensions not in ``dim`` 37 | will be reduced to the first element in each of these dimensions. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | coord=None, 43 | transform_func=None, 44 | transformer=None, 45 | lazy=False, 46 | dim=None, 47 | reduce_func=None, 48 | ): 49 | 50 | self.transform_func = transform_func 51 | self.coord = coord 52 | self.lazy = lazy 53 | self.reduce_func = reduce_func 54 | self.dim = dim 55 | 56 | self.transformer = transformer 57 | if transformer is not None: 58 | import warnings 59 | 60 | warnings.simplefilter("always", DeprecationWarning) 61 | warnings.warn( 62 | "The transformer argument is deprecated and will be removed " 63 | "in a future version. Use " 64 | "transform_func=Transformer().fit_transform instead.", 65 | DeprecationWarning, 66 | ) 67 | warnings.simplefilter("ignore", DeprecationWarning) 68 | self.transform_func = self.transformer.fit_transform 69 | 70 | self.values = None 71 | 72 | def __getitem__(self, key): 73 | 74 | import copy 75 | 76 | self._check_assigned() 77 | 78 | new_obj = copy.copy(self) 79 | 80 | if self.lazy: 81 | new_obj.values = self.transform_func(self.values)[key] 82 | new_obj.lazy = False 83 | else: 84 | new_obj.values = self.values[key] 85 | 86 | return new_obj 87 | 88 | def __call__(self, X): 89 | 90 | return self.assign_to(X) 91 | 92 | def __str__(self): 93 | 94 | if self.values is None: 95 | if self.coord is None: 96 | return "Unassigned sklearn_xarray.Target without coordinate." 97 | else: 98 | return ( 99 | 'Unassigned sklearn_xarray.Target with coordinate "' 100 | + self.coord 101 | + '".' 102 | ) 103 | else: 104 | return "sklearn_xarray.Target with data:\n" + self.values.__str__() 105 | 106 | def __repr__(self): 107 | 108 | return self.__str__() 109 | 110 | def __array__(self, dtype=None): 111 | 112 | self._check_assigned() 113 | 114 | if not self.lazy or self.transform_func is None: 115 | return np.array(self.values, dtype=dtype) 116 | else: 117 | return np.array(self.transform_func(self.values), dtype=dtype) 118 | 119 | def _check_assigned(self): 120 | """ Check if this instance has been assigned data. """ 121 | 122 | if self.values is None and self.lazy: 123 | raise ValueError("This instance has not been assigned any data.") 124 | 125 | def _reduce(self, values): 126 | """ Reduce values to dimension(s). """ 127 | 128 | if self.dim is None: 129 | return values 130 | 131 | if isinstance(self.dim, str): 132 | dim = [self.dim] 133 | else: 134 | dim = self.dim 135 | 136 | if self.reduce_func is None: 137 | for d in values.dims: 138 | if d not in dim: 139 | values = values.isel(**{d: 0}) 140 | return values 141 | else: 142 | other_dims = [d for d in values.dims if d not in dim] 143 | return values.reduce(self.reduce_func, dim=other_dims) 144 | 145 | @property 146 | def shape(self): 147 | """ The shape of the transformed target. """ 148 | 149 | self._check_assigned() 150 | 151 | if ( 152 | self.lazy 153 | and self.transformer is not None 154 | and hasattr(self.transformer, "get_transformed_shape") 155 | ): 156 | return self.transformer.get_transformed_shape(self.values) 157 | else: 158 | return self.__array__().shape 159 | 160 | @property 161 | def ndim(self): 162 | """ The number of dimensions of the transformed target. """ 163 | 164 | self._check_assigned() 165 | 166 | if ( 167 | self.lazy 168 | and self.transformer is not None 169 | and hasattr(self.transformer, "get_transformed_shape") 170 | ): 171 | return len(self.transformer.get_transformed_shape(self.values)) 172 | else: 173 | return self.__array__().ndim 174 | 175 | def assign_to(self, X): 176 | """ Assign this target to a DataArray or Dataset. 177 | 178 | Parameters 179 | ---------- 180 | X : xarray DataArray or Dataset 181 | The data whose coordinate is used as the target. 182 | 183 | Returns 184 | ------- 185 | self: 186 | The target itself. 187 | """ 188 | 189 | if self.coord is not None: 190 | self.values = self._reduce(X[self.coord]) 191 | else: 192 | self.values = self._reduce(X) 193 | 194 | if not self.lazy and self.transform_func is not None: 195 | self.values = self.transform_func(self.values) 196 | 197 | return self 198 | -------------------------------------------------------------------------------- /sklearn_xarray/utils.py: -------------------------------------------------------------------------------- 1 | """ ``sklearn_xarray.utils`` """ 2 | 3 | 4 | import numpy as np 5 | 6 | from .target import Target 7 | 8 | 9 | def is_dataarray(X, require_attrs=None): 10 | """ Check whether an object is a DataArray. 11 | 12 | Parameters 13 | ---------- 14 | X : anything 15 | The object to be checked. 16 | 17 | require_attrs : list of str, optional 18 | The attributes the object has to have in order to pass as a DataArray. 19 | 20 | Returns 21 | ------- 22 | bool 23 | Whether the object is a DataArray or not. 24 | """ 25 | 26 | if require_attrs is None: 27 | require_attrs = ["values", "coords", "dims", "to_dataset"] 28 | 29 | return all([hasattr(X, name) for name in require_attrs]) 30 | 31 | 32 | def is_dataset(X, require_attrs=None): 33 | """ Check whether an object is a Dataset. 34 | 35 | Parameters 36 | ---------- 37 | X : anything 38 | The object to be checked. 39 | 40 | require_attrs : list of str, optional 41 | The attributes the object has to have in order to pass as a Dataset. 42 | 43 | Returns 44 | ------- 45 | bool 46 | Whether the object is a Dataset or not. 47 | """ 48 | 49 | if require_attrs is None: 50 | require_attrs = ["data_vars", "coords", "dims", "to_array"] 51 | 52 | return all([hasattr(X, name) for name in require_attrs]) 53 | 54 | 55 | def is_target(X, require_attrs=None): 56 | """ Check whether an object is a Target. 57 | 58 | Parameters 59 | ---------- 60 | X : anything 61 | The object to be checked. 62 | 63 | require_attrs : list of str, optional 64 | The attributes the object has to have in order to pass as a Target. 65 | 66 | Returns 67 | ------- 68 | bool 69 | Whether the object is a Target or not. 70 | """ 71 | 72 | if require_attrs is None: 73 | require_attrs = ( 74 | name for name in vars(Target) if not name.startswith("_") 75 | ) 76 | 77 | return all([hasattr(X, name) for name in require_attrs]) 78 | 79 | 80 | def convert_to_ndarray(X, new_dim_last=True, new_dim_name="variable"): 81 | """ Convert xarray DataArray or Dataset to numpy ndarray. 82 | 83 | Parameters 84 | ---------- 85 | X : xarray DataArray or Dataset 86 | The input data. 87 | 88 | new_dim_last : bool, default true 89 | If true, put the new dimension last when converting a Dataset with 90 | multiple variables. 91 | 92 | new_dim_name : str, default 'variable' 93 | The name of the new dimension when converting a Dataset with multiple 94 | variables. 95 | 96 | Returns 97 | ------- 98 | X_arr : numpy ndarray 99 | The data as an ndarray. 100 | """ 101 | 102 | if is_dataset(X): 103 | 104 | if len(X.data_vars) == 1: 105 | X = X[tuple(X.data_vars)[0]] 106 | else: 107 | X = X.to_array(dim=new_dim_name) 108 | if new_dim_last: 109 | new_order = list(X.dims) 110 | new_order.append(new_dim_name) 111 | new_order.remove(new_dim_name) 112 | X = X.transpose(*new_order) 113 | 114 | return np.array(X) 115 | 116 | 117 | def get_group_indices(X, groupby, group_dim=None): 118 | """ Get logical index vectors for each group. 119 | 120 | Parameters 121 | ---------- 122 | X : xarray DataArray or Dataset 123 | The data structure for which to determine the indices. 124 | 125 | groupby : str or list 126 | Name of coordinate or list of coordinates by which the groups are 127 | determined. 128 | 129 | group_dim : str or None, optional 130 | Name of dimension along which the groups are indexed. 131 | 132 | Returns 133 | ------- 134 | idx: list of boolean numpy vectors 135 | List of logical indices for each group. 136 | """ 137 | 138 | import itertools 139 | 140 | if isinstance(groupby, str): 141 | groupby = [groupby] 142 | 143 | idx_groups = [] 144 | for g in groupby: 145 | if group_dim is None or group_dim not in X[g].dims: 146 | values = X[g].values 147 | else: 148 | other_dims = set(X[g].dims) - {group_dim} 149 | values = X[g].isel(**{d: 0 for d in other_dims}).values 150 | idx_groups.append([values == v for v in np.unique(values)]) 151 | 152 | idx_all = [np.all(e, axis=0) for e in itertools.product(*idx_groups)] 153 | 154 | return [i for i in idx_all if np.any(i)] 155 | 156 | 157 | def segment_array( 158 | arr, axis, new_len, step=1, new_axis=None, return_view=False 159 | ): 160 | """ Segment an array along some axis. 161 | 162 | Parameters 163 | ---------- 164 | arr : array-like 165 | The input array. 166 | 167 | axis : int 168 | The axis along which to segment. 169 | 170 | new_len : int 171 | The length of each segment. 172 | 173 | step : int, default 1 174 | The offset between the start of each segment. 175 | 176 | new_axis : int, optional 177 | The position where the newly created axis is to be inserted. By 178 | default, the axis will be added at the end of the array. 179 | 180 | return_view : bool, default False 181 | If True, return a view of the segmented array instead of a copy. 182 | 183 | Returns 184 | ------- 185 | arr_seg : array-like 186 | The segmented array. 187 | """ 188 | 189 | from numpy.lib.stride_tricks import as_strided 190 | 191 | # handle the case that the segmented axis is singleton after segmentation 192 | if (arr.shape[axis] - new_len) // step == 0: 193 | idx = [slice(None)] * arr.ndim 194 | idx[axis] = slice(new_len) 195 | arr_seg = arr[tuple(idx)][..., np.newaxis] 196 | if new_axis is None: 197 | return np.moveaxis(arr_seg, (axis, -1), (-1, axis)) 198 | else: 199 | return np.moveaxis(arr_seg, (axis, -1), (new_axis, axis)) 200 | 201 | old_shape = np.array(arr.shape) 202 | 203 | assert ( 204 | new_len <= old_shape[axis] 205 | ), "new_len is bigger than input array in axis" 206 | seg_shape = old_shape.copy() 207 | seg_shape[axis] = new_len 208 | 209 | steps = np.ones_like(old_shape) 210 | if step: 211 | step = np.array(step, ndmin=1) 212 | assert step > 0, "Only positive steps allowed" 213 | steps[axis] = step 214 | 215 | arr_strides = np.array(arr.strides) 216 | 217 | shape = tuple((old_shape - seg_shape) // steps + 1) + tuple(seg_shape) 218 | strides = tuple(arr_strides * steps) + tuple(arr_strides) 219 | 220 | arr_seg = np.squeeze(as_strided(arr, shape=shape, strides=strides)) 221 | 222 | # squeeze will move the segmented axis to the first position 223 | arr_seg = np.moveaxis(arr_seg, 0, axis) 224 | 225 | # the new axis comes right after 226 | if new_axis is not None: 227 | arr_seg = np.moveaxis(arr_seg, axis + 1, new_axis) 228 | else: 229 | arr_seg = np.moveaxis(arr_seg, axis + 1, -1) 230 | 231 | if return_view: 232 | return arr_seg 233 | else: 234 | return arr_seg.copy() 235 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phausamann/sklearn-xarray/0a8e61222a89e02665f444233e2bb2eb2bef7184/tests/__init__.py -------------------------------------------------------------------------------- /tests/mocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.base import BaseEstimator, TransformerMixin 4 | 5 | 6 | class DummyEstimator(BaseEstimator): 7 | """ A dummy estimator that returns the input as a numpy array.""" 8 | 9 | def __init__(self, demo_param="demo_param"): 10 | 11 | self.demo_param = demo_param 12 | 13 | def fit(self, X, y=None): 14 | 15 | return self 16 | 17 | def predict(self, X): 18 | 19 | return np.array(X) 20 | 21 | 22 | class DummyTransformer(BaseEstimator): 23 | """ A dummy estimator that returns the input as a numpy array.""" 24 | 25 | def __init__(self, demo_param="demo_param"): 26 | 27 | self.demo_param = demo_param 28 | 29 | def fit(self, X, y=None): 30 | 31 | return self 32 | 33 | def transform(self, X): 34 | 35 | return np.array(X) 36 | 37 | 38 | class ReshapingEstimator(BaseEstimator, TransformerMixin): 39 | """ A dummy estimator that changes the number of features.""" 40 | 41 | def __init__(self, new_shape=None): 42 | 43 | self.new_shape = new_shape 44 | 45 | def fit(self, X, y=None): 46 | 47 | self.shape_ = X.shape 48 | 49 | return self 50 | 51 | def predict(self, X): 52 | 53 | Xt = np.array(X) 54 | 55 | idx = [slice(None)] * Xt.ndim 56 | for i in range(len(self.new_shape)): 57 | if self.new_shape[i] > 0: 58 | idx[i] = slice(None, self.new_shape[i]) 59 | elif self.new_shape[i] == 0: 60 | idx[i] = 0 61 | 62 | return Xt[tuple(idx)] 63 | 64 | def transform(self, X): 65 | 66 | return self.predict(X) 67 | 68 | def inverse_transform(self, X): 69 | 70 | Xt = np.zeros(self.shape_) 71 | 72 | idx = [slice(None)] * Xt.ndim 73 | for i in range(len(self.new_shape)): 74 | if self.new_shape[i] > 0: 75 | idx[i] = slice(None, self.new_shape[i]) 76 | elif self.new_shape[i] == 0: 77 | idx[i] = 0 78 | 79 | Xt[tuple(idx)] = X 80 | 81 | return Xt 82 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | import xarray as xr 5 | from xarray.testing import assert_equal, assert_allclose 6 | import numpy.testing as npt 7 | 8 | from sklearn_xarray import wrap 9 | 10 | from sklearn.base import clone 11 | from sklearn.preprocessing import StandardScaler, KernelCenterer 12 | from sklearn.linear_model import LinearRegression, LogisticRegression 13 | from sklearn.svm import SVC 14 | 15 | from tests.mocks import ( 16 | DummyEstimator, 17 | DummyTransformer, 18 | ReshapingEstimator, 19 | ) 20 | 21 | 22 | class EstimatorWrapperTests(TestCase): 23 | def setUp(self): 24 | 25 | self.X = xr.Dataset( 26 | { 27 | "var_2d": (["sample", "feat_1"], np.random.random((100, 10))), 28 | "var_3d": ( 29 | ["sample", "feat_1", "feat_2"], 30 | np.random.random((100, 10, 10)), 31 | ), 32 | }, 33 | { 34 | "sample": range(100), 35 | "feat_1": range(10), 36 | "feat_2": range(10), 37 | "dummy": (["sample", "feat_1"], np.random.random((100, 10))), 38 | }, 39 | ) 40 | 41 | def test_update_restore_dims(self): 42 | 43 | estimator = wrap( 44 | ReshapingEstimator(new_shape=(-1, 0, 5)), 45 | reshapes={"feature": ["feat_1", "feat_2"]}, 46 | ) 47 | 48 | X = self.X.var_3d 49 | 50 | estimator.fit(X) 51 | 52 | X_out = estimator.estimator_.transform(X.values) 53 | dims_new = estimator._update_dims(X, X_out) 54 | Xt = xr.DataArray(X_out, dims=dims_new) 55 | 56 | assert dims_new == ["sample", "feature"] 57 | 58 | Xr_out = estimator.estimator_.inverse_transform(X_out) 59 | dims_old = estimator._restore_dims(Xt, Xr_out) 60 | 61 | assert dims_old == ["sample", "feat_1", "feat_2"] 62 | 63 | def test_update_coords(self): 64 | 65 | pass 66 | 67 | def test_params(self): 68 | 69 | estimator = StandardScaler(with_mean=False) 70 | params = estimator.get_params() 71 | params.update( 72 | {"estimator": estimator, "reshapes": None, "sample_dim": None} 73 | ) 74 | 75 | # check params set in constructor 76 | wrapper = wrap(estimator) 77 | self.assertEqual(wrapper.get_params(), params) 78 | self.assertEqual(wrapper.with_mean, False) 79 | 80 | # check params set by attribute 81 | wrapper.with_std = False 82 | params.update({"with_std": False}) 83 | self.assertEqual(wrapper.get_params(), params) 84 | 85 | # check params set with set_params 86 | wrapper.set_params(copy=False) 87 | params.update({"copy": False}) 88 | self.assertEqual(wrapper.get_params(), params) 89 | 90 | def test_attributes(self): 91 | 92 | estimator = wrap(StandardScaler()) 93 | 94 | # check pass-through wrapper 95 | estimator.fit(self.X.var_2d.values) 96 | npt.assert_allclose(estimator.mean_, estimator.estimator_.mean_) 97 | 98 | # check DataArray wrapper 99 | estimator.fit(self.X.var_2d) 100 | npt.assert_allclose(estimator.mean_, estimator.estimator_.mean_) 101 | 102 | # check Dataset wrapper 103 | estimator.fit(self.X.var_2d.to_dataset()) 104 | npt.assert_allclose( 105 | estimator.mean_["var_2d"], 106 | estimator.estimator_dict_["var_2d"].mean_, 107 | ) 108 | 109 | 110 | class PublicInterfaceTests(TestCase): 111 | def setUp(self): 112 | 113 | self.X = xr.Dataset( 114 | { 115 | "var_2d": (["sample", "feat_1"], np.random.random((100, 10))), 116 | "var_3d": ( 117 | ["sample", "feat_1", "feat_2"], 118 | np.random.random((100, 10, 10)), 119 | ), 120 | }, 121 | { 122 | "sample": range(100), 123 | "feat_1": range(10), 124 | "feat_2": range(10), 125 | "dummy": (["sample", "feat_1"], np.random.random((100, 10))), 126 | }, 127 | ) 128 | 129 | def test_dummy_estimator(self): 130 | 131 | estimator = wrap(DummyEstimator()) 132 | 133 | # test DataArray 134 | X_da = self.X.var_2d 135 | 136 | estimator.fit(X_da) 137 | yp = estimator.predict(X_da) 138 | 139 | assert_equal(yp, X_da) 140 | 141 | # test Dataset 142 | X_ds = self.X 143 | 144 | estimator.fit(X_ds) 145 | yp = estimator.predict(X_ds) 146 | 147 | assert_equal(yp, X_ds) 148 | 149 | def test_dummy_transformer(self): 150 | 151 | estimator = wrap(DummyTransformer()) 152 | 153 | # test DataArray 154 | X_da = self.X.var_2d 155 | 156 | estimator.fit(X_da) 157 | yp = estimator.transform(X_da) 158 | 159 | assert_equal(yp, X_da) 160 | 161 | # test Dataset 162 | X_ds = self.X 163 | 164 | estimator.fit(X_ds) 165 | yp = estimator.transform(X_ds) 166 | 167 | assert_equal(yp, X_ds) 168 | 169 | def test_wrapped_transformer(self): 170 | 171 | estimator = wrap(StandardScaler()) 172 | 173 | # test DataArray 174 | X_da = self.X.var_2d 175 | 176 | estimator.partial_fit(X_da) 177 | 178 | assert_allclose( 179 | X_da, estimator.inverse_transform(estimator.transform(X_da)) 180 | ) 181 | 182 | # test Dataset 183 | X_ds = self.X.var_2d.to_dataset() 184 | 185 | estimator.fit(X_ds) 186 | 187 | assert_allclose( 188 | X_ds, estimator.inverse_transform(estimator.transform(X_ds)) 189 | ) 190 | 191 | def test_ndim_dummy_estimator(self): 192 | 193 | estimator = wrap(DummyEstimator()) 194 | 195 | # test DataArray 196 | X_da = self.X.var_3d 197 | 198 | estimator.fit(X_da) 199 | yp = estimator.predict(X_da) 200 | 201 | assert_equal(yp, X_da) 202 | 203 | # test Dataset 204 | X_ds = self.X 205 | 206 | estimator.fit(X_ds) 207 | yp = estimator.predict(X_ds) 208 | 209 | assert_equal(yp, X_ds) 210 | 211 | def test_reshaping_estimator(self): 212 | 213 | estimator = wrap( 214 | ReshapingEstimator(new_shape=(-1, 2)), reshapes="feat_1" 215 | ) 216 | 217 | # test DataArray 218 | X_da = self.X.var_2d 219 | 220 | y = X_da[:, :2].drop("feat_1") 221 | y["dummy"] = y.dummy[:, 0] 222 | 223 | estimator.fit(X_da) 224 | yp = estimator.predict(X_da) 225 | 226 | assert_allclose(yp, y) 227 | 228 | # test Dataset 229 | X_ds = self.X.var_2d.to_dataset() 230 | 231 | y = X_ds.var_2d[:, :2].drop("feat_1") 232 | y["dummy"] = y.dummy[:, 0] 233 | 234 | estimator.fit(X_ds) 235 | yp = estimator.predict(X_ds).var_2d 236 | 237 | assert_allclose(yp, y) 238 | 239 | def test_reshaping_transformer(self): 240 | 241 | estimator = wrap( 242 | ReshapingEstimator(new_shape=(-1, 2)), reshapes="feat_1" 243 | ) 244 | 245 | # test DataArray 246 | X_da = self.X.var_3d 247 | 248 | y = X_da[:, :2].drop("feat_1") 249 | y["dummy"] = y.dummy[:, 0] 250 | 251 | estimator.fit(X_da) 252 | yp = estimator.transform(X_da) 253 | 254 | assert_allclose(yp, y) 255 | 256 | # test Dataset 257 | X_ds = self.X.var_2d.to_dataset() 258 | 259 | y = X_ds.var_2d[:, :2].drop("feat_1") 260 | y["dummy"] = y.dummy[:, 0] 261 | 262 | estimator.fit(X_ds) 263 | yp = estimator.transform(X_ds).var_2d 264 | 265 | assert_allclose(yp, y) 266 | 267 | def test_reshaping_estimator_singleton(self): 268 | 269 | estimator = wrap( 270 | ReshapingEstimator(new_shape=(-1, 0)), reshapes="feat_1" 271 | ) 272 | 273 | # test DataArray 274 | X_da = self.X.var_2d 275 | 276 | y = X_da[:, 0].drop("feat_1") 277 | estimator.fit(X_da) 278 | yp = estimator.predict(X_da) 279 | 280 | assert_allclose(yp, y) 281 | 282 | # test Dataset 283 | X_ds = self.X 284 | 285 | y = X_ds.var_2d[:, 0].drop("feat_1") 286 | 287 | estimator.fit(X_ds) 288 | yp = estimator.predict(X_ds).var_2d 289 | 290 | assert_allclose(yp, y) 291 | 292 | def test_ndim_reshaping_estimator(self): 293 | 294 | estimator = wrap( 295 | ReshapingEstimator(new_shape=(-1, 5, 0)), 296 | reshapes={"feature": ["feat_1", "feat_2"]}, 297 | ) 298 | 299 | # test DataArray 300 | X_da = self.X.var_3d 301 | 302 | Xt = ( 303 | X_da[:, :5, 0] 304 | .drop(["feat_1", "feat_2"]) 305 | .rename({"feat_1": "feature"}) 306 | ) 307 | Xt["dummy"] = Xt.dummy[:, 0] 308 | 309 | estimator.fit(X_da) 310 | Xt_da = estimator.transform(X_da) 311 | estimator.inverse_transform(Xt_da) 312 | 313 | assert_allclose(Xt_da, Xt) 314 | 315 | # test Dataset 316 | X_ds = self.X.var_3d.to_dataset() 317 | 318 | y = X_ds.var_3d[:, :5, 0].drop(["feat_1", "feat_2"]) 319 | y = y.rename({"feat_1": "feature"}) 320 | y["dummy"] = y.dummy[:, 0] 321 | 322 | estimator.fit(X_ds) 323 | yp = estimator.predict(X_ds).var_3d 324 | 325 | assert_allclose(yp, y) 326 | 327 | def test_sample_dim(self): 328 | 329 | from sklearn.decomposition import PCA 330 | 331 | estimator = wrap( 332 | PCA(n_components=5), reshapes="feat_1", sample_dim="sample" 333 | ) 334 | 335 | # test DataArray 336 | X_da = self.X.var_2d 337 | 338 | Xt_da = estimator.fit_transform(X_da) 339 | Xr_da = estimator.inverse_transform(Xt_da) 340 | 341 | npt.assert_equal(Xt_da.shape, (100, 5)) 342 | npt.assert_equal(Xr_da.shape, (100, 10)) 343 | 344 | # test Dataset 345 | X_ds = self.X.var_2d.to_dataset() 346 | 347 | Xt = estimator.fit_transform(X_ds) 348 | 349 | npt.assert_equal(Xt.var_2d.shape, (100, 5)) 350 | 351 | def test_score(self): 352 | 353 | from sklearn.linear_model import LinearRegression 354 | 355 | estimator = wrap(LinearRegression, reshapes="feat_1") 356 | 357 | # test DataArray 358 | X_da = self.X.var_2d 359 | 360 | y = np.random.random(100) 361 | 362 | estimator.fit(X_da, y) 363 | 364 | estimator.score(X_da, y) 365 | 366 | # test Dataset 367 | X_ds = self.X.var_2d.to_dataset() 368 | 369 | wrapper = estimator.fit(X_ds, y) 370 | 371 | wrapper.score(X_ds, y) 372 | 373 | def test_partial_fit(self): 374 | 375 | estimator = wrap(StandardScaler()) 376 | 377 | # check pass-through wrapper 378 | estimator.partial_fit(self.X.var_2d.values) 379 | assert hasattr(estimator, "mean_") 380 | 381 | with self.assertRaises(ValueError): 382 | estimator.partial_fit(self.X.var_2d) 383 | with self.assertRaises(ValueError): 384 | estimator.partial_fit(self.X) 385 | 386 | # check DataArray wrapper 387 | estimator = clone(estimator) 388 | estimator.partial_fit(self.X.var_2d) 389 | 390 | with self.assertRaises(ValueError): 391 | estimator.partial_fit(self.X.var_2d.values) 392 | with self.assertRaises(ValueError): 393 | estimator.partial_fit(self.X) 394 | assert hasattr(estimator, "mean_") 395 | 396 | # check Dataset wrapper 397 | estimator = clone(estimator) 398 | estimator.partial_fit(self.X.var_2d.to_dataset()) 399 | 400 | with self.assertRaises(ValueError): 401 | estimator.partial_fit(self.X.var_2d.values) 402 | with self.assertRaises(ValueError): 403 | estimator.partial_fit(self.X.var_2d) 404 | assert hasattr(estimator, "mean_") 405 | 406 | 407 | def test_classifier(): 408 | 409 | lr = wrap(LogisticRegression) 410 | # wrappers don't pass check_estimator anymore because estimators 411 | # "should not set any attribute apart from parameters during init" 412 | assert hasattr(lr, "predict") 413 | assert hasattr(lr, "decision_function") 414 | 415 | lr = wrap(LogisticRegression) 416 | assert hasattr(lr, "C") 417 | 418 | svc_proba = wrap(SVC(probability=True)) 419 | # check_estimator(svc_proba) fails because the wrapper is not excluded 420 | # from tests that are known to fail for SVC... 421 | assert hasattr(svc_proba, "predict_proba") 422 | assert hasattr(svc_proba, "predict_log_proba") 423 | 424 | 425 | def test_regressor(): 426 | 427 | lr = wrap(LinearRegression, compat=True) 428 | assert hasattr(lr, "predict") 429 | assert hasattr(lr, "score") 430 | 431 | lr = wrap(LinearRegression) 432 | assert hasattr(lr, "normalize") 433 | 434 | 435 | def test_transformer(): 436 | 437 | wrap(KernelCenterer, compat=True) 438 | 439 | tr = wrap(KernelCenterer) 440 | assert hasattr(tr, "transform") 441 | 442 | ss = wrap(StandardScaler) 443 | # check_estimator(ss) fails because the wrapper is not excluded 444 | # from tests that are known to fail for StandardScaler... 445 | assert hasattr(ss, "partial_fit") 446 | assert hasattr(ss, "inverse_transform") 447 | assert hasattr(ss, "fit_transform") 448 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | from sklearn_xarray.datasets import ( 2 | load_dummy_dataarray, 3 | load_dummy_dataset, 4 | load_digits_dataarray, 5 | load_wisdm_dataarray, 6 | ) 7 | 8 | import os 9 | from sklearn_xarray import ROOT_DIR 10 | 11 | 12 | def test_load_dummy_dataarray(): 13 | 14 | load_dummy_dataarray() 15 | 16 | 17 | def test_load_dummy_dataset(): 18 | 19 | load_dummy_dataset() 20 | 21 | 22 | def test_load_digits_dataarray(): 23 | 24 | load_digits_dataarray(nan_probability=0.1) 25 | 26 | load_digits_dataarray(load_images=True, nan_probability=0.1) 27 | 28 | 29 | def test_load_wisdm_dataarray(): 30 | 31 | load_wisdm_dataarray(folder=os.path.join(ROOT_DIR, "../data")) 32 | -------------------------------------------------------------------------------- /tests/test_model_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | 4 | from sklearn_xarray.model_selection import CrossValidatorWrapper 5 | from sklearn.model_selection import KFold, GroupKFold 6 | 7 | 8 | def test_cross_validator(): 9 | 10 | X_da = xr.DataArray( 11 | np.random.random((100, 10)), 12 | coords={"sample": range(100), "feature": range(10)}, 13 | dims=["sample", "feature"], 14 | ) 15 | 16 | X_ds = xr.Dataset( 17 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 18 | coords={"sample": range(100), "feature": range(10)}, 19 | ) 20 | 21 | cv = CrossValidatorWrapper(KFold(n_splits=3)) 22 | 23 | assert cv.get_n_splits() == 3 24 | 25 | cv_list = list(cv.split(X_da)) 26 | assert cv_list[0][0].shape[0] + cv_list[0][1].shape[0] == 100 27 | 28 | cv_list = list(cv.split(X_ds)) 29 | assert cv_list[0][0].shape[0] + cv_list[0][1].shape[0] == 100 30 | 31 | 32 | def test_cross_validator_groupwise(): 33 | 34 | coord_1 = ["a"] * 51 + ["b"] * 49 35 | coord_2 = list(range(10)) * 10 36 | 37 | X_da = xr.DataArray( 38 | np.random.random((100, 10)), 39 | coords={ 40 | "sample": range(100), 41 | "feature": range(10), 42 | "coord_1": (["sample"], coord_1), 43 | "coord_2": (["sample"], coord_2), 44 | }, 45 | dims=["sample", "feature"], 46 | ) 47 | 48 | cv = CrossValidatorWrapper(GroupKFold(n_splits=2), groupby="coord_1") 49 | 50 | cv_list = list(cv.split(X_da)) 51 | 52 | assert np.any([c.size == 51 for c in cv_list[0]]) 53 | -------------------------------------------------------------------------------- /tests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | import xarray.testing as xrt 4 | import numpy.testing as npt 5 | 6 | from sklearn_xarray.preprocessing import ( 7 | preprocess, 8 | transpose, 9 | split, 10 | segment, 11 | resample, 12 | concatenate, 13 | featurize, 14 | select, 15 | sanitize, 16 | reduce, 17 | Splitter, 18 | ) 19 | 20 | 21 | def test_preprocess(): 22 | 23 | from sklearn.preprocessing import scale 24 | 25 | X_da = xr.DataArray( 26 | np.random.random((100, 10)), 27 | coords={"sample": range(100), "feature": range(10)}, 28 | dims=("sample", "feature"), 29 | ) 30 | 31 | Xt_da_gt = X_da 32 | Xt_da_gt.data = scale(X_da) 33 | 34 | Xt_da = preprocess(X_da, scale) 35 | 36 | xrt.assert_allclose(Xt_da, Xt_da_gt) 37 | 38 | X_ds = xr.Dataset( 39 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 40 | coords={"sample": range(100), "feature": range(10)}, 41 | ) 42 | 43 | Xt_ds = preprocess(X_ds, scale) 44 | 45 | xrt.assert_allclose(Xt_ds, X_ds.apply(scale)) 46 | 47 | 48 | def test_groupwise(): 49 | 50 | from sklearn.preprocessing import scale 51 | 52 | coord_1 = ["a"] * 51 + ["b"] * 49 53 | coord_2 = list(range(10)) * 10 54 | 55 | X_ds = xr.Dataset( 56 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 57 | coords={ 58 | "sample": range(100), 59 | "feature": range(10), 60 | "coord_1": (["sample"], coord_1), 61 | "coord_2": (["sample"], coord_2), 62 | }, 63 | ) 64 | 65 | # test wrapped sklearn estimator 66 | preprocess(X_ds, scale, groupby="coord_1") 67 | 68 | # test newly defined estimator 69 | Xt_ds2, estimator = split( 70 | X_ds, 71 | new_dim="split_sample", 72 | new_len=5, 73 | groupby="coord_1", 74 | keep_coords_as="initial_sample", 75 | return_estimator=True, 76 | ) 77 | 78 | assert Xt_ds2.var_1.shape == (19, 10, 5) 79 | 80 | Xt_ds2 = estimator.inverse_transform(Xt_ds2) 81 | 82 | assert Xt_ds2.var_1.shape == (95, 10) 83 | 84 | 85 | def test_transpose(): 86 | 87 | # test on DataArray 88 | X_da = xr.DataArray( 89 | np.random.random((100, 10)), 90 | coords={"sample": range(100), "feature": range(10)}, 91 | dims=("sample", "feature"), 92 | ) 93 | 94 | Xt_da, estimator = transpose( 95 | X_da, order=["feature", "sample"], return_estimator=True 96 | ) 97 | 98 | xrt.assert_allclose(Xt_da, X_da.transpose()) 99 | 100 | Xt_da = estimator.inverse_transform(Xt_da) 101 | 102 | xrt.assert_allclose(Xt_da, X_da) 103 | 104 | # test on Dataset with subset of dimensions 105 | X_ds = xr.Dataset( 106 | { 107 | "var_1": ( 108 | ["sample", "feat_1", "feat_2"], 109 | np.random.random((100, 10, 5)), 110 | ), 111 | "var_2": (["feat_2", "sample"], np.random.random((5, 100))), 112 | }, 113 | coords={"sample": range(100), "feat_1": range(10), "feat_2": range(5)}, 114 | ) 115 | 116 | Xt_ds, estimator = transpose( 117 | X_ds, order=["sample", "feat_2"], return_estimator=True 118 | ) 119 | 120 | xrt.assert_allclose(Xt_ds, X_ds.transpose("sample", "feat_1", "feat_2")) 121 | 122 | Xt_ds = estimator.inverse_transform(Xt_ds) 123 | 124 | xrt.assert_allclose(Xt_ds, X_ds) 125 | 126 | 127 | def test_split(): 128 | 129 | # test on DataArray with number of samples multiple of new length 130 | X_da = xr.DataArray( 131 | np.random.random((100, 10)), 132 | coords={ 133 | "sample": range(100), 134 | "feature": range(10), 135 | "coord_1": (["sample", "feature"], np.tile("Test", (100, 10))), 136 | }, 137 | dims=("sample", "feature"), 138 | ) 139 | 140 | estimator = Splitter( 141 | new_dim="split_sample", 142 | new_len=5, 143 | reduce_index="subsample", 144 | axis=1, 145 | keep_coords_as="sample_coord", 146 | ) 147 | 148 | Xt_da = estimator.fit_transform(X_da) 149 | 150 | assert Xt_da.shape == (20, 5, 10) 151 | npt.assert_allclose(Xt_da[0, :, 0], X_da[:5, 0]) 152 | 153 | Xit_da = estimator.inverse_transform(Xt_da) 154 | 155 | xrt.assert_allclose(X_da, Xit_da) 156 | 157 | # test on Dataset with number of samples NOT multiple of new length 158 | X_ds = xr.Dataset( 159 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 160 | coords={"sample": range(100), "feature": range(10)}, 161 | ) 162 | 163 | Xt_ds = split( 164 | X_ds, 165 | new_dim="split_sample", 166 | new_len=7, 167 | reduce_index="head", 168 | axis=1, 169 | new_index_func=None, 170 | ) 171 | 172 | assert Xt_ds["var_1"].shape == (14, 7, 10) 173 | npt.assert_allclose(Xt_ds.var_1[0, :, 0], X_ds.var_1[:7, 0]) 174 | 175 | 176 | def test_segment(): 177 | 178 | X_da = xr.DataArray( 179 | np.tile(np.arange(10), (100, 1)), 180 | coords={ 181 | "sample": range(100), 182 | "feature": range(10), 183 | "coord_1": (["sample", "feature"], np.tile("Test", (100, 10))), 184 | }, 185 | dims=("sample", "feature"), 186 | ) 187 | 188 | Xt_da, estimator = segment( 189 | X_da, 190 | new_dim="split_sample", 191 | new_len=20, 192 | step=5, 193 | axis=0, 194 | reduce_index="subsample", 195 | keep_coords_as="backup", 196 | return_estimator=True, 197 | ) 198 | 199 | assert Xt_da.coord_1.shape == (20, 17, 10) 200 | npt.assert_allclose(Xt_da[:, 0, 0], X_da[:20, 0]) 201 | 202 | Xit_da = estimator.inverse_transform(Xt_da) 203 | 204 | xrt.assert_allclose(Xit_da, X_da) 205 | 206 | X_ds = xr.Dataset( 207 | { 208 | "var_1": ( 209 | ["sample", "feat_1", "feat_2"], 210 | np.tile(np.arange(10), (100, 10, 1)), 211 | ), 212 | "var_2": (["feat_2"], np.random.random((10,))), 213 | }, 214 | coords={ 215 | "sample": range(100), 216 | "feat_1": range(10), 217 | "feat_2": range(10), 218 | "coord_1": (["sample", "feat_1"], np.tile("Test", (100, 10))), 219 | }, 220 | ) 221 | 222 | Xt_ds, estimator = segment( 223 | X_ds, 224 | new_dim="split_sample", 225 | new_len=20, 226 | step=5, 227 | reduce_index="head", 228 | keep_coords_as="backup", 229 | return_estimator=True, 230 | ) 231 | 232 | assert Xt_ds.var_1.shape == (17, 10, 10, 20) 233 | npt.assert_allclose(Xt_ds.var_1[0, 0, 0, :], X_ds.var_1[:20, 0, 0]) 234 | 235 | xrt.assert_allclose(estimator.inverse_transform(Xt_ds), X_ds) 236 | 237 | 238 | def test_resample(): 239 | 240 | import pandas as pd 241 | 242 | X_da = xr.DataArray( 243 | np.random.random((100, 10)), 244 | coords={ 245 | "sample": pd.timedelta_range(0, periods=100, freq="10ms"), 246 | "feature": range(10), 247 | }, 248 | dims=("sample", "feature"), 249 | ) 250 | 251 | resample(X_da, freq="20ms") 252 | 253 | X_ds = xr.Dataset( 254 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 255 | coords={ 256 | "sample": pd.timedelta_range(0, periods=100, freq="10ms"), 257 | "feature": range(10), 258 | }, 259 | ) 260 | 261 | # TODO: check result 262 | resample(X_ds, freq="20ms") 263 | 264 | 265 | def test_concatenate(): 266 | 267 | X_ds = xr.Dataset( 268 | { 269 | "var_1": (["sample", "feature"], np.random.random((100, 10))), 270 | "var_2": (["sample", "feature"], np.random.random((100, 10))), 271 | "var_3": (["sample", "feature"], np.random.random((100, 10))), 272 | }, 273 | coords={"sample": range(100), "feature": range(10)}, 274 | ) 275 | 276 | Xt_da, concatenator = concatenate( 277 | X_ds, return_array=True, return_estimator=True 278 | ) 279 | 280 | assert Xt_da.shape == (100, 30) 281 | 282 | xrt.assert_allclose(concatenator.inverse_transform(Xt_da), X_ds) 283 | 284 | Xt_ds2, concatenator2 = concatenate( 285 | X_ds, 286 | variables=["var_1", "var_2"], 287 | new_index_func=np.arange, 288 | return_estimator=True, 289 | ) 290 | 291 | assert Xt_ds2.Feature.shape == (100, 20) 292 | npt.assert_equal(Xt_ds2.feature.values, np.arange(20)) 293 | 294 | xrt.assert_allclose(concatenator2.inverse_transform(Xt_ds2), X_ds) 295 | 296 | 297 | def test_featurize(): 298 | 299 | X_da = xr.DataArray( 300 | np.random.random((100, 10, 10)), 301 | coords={ 302 | "sample": range(100), 303 | "feat_1": range(10), 304 | "feat_2": range(10), 305 | }, 306 | dims=("sample", "feat_1", "feat_2"), 307 | ) 308 | 309 | Xt_da, featurizer = featurize(X_da, return_estimator=True) 310 | 311 | assert Xt_da.shape == (100, 100) 312 | 313 | X_ds = xr.Dataset( 314 | { 315 | "var_1": ( 316 | ["sample", "feat_1", "feat_2"], 317 | np.random.random((100, 10, 10)), 318 | ), 319 | "var_2": (["sample", "feat_1"], np.random.random((100, 10))), 320 | }, 321 | coords={ 322 | "sample": range(100), 323 | "feat_1": range(10), 324 | "feat_2": range(10), 325 | }, 326 | ) 327 | 328 | Xt_ds, featurizer = featurize( 329 | X_ds, return_array=True, return_estimator=True 330 | ) 331 | 332 | assert Xt_ds.shape == (100, 110) 333 | 334 | 335 | def test_select(): 336 | 337 | X_da = xr.DataArray( 338 | np.random.random((100, 10)), 339 | coords={"sample": range(100), "feature": range(10)}, 340 | dims=("sample", "feature"), 341 | ) 342 | 343 | selector = np.zeros((100, 10)) 344 | selector[[5, 10, 15], 0] = 1 345 | X_da["selector"] = (["sample", "feature"], selector) 346 | 347 | Xt_da = select(X_da, coord="selector") 348 | 349 | npt.assert_equal(Xt_da.sample.values, np.array((5, 10, 15))) 350 | 351 | 352 | def test_sanitize(): 353 | 354 | X_da = xr.DataArray( 355 | np.random.random((100, 10)), 356 | coords={"sample": range(100), "feature": range(10)}, 357 | dims=("sample", "feature"), 358 | ) 359 | 360 | X_da[0, 0] = np.nan 361 | 362 | Xt_da = sanitize(X_da) 363 | 364 | xrt.assert_allclose(X_da[1:], Xt_da) 365 | 366 | X_ds = xr.Dataset( 367 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 368 | coords={"sample": range(100), "feature": range(10)}, 369 | ) 370 | 371 | X_ds["var_1"][0, 0] = np.nan 372 | 373 | Xt_ds = sanitize(X_ds) 374 | 375 | xrt.assert_allclose(X_ds.isel(sample=range(1, 100)), Xt_ds) 376 | 377 | 378 | def test_reduce(): 379 | 380 | X_da = xr.DataArray( 381 | np.random.random((100, 10)), 382 | coords={"sample": range(100), "feature": range(10)}, 383 | dims=("sample", "feature"), 384 | ) 385 | 386 | Xt_da = reduce(X_da) 387 | 388 | xrt.assert_allclose(Xt_da, X_da.reduce(np.linalg.norm, dim="feature")) 389 | -------------------------------------------------------------------------------- /tests/test_target.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | import numpy.testing as npt 4 | 5 | from sklearn_xarray import Target 6 | from sklearn.preprocessing import LabelBinarizer 7 | 8 | 9 | def test_constructor(): 10 | 11 | from sklearn_xarray.utils import convert_to_ndarray 12 | 13 | coord_1 = ["a"] * 51 + ["b"] * 49 14 | coord_2 = list(range(10)) * 10 15 | 16 | X_ds = xr.Dataset( 17 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 18 | coords={ 19 | "sample": range(100), 20 | "feature": range(10), 21 | "coord_1": (["sample"], coord_1), 22 | "coord_2": (["sample"], coord_2), 23 | }, 24 | ) 25 | 26 | target = Target(transform_func=convert_to_ndarray) 27 | target.assign_to(X_ds) 28 | 29 | npt.assert_equal(target.values, np.array(X_ds.var_1)) 30 | 31 | target = Target(coord="coord_1", transformer=LabelBinarizer())(X_ds) 32 | 33 | npt.assert_equal(target.values, LabelBinarizer().fit_transform(coord_1)) 34 | 35 | 36 | def test_str(): 37 | 38 | assert str(Target()).startswith( 39 | "Unassigned sklearn_xarray.Target without coordinate" 40 | ) 41 | 42 | assert str(Target(coord="test")).startswith( 43 | 'Unassigned sklearn_xarray.Target with coordinate "test"' 44 | ) 45 | 46 | assert str(Target()(np.ones(10))).startswith( 47 | "sklearn_xarray.Target with data:" 48 | ) 49 | 50 | 51 | def test_array(): 52 | 53 | coord_1 = ["a"] * 51 + ["b"] * 49 54 | coord_2 = list(range(10)) * 10 55 | 56 | X_ds = xr.Dataset( 57 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 58 | coords={ 59 | "sample": range(100), 60 | "feature": range(10), 61 | "coord_1": (["sample"], coord_1), 62 | "coord_2": (["sample"], coord_2), 63 | }, 64 | ) 65 | 66 | target = Target( 67 | coord="coord_1", 68 | transform_func=LabelBinarizer().fit_transform, 69 | lazy=True, 70 | )(X_ds) 71 | 72 | npt.assert_equal(np.array(target), LabelBinarizer().fit_transform(coord_1)) 73 | 74 | 75 | def test_getitem(): 76 | 77 | coord_1 = ["a"] * 51 + ["b"] * 49 78 | coord_2 = list(range(10)) * 10 79 | 80 | X_ds = xr.Dataset( 81 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 82 | coords={ 83 | "sample": range(100), 84 | "feature": range(10), 85 | "coord_1": (["sample"], coord_1), 86 | "coord_2": (["sample"], coord_2), 87 | }, 88 | ) 89 | 90 | target = Target( 91 | coord="coord_1", transform_func=LabelBinarizer().fit_transform 92 | )(X_ds) 93 | 94 | y_test = target[-1] 95 | 96 | assert y_test == LabelBinarizer().fit_transform(coord_1)[-1] 97 | 98 | # test lazy eval 99 | target = Target( 100 | coord="coord_1", 101 | transform_func=LabelBinarizer().fit_transform, 102 | lazy=True, 103 | )(X_ds) 104 | 105 | y_test = target[-1] 106 | 107 | assert y_test == LabelBinarizer().fit_transform(coord_1)[-1] 108 | assert not y_test.lazy 109 | 110 | 111 | def test_shape_and_ndim(): 112 | 113 | coord_1 = ["a"] * 51 + ["b"] * 49 114 | coord_2 = list(range(10)) * 10 115 | 116 | X_ds = xr.Dataset( 117 | {"var_1": (["sample", "feature"], np.random.random((100, 10)))}, 118 | coords={ 119 | "sample": range(100), 120 | "feature": range(10), 121 | "coord_1": (["sample"], coord_1), 122 | "coord_2": (["sample"], coord_2), 123 | }, 124 | ) 125 | 126 | target = Target( 127 | coord="coord_1", transform_func=LabelBinarizer().fit_transform 128 | )(X_ds) 129 | 130 | npt.assert_equal( 131 | target.shape, LabelBinarizer().fit_transform(coord_1).shape 132 | ) 133 | 134 | npt.assert_equal(target.ndim, LabelBinarizer().fit_transform(coord_1).ndim) 135 | 136 | 137 | def test_multidim_coord(): 138 | 139 | coord_1 = np.tile(["a"] * 51 + ["b"] * 49, (10, 1)).T 140 | coord_2 = np.random.random((100, 10, 10)) 141 | 142 | X_ds = xr.Dataset( 143 | { 144 | "var_1": ( 145 | ["sample", "feat_1", "feat_2"], 146 | np.random.random((100, 10, 10)), 147 | ) 148 | }, 149 | coords={ 150 | "sample": range(100), 151 | "feature": range(10), 152 | "coord_1": (["sample", "feat_1"], coord_1), 153 | "coord_2": (["sample", "feat_1", "feat_2"], coord_2), 154 | }, 155 | ) 156 | 157 | target_1 = Target( 158 | coord="coord_1", 159 | transform_func=LabelBinarizer().fit_transform, 160 | dim="sample", 161 | )(X_ds) 162 | target_2 = Target( 163 | coord="coord_2", dim=["sample", "feat_1"], reduce_func=np.mean 164 | )(X_ds) 165 | 166 | npt.assert_equal(target_1, LabelBinarizer().fit_transform(coord_1[:, 0])) 167 | npt.assert_equal(target_2, np.mean(coord_2, 2)) 168 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | 4 | import numpy.testing as npt 5 | 6 | from sklearn_xarray.utils import ( 7 | is_dataarray, 8 | is_dataset, 9 | is_target, 10 | convert_to_ndarray, 11 | get_group_indices, 12 | ) 13 | 14 | from sklearn_xarray import Target 15 | 16 | 17 | def test_is_dataarray(): 18 | 19 | X_da = xr.DataArray(np.random.random((100, 10))) 20 | 21 | assert is_dataarray(X_da) 22 | 23 | X_not_a_da = np.random.random((100, 10)) 24 | 25 | assert not is_dataarray(X_not_a_da) 26 | 27 | 28 | def test_is_dataset(): 29 | 30 | X_ds = xr.Dataset({"var_1": 1}) 31 | 32 | assert is_dataset(X_ds) 33 | 34 | X_not_a_ds = np.random.random((100, 10)) 35 | 36 | assert not is_dataarray(X_not_a_ds) 37 | 38 | 39 | def test_is_target(): 40 | 41 | target = Target() 42 | 43 | assert is_target(target) 44 | 45 | not_a_target = 1 46 | 47 | assert not is_target(not_a_target) 48 | 49 | 50 | def test_convert_to_ndarray(): 51 | 52 | from collections import OrderedDict 53 | 54 | X_ds = xr.Dataset( 55 | OrderedDict( 56 | [ 57 | ( 58 | "var_1", 59 | (["sample", "feature"], np.random.random((100, 10))), 60 | ), 61 | ( 62 | "var_2", 63 | (["sample", "feature"], np.random.random((100, 10))), 64 | ), 65 | ] 66 | ), 67 | coords={"sample": range(100), "feature": range(10)}, 68 | ) 69 | 70 | X_arr = convert_to_ndarray(X_ds) 71 | 72 | npt.assert_equal(X_arr, np.dstack((X_ds.var_1, X_ds.var_2))) 73 | 74 | 75 | def test_get_group_indices(): 76 | 77 | import itertools 78 | 79 | coord_1 = ["a"] * 50 + ["b"] * 50 80 | coord_2 = np.tile(list(range(10)) * 10, (10, 1)).T 81 | 82 | X_da = xr.DataArray( 83 | np.random.random((100, 10)), 84 | coords={ 85 | "sample": range(100), 86 | "feature": range(10), 87 | "coord_1": (["sample"], coord_1), 88 | "coord_2": (["sample", "feature"], coord_2), 89 | }, 90 | dims=["sample", "feature"], 91 | ) 92 | 93 | g1 = get_group_indices(X_da, "coord_1") 94 | for i, gg in enumerate(g1): 95 | idx = np.array(coord_1) == np.unique(coord_1)[i] 96 | npt.assert_equal(gg, idx) 97 | 98 | g2 = get_group_indices(X_da, ["coord_1", "coord_2"], group_dim="sample") 99 | combinations = list( 100 | itertools.product(np.unique(coord_1), np.unique(coord_2)) 101 | ) 102 | for i, gg in enumerate(g2): 103 | idx = (np.array(coord_1) == combinations[i][0]) & ( 104 | np.array(coord_2)[:, 0] == combinations[i][1] 105 | ) 106 | npt.assert_equal(gg, idx) 107 | 108 | 109 | def test_segment_array(): 110 | 111 | from sklearn_xarray.utils import segment_array 112 | 113 | arr = np.array( 114 | [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] 115 | ) 116 | 117 | arr_seg_1 = segment_array(arr, axis=1, new_len=3, step=1) 118 | arr_target_1 = np.array( 119 | [ 120 | [[0, 1, 2], [1, 2, 3]], 121 | [[4, 5, 6], [5, 6, 7]], 122 | [[8, 9, 10], [9, 10, 11]], 123 | [[12, 13, 14], [13, 14, 15]], 124 | ] 125 | ) 126 | 127 | npt.assert_allclose(arr_target_1, arr_seg_1) 128 | 129 | arr_seg_2 = segment_array(arr, axis=1, new_len=2, step=2, new_axis=1) 130 | arr_target_2 = np.array( 131 | [ 132 | [[0, 1], [2, 3]], 133 | [[4, 5], [6, 7]], 134 | [[8, 9], [10, 11]], 135 | [[12, 13], [14, 15]], 136 | ] 137 | ).transpose((0, 2, 1)) 138 | 139 | npt.assert_allclose(arr_target_2, arr_seg_2) 140 | 141 | arr_seg_3 = segment_array(arr, axis=0, new_len=2, step=1, new_axis=1) 142 | arr_target_3 = np.array( 143 | [ 144 | [[0, 4], [1, 5], [2, 6], [3, 7]], 145 | [[4, 8], [5, 9], [6, 10], [7, 11]], 146 | [[8, 12], [9, 13], [10, 14], [11, 15]], 147 | ] 148 | ).transpose((0, 2, 1)) 149 | 150 | npt.assert_allclose(arr_target_3, arr_seg_3) 151 | 152 | arr_seg_4 = segment_array(arr, axis=1, new_len=3, step=2, new_axis=2) 153 | arr_target_4 = np.array( 154 | [[[0, 1, 2]], [[4, 5, 6]], [[8, 9, 10]], [[12, 13, 14]]] 155 | ) 156 | 157 | npt.assert_allclose(arr_target_4, arr_seg_4) 158 | --------------------------------------------------------------------------------