├── dask_searchcv ├── tests │ ├── __init__.py │ ├── test_model_selection.py │ └── test_model_selection_sklearn.py ├── _compat.py ├── __init__.py ├── SCIKIT_LEARN_LICENSE.txt ├── _normalize.py ├── utils.py ├── utils_test.py ├── methods.py └── _version.py ├── .coveragerc ├── .gitignore ├── MANIFEST.in ├── setup.cfg ├── docs ├── source │ ├── api.rst │ ├── images │ │ ├── merged_grid_search_graph.dot │ │ ├── unmerged_grid_search_graph.dot │ │ ├── merged_grid_search_graph.svg │ │ └── unmerged_grid_search_graph.svg │ ├── sphinxext │ │ ├── LICENSE.txt │ │ ├── numpydoc.py │ │ ├── docscrape_sphinx.py │ │ └── docscrape.py │ ├── conf.py │ └── index.rst ├── Makefile └── make.bat ├── setup.py ├── .travis.yml ├── LICENSE.txt └── README.rst /dask_searchcv/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | dask_searchcv/_version.py 4 | */test_*.py 5 | source = 6 | dask_searchcv 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | docs/build 4 | build/ 5 | dist/ 6 | .idea/ 7 | log.* 8 | log 9 | .coverage 10 | .DS_Store 11 | *.swp 12 | *.swo 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dask_searchcv *.py 2 | 3 | include setup.py 4 | include README.rst 5 | include LICENSE.txt 6 | include MANIFEST.in 7 | include versioneer.py 8 | -------------------------------------------------------------------------------- /dask_searchcv/_compat.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | from sklearn import __version__ 3 | 4 | _SK_VERSION = LooseVersion(__version__) 5 | 6 | _HAS_MULTIPLE_METRICS = _SK_VERSION >= '0.19.0' 7 | -------------------------------------------------------------------------------- /dask_searchcv/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .model_selection import GridSearchCV, RandomizedSearchCV 4 | 5 | from ._version import get_versions 6 | __version__ = get_versions()['version'] 7 | del get_versions 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = __init__.py 3 | max-line-length = 90 4 | 5 | [versioneer] 6 | VCS = git 7 | style = pep440 8 | versionfile_source = dask_searchcv/_version.py 9 | versionfile_build = dask_searchcv/_version.py 10 | tag_prefix = 11 | parentdir_prefix = dask_searchcv- 12 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | === 3 | 4 | .. currentmodule:: dask_searchcv 5 | 6 | .. autosummary:: 7 | GridSearchCV 8 | RandomizedSearchCV 9 | 10 | .. autoclass:: GridSearchCV 11 | :members: 12 | :inherited-members: 13 | 14 | .. autoclass:: RandomizedSearchCV 15 | :members: 16 | :inherited-members: 17 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = dask-searchcv 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os.path import exists 2 | 3 | from setuptools import setup 4 | import versioneer 5 | 6 | install_requires = ["dask >= 0.14.0", 7 | "toolz >= 0.8.2", 8 | "scikit-learn >= 0.18.0", 9 | "scipy >= 0.13.0", 10 | "numpy >= 1.8.0"] 11 | 12 | setup(name='dask-searchcv', 13 | version=versioneer.get_version(), 14 | cmdclass=versioneer.get_cmdclass(), 15 | license='BSD', 16 | url='http://github.com/dask/dask-searchcv', 17 | maintainer='Jim Crist', 18 | maintainer_email='jcrist@continuum.io', 19 | install_requires=install_requires, 20 | description='Tools for doing hyperparameter search with Scikit-Learn and Dask', 21 | long_description=(open('README.rst').read() if exists('README.rst') 22 | else ''), 23 | packages=['dask_searchcv', 'dask_searchcv.tests']) 24 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=dask-searchcv 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/source/images/merged_grid_search_graph.dot: -------------------------------------------------------------------------------- 1 | digraph merged { 2 | rankdir="BT" 3 | node [fontname = "Inconsolata"] 4 | 5 | data [label="Training Data"] 6 | 7 | vect1 [label="CountVectorizer\n- ngram_range=(1, 1)"] 8 | 9 | tfidf_1_1 [label="TfidfTransformer\n- norm='l1'"] 10 | tfidf_1_2 [label="TfidfTransformer\n- norm='l2'"] 11 | 12 | sgd_1_1_1 [label="SGDClassifier\n- alpha=1e-3"] 13 | sgd_1_1_2 [label="SGDClassifier\n- alpha=1e-4"] 14 | sgd_1_1_3 [label="SGDClassifier\n- alpha=1e-5"] 15 | 16 | sgd_1_2_1 [label="SGDClassifier\n- alpha=1e-3"] 17 | sgd_1_2_2 [label="SGDClassifier\n- alpha=1e-4"] 18 | sgd_1_2_3 [label="SGDClassifier\n- alpha=1e-5"] 19 | 20 | best [label="Choose Best Parameters"] 21 | 22 | data -> vect1 23 | 24 | vect1 -> tfidf_1_1 -> sgd_1_1_1 -> best 25 | tfidf_1_1 -> sgd_1_1_2 -> best 26 | tfidf_1_1 -> sgd_1_1_3 -> best 27 | 28 | vect1 -> tfidf_1_2 -> sgd_1_2_1 -> best 29 | tfidf_1_2 -> sgd_1_2_2 -> best 30 | tfidf_1_2 -> sgd_1_2_3 -> best 31 | } 32 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | 4 | env: 5 | matrix: 6 | - PYTHON=2.7 SKLEARN=0.18.1 DASK_DEV='false' TEST_FLAGS= 7 | - PYTHON=3.5 SKLEARN=0.18.0 DASK_DEV='false' TEST_FLAGS= 8 | - PYTHON=3.6 SKLEARN=0.19.1 DASK_DEV='true' TEST_FLAGS=--doctest-modules 9 | 10 | addons: 11 | apt: 12 | packages: 13 | - graphviz 14 | 15 | install: 16 | # Install conda 17 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 18 | - bash miniconda.sh -b -p $HOME/miniconda 19 | - export PATH="$HOME/miniconda/bin:$PATH" 20 | - conda config --set always_yes yes --set changeps1 no 21 | - conda update conda 22 | 23 | # Install dependencies 24 | - conda create -n test-environment python=$PYTHON 25 | - source activate test-environment 26 | - conda install dask distributed numpy scikit-learn=$SKLEARN cytoolz pytest 27 | - if [[ $DASK_DEV == 'true' ]]; then pip install -U git+https://github.com/dask/dask.git; fi 28 | - if [[ $DASK_DEV == 'true' ]]; then pip install -U git+https://github.com/dask/distributed.git; fi 29 | - pip install -q graphviz flake8 30 | - pip install --no-deps -e . 31 | 32 | script: 33 | - py.test dask_searchcv --verbose $TEST_FLAGS 34 | - flake8 dask_searchcv 35 | 36 | notifications: 37 | email: false 38 | -------------------------------------------------------------------------------- /docs/source/images/unmerged_grid_search_graph.dot: -------------------------------------------------------------------------------- 1 | digraph unmerged { 2 | rankdir="BT" 3 | node [fontname = "Inconsolata"] 4 | 5 | data [label="Training Data"] 6 | 7 | vect1 [label="CountVectorizer\n- ngram_range=(1, 1)"] 8 | vect2 [label="CountVectorizer\n- ngram_range=(1, 1)"] 9 | vect3 [label="CountVectorizer\n- ngram_range=(1, 1)"] 10 | vect4 [label="CountVectorizer\n- ngram_range=(1, 1)"] 11 | vect5 [label="CountVectorizer\n- ngram_range=(1, 1)"] 12 | vect6 [label="CountVectorizer\n- ngram_range=(1, 1)"] 13 | 14 | tfidf1 [label="TfidfTransformer\n- norm='l1'"] 15 | tfidf2 [label="TfidfTransformer\n- norm='l1'"] 16 | tfidf3 [label="TfidfTransformer\n- norm='l1'"] 17 | tfidf4 [label="TfidfTransformer\n- norm='l2'"] 18 | tfidf5 [label="TfidfTransformer\n- norm='l2'"] 19 | tfidf6 [label="TfidfTransformer\n- norm='l2'"] 20 | 21 | sgd1 [label="SGDClassifier\n- alpha=1e-3"] 22 | sgd2 [label="SGDClassifier\n- alpha=1e-4"] 23 | sgd3 [label="SGDClassifier\n- alpha=1e-5"] 24 | sgd4 [label="SGDClassifier\n- alpha=1e-3"] 25 | sgd5 [label="SGDClassifier\n- alpha=1e-4"] 26 | sgd6 [label="SGDClassifier\n- alpha=1e-5"] 27 | 28 | best [label="Choose Best Parameters"] 29 | 30 | data -> vect1 -> tfidf1 -> sgd1 -> best 31 | data -> vect2 -> tfidf2 -> sgd2 -> best 32 | data -> vect3 -> tfidf3 -> sgd3 -> best 33 | data -> vect4 -> tfidf4 -> sgd4 -> best 34 | data -> vect5 -> tfidf5 -> sgd5 -> best 35 | data -> vect6 -> tfidf6 -> sgd6 -> best 36 | } 37 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Continuum Analytics, Inc. and contributors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | Neither the name of Continuum Analytics nor the names of any contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 28 | THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /docs/source/sphinxext/LICENSE.txt: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | The files 3 | - numpydoc.py 4 | - docscrape.py 5 | - docscrape_sphinx.py 6 | have the following license: 7 | 8 | Copyright (C) 2008 Stefan van der Walt , Pauli Virtanen 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are 12 | met: 13 | 14 | 1. Redistributions of source code must retain the above copyright 15 | notice, this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in 18 | the documentation and/or other materials provided with the 19 | distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 22 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 25 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 28 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 29 | STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING 30 | IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /dask_searchcv/SCIKIT_LEARN_LICENSE.txt: -------------------------------------------------------------------------------- 1 | New BSD License 2 | 3 | Copyright (c) 2007–2016 The scikit-learn developers. 4 | All rights reserved. 5 | 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | a. Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | b. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | c. Neither the name of the Scikit-learn Developers nor the names of 16 | its contributors may be used to endorse or promote products 17 | derived from this software without specific prior written 18 | permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 29 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 30 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 31 | DAMAGE. 32 | 33 | -------------------------------------------------------------------------------- /dask_searchcv/_normalize.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import numpy as np 4 | from dask.base import normalize_token 5 | 6 | from sklearn.base import BaseEstimator 7 | from sklearn.model_selection._split import (_BaseKFold, 8 | BaseShuffleSplit, 9 | LeaveOneOut, 10 | LeaveOneGroupOut, 11 | LeavePOut, 12 | LeavePGroupsOut, 13 | PredefinedSplit, 14 | _CVIterableWrapper) 15 | 16 | 17 | @normalize_token.register(BaseEstimator) 18 | def normalize_estimator(est): 19 | """Normalize an estimator. 20 | 21 | Note: Since scikit-learn requires duck-typing, but not sub-typing from 22 | ``BaseEstimator``, we sometimes need to call this function directly.""" 23 | return type(est).__name__, normalize_token(est.get_params()) 24 | 25 | 26 | def normalize_random_state(random_state): 27 | if isinstance(random_state, np.random.RandomState): 28 | return random_state.get_state() 29 | return random_state 30 | 31 | 32 | @normalize_token.register(_BaseKFold) 33 | def normalize_KFold(x): 34 | # Doesn't matter if shuffle is False 35 | rs = normalize_random_state(x.random_state) if x.shuffle else None 36 | return (type(x).__name__, x.n_splits, x.shuffle, rs) 37 | 38 | 39 | @normalize_token.register(BaseShuffleSplit) 40 | def normalize_ShuffleSplit(x): 41 | return (type(x).__name__, x.n_splits, x.test_size, x.train_size, 42 | normalize_random_state(x.random_state)) 43 | 44 | 45 | @normalize_token.register((LeaveOneOut, LeaveOneGroupOut)) 46 | def normalize_LeaveOneOut(x): 47 | return type(x).__name__ 48 | 49 | 50 | @normalize_token.register((LeavePOut, LeavePGroupsOut)) 51 | def normalize_LeavePOut(x): 52 | return (type(x).__name__, x.p if hasattr(x, 'p') else x.n_groups) 53 | 54 | 55 | @normalize_token.register(PredefinedSplit) 56 | def normalize_PredefinedSplit(x): 57 | return (type(x).__name__, x.test_fold) 58 | 59 | 60 | @normalize_token.register(_CVIterableWrapper) 61 | def normalize_CVIterableWrapper(x): 62 | return (type(x).__name__, x.cv) 63 | -------------------------------------------------------------------------------- /dask_searchcv/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from distutils.version import LooseVersion 3 | 4 | import dask 5 | import dask.array as da 6 | from dask.base import tokenize 7 | from dask.delayed import delayed, Delayed 8 | 9 | from sklearn.utils.validation import indexable, _is_arraylike 10 | 11 | 12 | if LooseVersion(dask.__version__) > '0.15.4': 13 | from dask.base import is_dask_collection 14 | else: 15 | from dask.base import Base 16 | 17 | def is_dask_collection(x): 18 | return isinstance(x, Base) 19 | 20 | 21 | def _indexable(x): 22 | return indexable(x)[0] 23 | 24 | 25 | def _maybe_indexable(x): 26 | return indexable(x)[0] if _is_arraylike(x) else x 27 | 28 | 29 | def to_indexable(*args, **kwargs): 30 | """Ensure that all args are an indexable type. 31 | 32 | Conversion runs lazily for dask objects, immediately otherwise. 33 | 34 | Parameters 35 | ---------- 36 | args : array_like or scalar 37 | allow_scalars : bool, optional 38 | Whether to allow scalars in args. Default is False. 39 | """ 40 | if kwargs.get('allow_scalars', False): 41 | indexable = _maybe_indexable 42 | else: 43 | indexable = _indexable 44 | for x in args: 45 | if x is None or isinstance(x, da.Array): 46 | yield x 47 | elif is_dask_collection(x): 48 | yield delayed(indexable, pure=True)(x) 49 | else: 50 | yield indexable(x) 51 | 52 | 53 | def to_keys(dsk, *args): 54 | for x in args: 55 | if x is None: 56 | yield None 57 | elif isinstance(x, da.Array): 58 | x = delayed(x) 59 | dsk.update(x.dask) 60 | yield x.key 61 | elif isinstance(x, Delayed): 62 | dsk.update(x.dask) 63 | yield x.key 64 | else: 65 | assert not is_dask_collection(x) 66 | key = 'array-' + tokenize(x) 67 | dsk[key] = x 68 | yield key 69 | 70 | 71 | def copy_estimator(est): 72 | # Semantically, we'd like to use `sklearn.clone` here instead. However, 73 | # `sklearn.clone` isn't threadsafe, so we don't want to call it in 74 | # tasks. Since `est` is guaranteed to not be a fit estimator, we can 75 | # use `copy.deepcopy` here without fear of copying large data. 76 | return copy.deepcopy(est) 77 | 78 | 79 | def unzip(itbl, n): 80 | return zip(*itbl) if itbl else [()] * n 81 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | dask-searchcv 2 | ============= 3 | 4 | Dask-SearchCV is now included in `Dask-ML `__. 5 | 6 | Further development to Dask-SearchCV is occuring in the 7 | `Dask-ML repository `__. Please post issues 8 | and make pull requests there. 9 | 10 | |Travis Status| |Doc Status| |Conda Badge| |PyPI Badge| 11 | 12 | Tools for performing hyperparameter search with 13 | `Scikit-Learn `__ and `Dask `__. 14 | 15 | Highlights 16 | ---------- 17 | 18 | - Drop-in replacement for Scikit-Learn's ``GridSearchCV`` and 19 | ``RandomizedSearchCV``. 20 | 21 | - Hyperparameter optimization can be done in parallel using threads, processes, 22 | or distributed across a cluster. 23 | 24 | - Works well with Dask collections. Dask arrays, dataframes, and delayed can be 25 | passed to ``fit``. 26 | 27 | - Candidate estimators with identical parameters and inputs will only be fit 28 | once. For composite-estimators such as ``Pipeline`` this can be significantly 29 | more efficient as it can avoid expensive repeated computations. 30 | 31 | 32 | For more information, check out the `documentation `__. 33 | 34 | 35 | Install 36 | ------- 37 | 38 | Dask-searchcv is available via ``conda`` or ``pip``: 39 | 40 | :: 41 | 42 | # Install with conda 43 | $ conda install dask-searchcv -c conda-forge 44 | 45 | # Install with pip 46 | $ pip install dask-searchcv 47 | 48 | 49 | Example 50 | ------- 51 | 52 | .. code-block:: python 53 | 54 | from sklearn.datasets import load_digits 55 | from sklearn.svm import SVC 56 | import dask_searchcv as dcv 57 | import numpy as np 58 | 59 | digits = load_digits() 60 | 61 | param_space = {'C': np.logspace(-4, 4, 9), 62 | 'gamma': np.logspace(-4, 4, 9), 63 | 'class_weight': [None, 'balanced']} 64 | 65 | model = SVC(kernel='rbf') 66 | search = dcv.GridSearchCV(model, param_space, cv=3) 67 | 68 | search.fit(digits.data, digits.target) 69 | 70 | 71 | .. |Travis Status| image:: https://travis-ci.org/dask/dask-searchcv.svg?branch=master 72 | :target: https://travis-ci.org/dask/dask-searchcv 73 | .. |Doc Status| image:: http://readthedocs.org/projects/dask-searchcv/badge/?version=latest 74 | :target: http://dask-searchcv.readthedocs.io/en/latest/index.html 75 | :alt: Documentation Status 76 | .. |PyPI Badge| image:: https://img.shields.io/pypi/v/dask-searchcv.svg 77 | :target: https://pypi.python.org/pypi/dask-searchcv 78 | .. |Conda Badge| image:: https://anaconda.org/conda-forge/dask-searchcv/badges/version.svg 79 | :target: https://anaconda.org/conda-forge/dask-searchcv 80 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # dask-searchcv documentation build configuration file, created by 5 | # sphinx-quickstart on Fri Mar 31 16:55:14 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | sys.path.insert(0, os.path.abspath('.')) 23 | 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # 29 | # needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 35 | 'sphinx.ext.mathjax', 'sphinxext.numpydoc'] 36 | 37 | autosummary_generate = True 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # The suffix(es) of source filenames. 43 | # You can specify multiple suffix as a list of string: 44 | # 45 | # source_suffix = ['.rst', '.md'] 46 | source_suffix = '.rst' 47 | 48 | # The master toctree document. 49 | master_doc = 'index' 50 | 51 | # General information about the project. 52 | project = 'dask-searchcv' 53 | copyright = '2017, Dask Development Team' 54 | author = 'Dask Development Team' 55 | 56 | # The version info for the project you're documenting, acts as replacement for 57 | # |version| and |release|, also used in various other places throughout the 58 | # built documents. 59 | # 60 | # The short X.Y version. 61 | version = '' 62 | # The full version, including alpha/beta/rc tags. 63 | release = '' 64 | 65 | # The language for content autogenerated by Sphinx. Refer to documentation 66 | # for a list of supported languages. 67 | # 68 | # This is also used if you do content translation via gettext catalogs. 69 | # Usually you set "language" from the command line for these cases. 70 | language = None 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This patterns also effect to html_static_path and html_extra_path 75 | exclude_patterns = [] 76 | 77 | # The name of the Pygments (syntax highlighting) style to use. 78 | pygments_style = 'sphinx' 79 | 80 | # If true, `todo` and `todoList` produce output, else they produce nothing. 81 | todo_include_todos = False 82 | 83 | 84 | # -- Options for HTML output ---------------------------------------------- 85 | 86 | # Taken from docs.readthedocs.io: 87 | # on_rtd is whether we are on readthedocs.io 88 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 89 | 90 | if not on_rtd: # only import and set the theme if we're building docs locally 91 | import sphinx_rtd_theme 92 | html_theme = 'sphinx_rtd_theme' 93 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 94 | 95 | # Theme options are theme-specific and customize the look and feel of a theme 96 | # further. For a list of options available for each theme, see the 97 | # documentation. 98 | # 99 | # html_theme_options = {} 100 | 101 | # Add any paths that contain custom static files (such as style sheets) here, 102 | # relative to this directory. They are copied after the builtin static files, 103 | # so a file named "default.css" will overwrite the builtin "default.css". 104 | html_static_path = ['_static'] 105 | 106 | 107 | # -- Options for HTMLHelp output ------------------------------------------ 108 | 109 | # Output file base name for HTML help builder. 110 | htmlhelp_basename = 'dask-searchcvdoc' 111 | 112 | 113 | # -- Options for LaTeX output --------------------------------------------- 114 | 115 | latex_elements = { 116 | # The paper size ('letterpaper' or 'a4paper'). 117 | # 118 | # 'papersize': 'letterpaper', 119 | 120 | # The font size ('10pt', '11pt' or '12pt'). 121 | # 122 | # 'pointsize': '10pt', 123 | 124 | # Additional stuff for the LaTeX preamble. 125 | # 126 | # 'preamble': '', 127 | 128 | # Latex figure (float) alignment 129 | # 130 | # 'figure_align': 'htbp', 131 | } 132 | 133 | # Grouping the document tree into LaTeX files. List of tuples 134 | # (source start file, target name, title, 135 | # author, documentclass [howto, manual, or own class]). 136 | latex_documents = [ 137 | (master_doc, 'dask-searchcv.tex', 'dask-searchcv Documentation', 138 | 'Dask Development Team', 'manual'), 139 | ] 140 | 141 | 142 | # -- Options for manual page output --------------------------------------- 143 | 144 | # One entry per manual page. List of tuples 145 | # (source start file, name, description, authors, manual section). 146 | man_pages = [ 147 | (master_doc, 'dask-searchcv', 'dask-searchcv Documentation', 148 | [author], 1) 149 | ] 150 | 151 | 152 | # -- Options for Texinfo output ------------------------------------------- 153 | 154 | # Grouping the document tree into Texinfo files. List of tuples 155 | # (source start file, target name, title, author, 156 | # dir menu entry, description, category) 157 | texinfo_documents = [ 158 | (master_doc, 'dask-searchcv', 'dask-searchcv Documentation', 159 | author, 'dask-searchcv', 'One line description of project.', 160 | 'Miscellaneous'), 161 | ] 162 | -------------------------------------------------------------------------------- /dask_searchcv/utils_test.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from functools import wraps 4 | 5 | import numpy as np 6 | from sklearn.base import BaseEstimator, ClassifierMixin 7 | from sklearn.utils.validation import _num_samples, check_array 8 | 9 | 10 | # This class doesn't inherit from BaseEstimator to test hyperparameter search 11 | # on user-defined classifiers. 12 | class MockClassifier(object): 13 | """Dummy classifier to test the parameter search algorithms""" 14 | def __init__(self, foo_param=0): 15 | self.foo_param = foo_param 16 | 17 | def fit(self, X, Y): 18 | assert len(X) == len(Y) 19 | self.classes_ = np.unique(Y) 20 | return self 21 | 22 | def predict(self, T): 23 | return T.shape[0] 24 | 25 | predict_proba = predict 26 | predict_log_proba = predict 27 | decision_function = predict 28 | inverse_transform = predict 29 | 30 | def transform(self, X): 31 | return X 32 | 33 | def score(self, X=None, Y=None): 34 | if self.foo_param > 1: 35 | score = 1. 36 | else: 37 | score = 0. 38 | return score 39 | 40 | def get_params(self, deep=False): 41 | return {'foo_param': self.foo_param} 42 | 43 | def set_params(self, **params): 44 | self.foo_param = params['foo_param'] 45 | return self 46 | 47 | 48 | class ScalingTransformer(BaseEstimator): 49 | def __init__(self, factor=1): 50 | self.factor = factor 51 | 52 | def fit(self, X, y): 53 | return self 54 | 55 | def transform(self, X): 56 | return X * self.factor 57 | 58 | 59 | class CheckXClassifier(BaseEstimator): 60 | """Used to check output of featureunions""" 61 | def __init__(self, expected_X=None): 62 | self.expected_X = expected_X 63 | 64 | def fit(self, X, y): 65 | assert (X == self.expected_X).all() 66 | assert len(X) == len(y) 67 | return self 68 | 69 | def predict(self, X): 70 | return X.sum(axis=1) 71 | 72 | def score(self, X=None, y=None): 73 | return self.predict(X)[0] 74 | 75 | 76 | class FailingClassifier(BaseEstimator): 77 | """Classifier that raises a ValueError on fit()""" 78 | 79 | FAILING_PARAMETER = 2 80 | 81 | def __init__(self, parameter=None): 82 | self.parameter = parameter 83 | 84 | def fit(self, X, y=None): 85 | if self.parameter == FailingClassifier.FAILING_PARAMETER: 86 | raise ValueError("Failing classifier failed as required") 87 | return self 88 | 89 | def transform(self, X): 90 | return X 91 | 92 | def predict(self, X): 93 | return np.zeros(X.shape[0]) 94 | 95 | 96 | def ignore_warnings(f): 97 | """A super simple version of `sklearn.utils.testing.ignore_warnings""" 98 | @wraps(f) 99 | def _(*args, **kwargs): 100 | with warnings.catch_warnings(record=True): 101 | f(*args, **kwargs) 102 | return _ 103 | 104 | 105 | # XXX: Mocking classes copied from sklearn.utils.mocking to remove nose 106 | # dependency. Can be removed when scikit-learn switches to pytest. See issue 107 | # here: https://github.com/scikit-learn/scikit-learn/issues/7319 108 | 109 | class ArraySlicingWrapper(object): 110 | def __init__(self, array): 111 | self.array = array 112 | 113 | def __getitem__(self, aslice): 114 | return MockDataFrame(self.array[aslice]) 115 | 116 | 117 | class MockDataFrame(object): 118 | # have shape and length but don't support indexing. 119 | def __init__(self, array): 120 | self.array = array 121 | self.values = array 122 | self.shape = array.shape 123 | self.ndim = array.ndim 124 | # ugly hack to make iloc work. 125 | self.iloc = ArraySlicingWrapper(array) 126 | 127 | def __len__(self): 128 | return len(self.array) 129 | 130 | def __array__(self, dtype=None): 131 | # Pandas data frames also are array-like: we want to make sure that 132 | # input validation in cross-validation does not try to call that 133 | # method. 134 | return self.array 135 | 136 | 137 | class CheckingClassifier(BaseEstimator, ClassifierMixin): 138 | """Dummy classifier to test pipelining and meta-estimators. 139 | 140 | Checks some property of X and y in fit / predict. 141 | This allows testing whether pipelines / cross-validation or metaestimators 142 | changed the input. 143 | """ 144 | def __init__(self, check_y=None, check_X=None, foo_param=0, 145 | expected_fit_params=None): 146 | self.check_y = check_y 147 | self.check_X = check_X 148 | self.foo_param = foo_param 149 | self.expected_fit_params = expected_fit_params 150 | 151 | def fit(self, X, y, **fit_params): 152 | assert len(X) == len(y) 153 | if self.check_X is not None: 154 | assert self.check_X(X) 155 | if self.check_y is not None: 156 | assert self.check_y(y) 157 | self.classes_ = np.unique(check_array(y, ensure_2d=False, 158 | allow_nd=True)) 159 | if self.expected_fit_params: 160 | missing = set(self.expected_fit_params) - set(fit_params) 161 | assert len(missing) == 0, ('Expected fit parameter(s) %s not ' 162 | 'seen.' % list(missing)) 163 | for key, value in fit_params.items(): 164 | assert len(value) == len(X), ('Fit parameter %s has length %d; ' 165 | 'expected %d.' % (key, len(value), 166 | len(X))) 167 | return self 168 | 169 | def predict(self, T): 170 | if self.check_X is not None: 171 | assert self.check_X(T) 172 | return self.classes_[np.zeros(_num_samples(T), dtype=np.int)] 173 | 174 | def score(self, X=None, Y=None): 175 | if self.foo_param > 1: 176 | score = 1. 177 | else: 178 | score = 0. 179 | return score 180 | -------------------------------------------------------------------------------- /docs/source/sphinxext/numpydoc.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======== 3 | numpydoc 4 | ======== 5 | 6 | Sphinx extension that handles docstrings in the Numpy standard format. [1] 7 | 8 | It will: 9 | 10 | - Convert Parameters etc. sections to field lists. 11 | - Convert See Also section to a See also entry. 12 | - Renumber references. 13 | - Extract the signature from the docstring, if it can't be determined 14 | otherwise. 15 | 16 | .. [1] http://projects.scipy.org/numpy/wiki/CodingStyleGuidelines#docstring-standard 17 | 18 | """ 19 | 20 | from __future__ import unicode_literals 21 | 22 | import sys # Only needed to check Python version 23 | import os 24 | import re 25 | import pydoc 26 | from .docscrape_sphinx import get_doc_object 27 | from .docscrape_sphinx import SphinxDocString 28 | import inspect 29 | 30 | 31 | def mangle_docstrings(app, what, name, obj, options, lines, 32 | reference_offset=[0]): 33 | 34 | cfg = dict(use_plots=app.config.numpydoc_use_plots, 35 | show_class_members=app.config.numpydoc_show_class_members) 36 | 37 | if what == 'module': 38 | # Strip top title 39 | title_re = re.compile(r'^\s*[#*=]{4,}\n[a-z0-9 -]+\n[#*=]{4,}\s*', 40 | re.I | re.S) 41 | lines[:] = title_re.sub('', "\n".join(lines)).split("\n") 42 | else: 43 | doc = get_doc_object(obj, what, "\n".join(lines), config=cfg) 44 | if sys.version_info[0] < 3: 45 | lines[:] = unicode(doc).splitlines() 46 | else: 47 | lines[:] = str(doc).splitlines() 48 | 49 | if app.config.numpydoc_edit_link and hasattr(obj, '__name__') and \ 50 | obj.__name__: 51 | if hasattr(obj, '__module__'): 52 | v = dict(full_name="%s.%s" % (obj.__module__, obj.__name__)) 53 | else: 54 | v = dict(full_name=obj.__name__) 55 | lines += [u'', u'.. htmlonly::', ''] 56 | lines += [u' %s' % x for x in 57 | (app.config.numpydoc_edit_link % v).split("\n")] 58 | 59 | # replace reference numbers so that there are no duplicates 60 | references = [] 61 | for line in lines: 62 | line = line.strip() 63 | m = re.match(r'^.. \[([a-z0-9_.-])\]', line, re.I) 64 | if m: 65 | references.append(m.group(1)) 66 | 67 | # start renaming from the longest string, to avoid overwriting parts 68 | references.sort(key=lambda x: -len(x)) 69 | if references: 70 | for i, line in enumerate(lines): 71 | for r in references: 72 | if re.match(r'^\d+$', r): 73 | new_r = "R%d" % (reference_offset[0] + int(r)) 74 | else: 75 | new_r = u"%s%d" % (r, reference_offset[0]) 76 | lines[i] = lines[i].replace(u'[%s]_' % r, 77 | u'[%s]_' % new_r) 78 | lines[i] = lines[i].replace(u'.. [%s]' % r, 79 | u'.. [%s]' % new_r) 80 | 81 | reference_offset[0] += len(references) 82 | 83 | 84 | def mangle_signature(app, what, name, obj, 85 | options, sig, retann): 86 | # Do not try to inspect classes that don't define `__init__` 87 | if (inspect.isclass(obj) and 88 | (not hasattr(obj, '__init__') or 89 | 'initializes x; see ' in pydoc.getdoc(obj.__init__))): 90 | return '', '' 91 | 92 | if not (callable(obj) or hasattr(obj, '__argspec_is_invalid_')): 93 | return 94 | if not hasattr(obj, '__doc__'): 95 | return 96 | 97 | doc = SphinxDocString(pydoc.getdoc(obj)) 98 | if doc['Signature']: 99 | sig = re.sub("^[^(]*", "", doc['Signature']) 100 | return sig, '' 101 | 102 | 103 | def setup(app, get_doc_object_=get_doc_object): 104 | global get_doc_object 105 | get_doc_object = get_doc_object_ 106 | 107 | if sys.version_info[0] < 3: 108 | app.connect(b'autodoc-process-docstring', mangle_docstrings) 109 | app.connect(b'autodoc-process-signature', mangle_signature) 110 | else: 111 | app.connect('autodoc-process-docstring', mangle_docstrings) 112 | app.connect('autodoc-process-signature', mangle_signature) 113 | app.add_config_value('numpydoc_edit_link', None, False) 114 | app.add_config_value('numpydoc_use_plots', None, False) 115 | app.add_config_value('numpydoc_show_class_members', True, True) 116 | 117 | # Extra mangling domains 118 | app.add_domain(NumpyPythonDomain) 119 | app.add_domain(NumpyCDomain) 120 | 121 | #----------------------------------------------------------------------------- 122 | # Docstring-mangling domains 123 | #----------------------------------------------------------------------------- 124 | 125 | try: 126 | import sphinx # lazy to avoid test dependency 127 | except ImportError: 128 | CDomain = PythonDomain = object 129 | else: 130 | from sphinx.domains.c import CDomain 131 | from sphinx.domains.python import PythonDomain 132 | 133 | 134 | class ManglingDomainBase(object): 135 | directive_mangling_map = {} 136 | 137 | def __init__(self, *a, **kw): 138 | super(ManglingDomainBase, self).__init__(*a, **kw) 139 | self.wrap_mangling_directives() 140 | 141 | def wrap_mangling_directives(self): 142 | for name, objtype in self.directive_mangling_map.items(): 143 | self.directives[name] = wrap_mangling_directive( 144 | self.directives[name], objtype) 145 | 146 | 147 | class NumpyPythonDomain(ManglingDomainBase, PythonDomain): 148 | name = 'np' 149 | directive_mangling_map = { 150 | 'function': 'function', 151 | 'class': 'class', 152 | 'exception': 'class', 153 | 'method': 'function', 154 | 'classmethod': 'function', 155 | 'staticmethod': 'function', 156 | 'attribute': 'attribute', 157 | } 158 | 159 | 160 | class NumpyCDomain(ManglingDomainBase, CDomain): 161 | name = 'np-c' 162 | directive_mangling_map = { 163 | 'function': 'function', 164 | 'member': 'attribute', 165 | 'macro': 'function', 166 | 'type': 'class', 167 | 'var': 'object', 168 | } 169 | 170 | 171 | def wrap_mangling_directive(base_directive, objtype): 172 | class directive(base_directive): 173 | def run(self): 174 | env = self.state.document.settings.env 175 | 176 | name = None 177 | if self.arguments: 178 | m = re.match(r'^(.*\s+)?(.*?)(\(.*)?', self.arguments[0]) 179 | name = m.group(2).strip() 180 | 181 | if not name: 182 | name = self.arguments[0] 183 | 184 | lines = list(self.content) 185 | mangle_docstrings(env.app, objtype, name, None, None, lines) 186 | # local import to avoid testing dependency 187 | from docutils.statemachine import ViewList 188 | self.content = ViewList(lines, self.content.parent) 189 | 190 | return base_directive.run(self) 191 | 192 | return directive 193 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | dask-searchcv 2 | ============= 3 | 4 | .. note:: 5 | 6 | Dask-SearchCV is now included in 7 | `Dask-ML _`. 8 | 9 | Further development to Dask-SearchCV is occuring in the 10 | `Dask-ML _` repository. Please post issues 11 | and make pull requests there. 12 | 13 | 14 | Tools for performing hyperparameter optimization of Scikit-Learn models using 15 | Dask. 16 | 17 | Introduction 18 | ------------ 19 | 20 | This library provides implementations of Scikit-Learn's ``GridSearchCV`` and 21 | ``RandomizedSearchCV``. They implement many (but not all) of the same 22 | parameters, and should be a drop-in replacement for the subset that they do 23 | implement. For certain problems, these implementations can be more efficient 24 | than those in Scikit-Learn, as they can avoid expensive repeated computations. 25 | 26 | For more information, see `this blogpost 27 | `__. 28 | 29 | Highlights 30 | ---------- 31 | 32 | - :ref:`Drop-in replacement ` for Scikit-Learn's 33 | ``GridSearchCV`` and ``RandomizedSearchCV``. 34 | 35 | - :ref:`Flexible Backends `. Hyperparameter 36 | optimization can be done in parallel using threads, processes, or distributed 37 | across a cluster. 38 | 39 | - :ref:`Works well with Dask collections `. Dask 40 | arrays, dataframes, and delayed can be passed to ``fit``. 41 | 42 | - :ref:`Avoid repeated work `. Candidate estimators with 43 | identical parameters and inputs will only be fit once. For 44 | composite-estimators such as ``Pipeline`` this can be significantly more 45 | efficient as it can avoid expensive repeated computations. 46 | 47 | Install 48 | ------- 49 | 50 | Dask-searchcv is available via ``conda`` or ``pip``: 51 | 52 | .. code-block:: bash 53 | 54 | # Install with conda 55 | $ conda install dask-searchcv -c conda-forge 56 | 57 | # Install with pip 58 | $ pip install dask-searchcv 59 | 60 | 61 | Walkthrough 62 | ----------- 63 | 64 | .. _drop-in-replacement: 65 | 66 | Drop-In Replacement 67 | ^^^^^^^^^^^^^^^^^^^ 68 | 69 | Dask-searchcv provides (almost) drop-in replacements for Scikit-Learn's 70 | ``GridSearchCV`` and ``RandomizedSearchCV``. With the exception of a few 71 | keyword arguments, the api's are exactly the same, and often only an import 72 | change is necessary: 73 | 74 | .. code-block:: python 75 | :emphasize-lines: 4,5 76 | 77 | from sklearn.datasets import load_digits 78 | from sklearn.svm import SVC 79 | 80 | # Fit with dask-searchcv 81 | from dask_searchcv import GridSearchCV 82 | 83 | param_space = {'C': [1e-4, 1, 1e4], 84 | 'gamma': [1e-3, 1, 1e3], 85 | 'class_weight': [None, 'balanced']} 86 | 87 | model = SVC(kernel='rbf') 88 | 89 | digits = load_digits() 90 | 91 | search = GridSearchCV(model, param_space, cv=3) 92 | search.fit(digits.data, digits.target) 93 | 94 | .. raw:: html 95 | 96 | 97 | 117 | 118 | 119 | .. _flexible-backends: 120 | 121 | Flexible Backends 122 | ^^^^^^^^^^^^^^^^^ 123 | 124 | Dask-searchcv can use any of the dask schedulers. By default the threaded 125 | scheduler is used, but this can easily be swapped out for the multiprocessing 126 | or distributed scheduler: 127 | 128 | .. code-block:: python 129 | 130 | # Distribute grid-search across a cluster 131 | from dask.distributed import Client 132 | scheduler_address = '127.0.0.1:8786' 133 | client = Client(scheduler_address) 134 | 135 | search.fit(digits.data, digits.target) 136 | 137 | 138 | .. _works-with-dask-collections: 139 | 140 | Works Well With Dask Collections 141 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 142 | 143 | Dask collections such as ``dask.array``, ``dask.dataframe`` and 144 | ``dask.delayed`` can be passed to ``fit``. This means you can use dask to do 145 | your data loading and preprocessing as well, allowing for a clean workflow. 146 | This also allows you to work with remote data on a cluster without ever having 147 | to pull it locally to your computer: 148 | 149 | .. code-block:: python 150 | 151 | import dask.dataframe as dd 152 | 153 | # Load data from s3 154 | df = dd.read_csv('s3://bucket-name/my-data-*.csv') 155 | 156 | # Do some preprocessing steps 157 | df['x2'] = df.x - df.x.mean() 158 | # ... 159 | 160 | # Pass to fit without ever leaving the cluster 161 | search.fit(df[['x', 'x2']], df['y']) 162 | 163 | 164 | .. _avoid-repeated-work: 165 | 166 | Avoid Repeated Work 167 | ^^^^^^^^^^^^^^^^^^^ 168 | 169 | When searching over composite estimators like ``sklearn.pipeline.Pipeline`` or 170 | ``sklearn.pipeline.FeatureUnion``, dask-searchcv will avoid fitting the same 171 | estimator + parameter + data combination more than once. For pipelines with 172 | expensive early steps this can be faster, as repeated work is avoided. 173 | 174 | For example, given the following 3-stage pipeline and grid (modified from `this 175 | scikit-learn example 176 | `__). 177 | 178 | .. code-block:: python 179 | 180 | from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer 181 | from sklearn.linear_model import SGDClassifier 182 | from sklearn.pipeline import Pipeline 183 | 184 | pipeline = Pipeline([('vect', CountVectorizer()), 185 | ('tfidf', TfidfTransformer()), 186 | ('clf', SGDClassifier())]) 187 | 188 | grid = {'vect__ngram_range': [(1, 1)], 189 | 'tfidf__norm': ['l1', 'l2'], 190 | 'clf__alpha': [1e-3, 1e-4, 1e-5]} 191 | 192 | the Scikit-Learn grid-search implementation looks something like (simplified): 193 | 194 | .. code-block:: python 195 | 196 | scores = [] 197 | for ngram_range in parameters['vect__ngram_range']: 198 | for norm in parameters['tfidf__norm']: 199 | for alpha in parameters['clf__alpha']: 200 | vect = CountVectorizer(ngram_range=ngram_range) 201 | X2 = vect.fit_transform(X, y) 202 | tfidf = TfidfTransformer(norm=norm) 203 | X3 = tfidf.fit_transform(X2, y) 204 | clf = SGDClassifier(alpha=alpha) 205 | clf.fit(X3, y) 206 | scores.append(clf.score(X3, y)) 207 | best = choose_best_parameters(scores, parameters) 208 | 209 | 210 | As a directed acyclic graph, this might look like: 211 | 212 | .. figure:: images/unmerged_grid_search_graph.svg 213 | :alt: "scikit-learn grid-search directed acyclic graph" 214 | :align: center 215 | 216 | 217 | In contrast, the dask version looks more like: 218 | 219 | .. code-block:: python 220 | 221 | scores = [] 222 | for ngram_range in parameters['vect__ngram_range']: 223 | vect = CountVectorizer(ngram_range=ngram_range) 224 | X2 = vect.fit_transform(X, y) 225 | for norm in parameters['tfidf__norm']: 226 | tfidf = TfidfTransformer(norm=norm) 227 | X3 = tfidf.fit_transform(X2, y) 228 | for alpha in parameters['clf__alpha']: 229 | clf = SGDClassifier(alpha=alpha) 230 | clf.fit(X3, y) 231 | scores.append(clf.score(X3, y)) 232 | best = choose_best_parameters(scores, parameters) 233 | 234 | 235 | With a corresponding directed acyclic graph: 236 | 237 | .. figure:: images/merged_grid_search_graph.svg 238 | :alt: "dask-searchcv grid-search directed acyclic graph" 239 | :align: center 240 | 241 | 242 | Looking closely, you can see that the Scikit-Learn version ends up fitting 243 | earlier steps in the pipeline multiple times with the same parameters and data. 244 | Due to the increased flexibility of Dask over Joblib, we're able to merge these 245 | tasks in the graph and only perform the fit step once for any 246 | parameter/data/estimator combination. For pipelines that have relatively 247 | expensive early steps, this can be a big win when performing a grid search. 248 | 249 | 250 | Index 251 | ----- 252 | 253 | .. toctree:: 254 | 255 | api 256 | -------------------------------------------------------------------------------- /docs/source/sphinxext/docscrape_sphinx.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inspect 3 | import textwrap 4 | import pydoc 5 | from .docscrape import NumpyDocString 6 | from .docscrape import FunctionDoc 7 | from .docscrape import ClassDoc 8 | 9 | 10 | class SphinxDocString(NumpyDocString): 11 | def __init__(self, docstring, config=None): 12 | config = {} if config is None else config 13 | self.use_plots = config.get('use_plots', False) 14 | NumpyDocString.__init__(self, docstring, config=config) 15 | 16 | # string conversion routines 17 | def _str_header(self, name, symbol='`'): 18 | return ['.. rubric:: ' + name, ''] 19 | 20 | def _str_field_list(self, name): 21 | return [':' + name + ':'] 22 | 23 | def _str_indent(self, doc, indent=4): 24 | out = [] 25 | for line in doc: 26 | out += [' ' * indent + line] 27 | return out 28 | 29 | def _str_signature(self): 30 | return [''] 31 | if self['Signature']: 32 | return ['``%s``' % self['Signature']] + [''] 33 | else: 34 | return [''] 35 | 36 | def _str_summary(self): 37 | return self['Summary'] + [''] 38 | 39 | def _str_extended_summary(self): 40 | return self['Extended Summary'] + [''] 41 | 42 | def _str_param_list(self, name): 43 | out = [] 44 | if self[name]: 45 | out += self._str_field_list(name) 46 | out += [''] 47 | for param, param_type, desc in self[name]: 48 | out += self._str_indent(['**%s** : %s' % (param.strip(), 49 | param_type)]) 50 | out += [''] 51 | out += self._str_indent(desc, 8) 52 | out += [''] 53 | return out 54 | 55 | @property 56 | def _obj(self): 57 | if hasattr(self, '_cls'): 58 | return self._cls 59 | elif hasattr(self, '_f'): 60 | return self._f 61 | return None 62 | 63 | def _str_member_list(self, name): 64 | """ 65 | Generate a member listing, autosummary:: table where possible, 66 | and a table where not. 67 | 68 | """ 69 | out = [] 70 | if self[name]: 71 | out += ['.. rubric:: %s' % name, ''] 72 | prefix = getattr(self, '_name', '') 73 | 74 | if prefix: 75 | prefix = '~%s.' % prefix 76 | 77 | autosum = [] 78 | others = [] 79 | for param, param_type, desc in self[name]: 80 | param = param.strip() 81 | if not self._obj or hasattr(self._obj, param): 82 | autosum += [" %s%s" % (prefix, param)] 83 | else: 84 | others.append((param, param_type, desc)) 85 | 86 | if autosum: 87 | # GAEL: Toctree commented out below because it creates 88 | # hundreds of sphinx warnings 89 | # out += ['.. autosummary::', ' :toctree:', ''] 90 | out += ['.. autosummary::', ''] 91 | out += autosum 92 | 93 | if others: 94 | maxlen_0 = max([len(x[0]) for x in others]) 95 | maxlen_1 = max([len(x[1]) for x in others]) 96 | hdr = "=" * maxlen_0 + " " + "=" * maxlen_1 + " " + "=" * 10 97 | fmt = '%%%ds %%%ds ' % (maxlen_0, maxlen_1) 98 | n_indent = maxlen_0 + maxlen_1 + 4 99 | out += [hdr] 100 | for param, param_type, desc in others: 101 | out += [fmt % (param.strip(), param_type)] 102 | out += self._str_indent(desc, n_indent) 103 | out += [hdr] 104 | out += [''] 105 | return out 106 | 107 | def _str_section(self, name): 108 | out = [] 109 | if self[name]: 110 | out += self._str_header(name) 111 | out += [''] 112 | content = textwrap.dedent("\n".join(self[name])).split("\n") 113 | out += content 114 | out += [''] 115 | return out 116 | 117 | def _str_see_also(self, func_role): 118 | out = [] 119 | if self['See Also']: 120 | see_also = super(SphinxDocString, self)._str_see_also(func_role) 121 | out = ['.. seealso::', ''] 122 | out += self._str_indent(see_also[2:]) 123 | return out 124 | 125 | def _str_warnings(self): 126 | out = [] 127 | if self['Warnings']: 128 | out = ['.. warning::', ''] 129 | out += self._str_indent(self['Warnings']) 130 | return out 131 | 132 | def _str_index(self): 133 | idx = self['index'] 134 | out = [] 135 | if len(idx) == 0: 136 | return out 137 | 138 | out += ['.. index:: %s' % idx.get('default', '')] 139 | for section, references in idx.iteritems(): 140 | if section == 'default': 141 | continue 142 | elif section == 'refguide': 143 | out += [' single: %s' % (', '.join(references))] 144 | else: 145 | out += [' %s: %s' % (section, ','.join(references))] 146 | return out 147 | 148 | def _str_references(self): 149 | out = [] 150 | if self['References']: 151 | out += self._str_header('References') 152 | if isinstance(self['References'], str): 153 | self['References'] = [self['References']] 154 | out.extend(self['References']) 155 | out += [''] 156 | # Latex collects all references to a separate bibliography, 157 | # so we need to insert links to it 158 | import sphinx # local import to avoid test dependency 159 | if sphinx.__version__ >= "0.6": 160 | out += ['.. only:: latex', ''] 161 | else: 162 | out += ['.. latexonly::', ''] 163 | items = [] 164 | for line in self['References']: 165 | m = re.match(r'.. \[([a-z0-9._-]+)\]', line, re.I) 166 | if m: 167 | items.append(m.group(1)) 168 | out += [' ' + ", ".join(["[%s]_" % item for item in items]), ''] 169 | return out 170 | 171 | def _str_examples(self): 172 | examples_str = "\n".join(self['Examples']) 173 | 174 | if (self.use_plots and 'import matplotlib' in examples_str 175 | and 'plot::' not in examples_str): 176 | out = [] 177 | out += self._str_header('Examples') 178 | out += ['.. plot::', ''] 179 | out += self._str_indent(self['Examples']) 180 | out += [''] 181 | return out 182 | else: 183 | return self._str_section('Examples') 184 | 185 | def __str__(self, indent=0, func_role="obj"): 186 | out = [] 187 | out += self._str_signature() 188 | out += self._str_index() + [''] 189 | out += self._str_summary() 190 | out += self._str_extended_summary() 191 | for param_list in ('Parameters', 'Returns', 'Raises', 'Attributes'): 192 | out += self._str_param_list(param_list) 193 | out += self._str_warnings() 194 | out += self._str_see_also(func_role) 195 | out += self._str_section('Notes') 196 | out += self._str_references() 197 | out += self._str_examples() 198 | for param_list in ('Methods',): 199 | out += self._str_member_list(param_list) 200 | out = self._str_indent(out, indent) 201 | return '\n'.join(out) 202 | 203 | 204 | class SphinxFunctionDoc(SphinxDocString, FunctionDoc): 205 | def __init__(self, obj, doc=None, config={}): 206 | self.use_plots = config.get('use_plots', False) 207 | FunctionDoc.__init__(self, obj, doc=doc, config=config) 208 | 209 | 210 | class SphinxClassDoc(SphinxDocString, ClassDoc): 211 | def __init__(self, obj, doc=None, func_doc=None, config={}): 212 | self.use_plots = config.get('use_plots', False) 213 | ClassDoc.__init__(self, obj, doc=doc, func_doc=None, config=config) 214 | 215 | 216 | class SphinxObjDoc(SphinxDocString): 217 | def __init__(self, obj, doc=None, config=None): 218 | self._f = obj 219 | SphinxDocString.__init__(self, doc, config=config) 220 | 221 | 222 | def get_doc_object(obj, what=None, doc=None, config={}): 223 | if what is None: 224 | if inspect.isclass(obj): 225 | what = 'class' 226 | elif inspect.ismodule(obj): 227 | what = 'module' 228 | elif callable(obj): 229 | what = 'function' 230 | else: 231 | what = 'object' 232 | if what == 'class': 233 | return SphinxClassDoc(obj, func_doc=SphinxFunctionDoc, doc=doc, 234 | config=config) 235 | elif what in ('function', 'method'): 236 | return SphinxFunctionDoc(obj, doc=doc, config=config) 237 | else: 238 | if doc is None: 239 | doc = pydoc.getdoc(obj) 240 | return SphinxObjDoc(obj, doc, config=config) 241 | -------------------------------------------------------------------------------- /docs/source/images/merged_grid_search_graph.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | merged 11 | 12 | 13 | data 14 | 15 | Training Data 16 | 17 | 18 | vect1 19 | 20 | CountVectorizer 21 | - ngram_range=(1, 1) 22 | 23 | 24 | data->vect1 25 | 26 | 27 | 28 | 29 | tfidf_1_1 30 | 31 | TfidfTransformer 32 | - norm='l1' 33 | 34 | 35 | vect1->tfidf_1_1 36 | 37 | 38 | 39 | 40 | tfidf_1_2 41 | 42 | TfidfTransformer 43 | - norm='l2' 44 | 45 | 46 | vect1->tfidf_1_2 47 | 48 | 49 | 50 | 51 | sgd_1_1_1 52 | 53 | SGDClassifier 54 | - alpha=1e-3 55 | 56 | 57 | tfidf_1_1->sgd_1_1_1 58 | 59 | 60 | 61 | 62 | sgd_1_1_2 63 | 64 | SGDClassifier 65 | - alpha=1e-4 66 | 67 | 68 | tfidf_1_1->sgd_1_1_2 69 | 70 | 71 | 72 | 73 | sgd_1_1_3 74 | 75 | SGDClassifier 76 | - alpha=1e-5 77 | 78 | 79 | tfidf_1_1->sgd_1_1_3 80 | 81 | 82 | 83 | 84 | sgd_1_2_1 85 | 86 | SGDClassifier 87 | - alpha=1e-3 88 | 89 | 90 | tfidf_1_2->sgd_1_2_1 91 | 92 | 93 | 94 | 95 | sgd_1_2_2 96 | 97 | SGDClassifier 98 | - alpha=1e-4 99 | 100 | 101 | tfidf_1_2->sgd_1_2_2 102 | 103 | 104 | 105 | 106 | sgd_1_2_3 107 | 108 | SGDClassifier 109 | - alpha=1e-5 110 | 111 | 112 | tfidf_1_2->sgd_1_2_3 113 | 114 | 115 | 116 | 117 | best 118 | 119 | Choose Best Parameters 120 | 121 | 122 | sgd_1_1_1->best 123 | 124 | 125 | 126 | 127 | sgd_1_1_2->best 128 | 129 | 130 | 131 | 132 | sgd_1_1_3->best 133 | 134 | 135 | 136 | 137 | sgd_1_2_1->best 138 | 139 | 140 | 141 | 142 | sgd_1_2_2->best 143 | 144 | 145 | 146 | 147 | sgd_1_2_3->best 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /dask_searchcv/methods.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import warnings 4 | from collections import defaultdict, Mapping 5 | from threading import Lock 6 | from timeit import default_timer 7 | from distutils.version import LooseVersion 8 | 9 | import numpy as np 10 | from toolz import pluck 11 | from scipy import sparse 12 | from scipy.stats import rankdata 13 | from dask.base import normalize_token 14 | 15 | from sklearn.exceptions import FitFailedWarning 16 | from sklearn.pipeline import Pipeline, FeatureUnion 17 | from sklearn.utils import safe_indexing 18 | from sklearn.utils.validation import check_consistent_length, _is_arraylike 19 | 20 | from .utils import copy_estimator 21 | 22 | # Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop 23 | # support for scikit-learn < 0.18.1 or numpy < 1.12.0. 24 | if LooseVersion(np.__version__) < '1.12.0': 25 | class MaskedArray(np.ma.MaskedArray): 26 | # Before numpy 1.12, np.ma.MaskedArray object is not picklable 27 | # This fix is needed to make our model_selection.GridSearchCV 28 | # picklable as the ``cv_results_`` param uses MaskedArray 29 | def __getstate__(self): 30 | """Return the internal state of the masked array, for pickling 31 | purposes. 32 | 33 | """ 34 | cf = 'CF'[self.flags.fnc] 35 | data_state = super(np.ma.MaskedArray, self).__reduce__()[2] 36 | return data_state + (np.ma.getmaskarray(self).tostring(cf), 37 | self._fill_value) 38 | else: 39 | from numpy.ma import MaskedArray # noqa 40 | 41 | # A singleton to indicate a missing parameter 42 | MISSING = type('MissingParameter', (object,), 43 | {'__slots__': (), 44 | '__reduce__': lambda self: 'MISSING', 45 | '__doc__': "A singleton to indicate a missing parameter"})() 46 | normalize_token.register(type(MISSING), lambda x: 'MISSING') 47 | 48 | 49 | # A singleton to indicate a failed estimator fit 50 | FIT_FAILURE = type('FitFailure', (object,), 51 | {'__slots__': (), 52 | '__reduce__': lambda self: 'FIT_FAILURE', 53 | '__doc__': "A singleton to indicate fit failure"})() 54 | 55 | 56 | def warn_fit_failure(error_score, e): 57 | warnings.warn("Classifier fit failed. The score on this train-test" 58 | " partition for these parameters will be set to %f. " 59 | "Details: \n%r" % (error_score, e), FitFailedWarning) 60 | 61 | 62 | # ----------------------- # 63 | # Functions in the graphs # 64 | # ----------------------- # 65 | 66 | 67 | class CVCache(object): 68 | def __init__(self, splits, pairwise=False, cache=True): 69 | self.splits = splits 70 | self.pairwise = pairwise 71 | self.cache = {} if cache else None 72 | 73 | def __reduce__(self): 74 | return (CVCache, (self.splits, self.pairwise, self.cache is not None)) 75 | 76 | def num_test_samples(self): 77 | return np.array([i.sum() if i.dtype == bool else len(i) 78 | for i in pluck(1, self.splits)]) 79 | 80 | def extract(self, X, y, n, is_x=True, is_train=True): 81 | if is_x: 82 | if self.pairwise: 83 | return self._extract_pairwise(X, y, n, is_train=is_train) 84 | return self._extract(X, y, n, is_x=True, is_train=is_train) 85 | if y is None: 86 | return None 87 | return self._extract(X, y, n, is_x=False, is_train=is_train) 88 | 89 | def extract_param(self, key, x, n): 90 | if self.cache is not None and (n, key) in self.cache: 91 | return self.cache[n, key] 92 | 93 | out = safe_indexing(x, self.splits[n][0]) if _is_arraylike(x) else x 94 | 95 | if self.cache is not None: 96 | self.cache[n, key] = out 97 | return out 98 | 99 | def _extract(self, X, y, n, is_x=True, is_train=True): 100 | if self.cache is not None and (n, is_x, is_train) in self.cache: 101 | return self.cache[n, is_x, is_train] 102 | 103 | inds = self.splits[n][0] if is_train else self.splits[n][1] 104 | result = safe_indexing(X if is_x else y, inds) 105 | 106 | if self.cache is not None: 107 | self.cache[n, is_x, is_train] = result 108 | return result 109 | 110 | def _extract_pairwise(self, X, y, n, is_train=True): 111 | if self.cache is not None and (n, True, is_train) in self.cache: 112 | return self.cache[n, True, is_train] 113 | 114 | if not hasattr(X, "shape"): 115 | raise ValueError("Precomputed kernels or affinity matrices have " 116 | "to be passed as arrays or sparse matrices.") 117 | if X.shape[0] != X.shape[1]: 118 | raise ValueError("X should be a square kernel matrix") 119 | train, test = self.splits[n] 120 | result = X[np.ix_(train if is_train else test, train)] 121 | 122 | if self.cache is not None: 123 | self.cache[n, True, is_train] = result 124 | return result 125 | 126 | 127 | def cv_split(cv, X, y, groups, is_pairwise, cache): 128 | check_consistent_length(X, y, groups) 129 | return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache) 130 | 131 | 132 | def cv_n_samples(cvs): 133 | return cvs.num_test_samples() 134 | 135 | 136 | def cv_extract(cvs, X, y, is_X, is_train, n): 137 | return cvs.extract(X, y, n, is_X, is_train) 138 | 139 | 140 | def cv_extract_params(cvs, keys, vals, n): 141 | return {k: cvs.extract_param(tok, v, n) for (k, tok), v in zip(keys, vals)} 142 | 143 | 144 | def decompress_params(fields, params): 145 | return [{k: v for k, v in zip(fields, p) if v is not MISSING} 146 | for p in params] 147 | 148 | 149 | def _maybe_timed(x): 150 | """Unpack (est, fit_time) tuples if provided""" 151 | return x if isinstance(x, tuple) and len(x) == 2 else (x, 0.0) 152 | 153 | 154 | def pipeline(names, steps): 155 | """Reconstruct a Pipeline from names and steps""" 156 | steps, times = zip(*map(_maybe_timed, steps)) 157 | fit_time = sum(times) 158 | if any(s is FIT_FAILURE for s in steps): 159 | fit_est = FIT_FAILURE 160 | else: 161 | fit_est = Pipeline(list(zip(names, steps))) 162 | return fit_est, fit_time 163 | 164 | 165 | def feature_union(names, steps, weights): 166 | """Reconstruct a FeatureUnion from names, steps, and weights""" 167 | steps, times = zip(*map(_maybe_timed, steps)) 168 | fit_time = sum(times) 169 | if any(s is FIT_FAILURE for s in steps): 170 | fit_est = FIT_FAILURE 171 | else: 172 | fit_est = FeatureUnion(list(zip(names, steps)), 173 | transformer_weights=weights) 174 | return fit_est, fit_time 175 | 176 | 177 | def feature_union_concat(Xs, nsamples, weights): 178 | """Apply weights and concatenate outputs from a FeatureUnion""" 179 | if any(x is FIT_FAILURE for x in Xs): 180 | return FIT_FAILURE 181 | Xs = [X if w is None else X * w for X, w in zip(Xs, weights) 182 | if X is not None] 183 | if not Xs: 184 | return np.zeros((nsamples, 0)) 185 | if any(sparse.issparse(f) for f in Xs): 186 | return sparse.hstack(Xs).tocsr() 187 | return np.hstack(Xs) 188 | 189 | 190 | # Current set_params isn't threadsafe 191 | SET_PARAMS_LOCK = Lock() 192 | 193 | 194 | def set_params(est, fields=None, params=None, copy=True): 195 | if copy: 196 | est = copy_estimator(est) 197 | if fields is None: 198 | return est 199 | params = {f: p for (f, p) in zip(fields, params) if p is not MISSING} 200 | # TODO: rewrite set_params to avoid lock for classes that use the standard 201 | # set_params/get_params methods 202 | with SET_PARAMS_LOCK: 203 | return est.set_params(**params) 204 | 205 | 206 | def fit(est, X, y, error_score='raise', fields=None, params=None, 207 | fit_params=None): 208 | if X is FIT_FAILURE: 209 | est, fit_time = FIT_FAILURE, 0.0 210 | else: 211 | if not fit_params: 212 | fit_params = {} 213 | start_time = default_timer() 214 | try: 215 | est = set_params(est, fields, params) 216 | est.fit(X, y, **fit_params) 217 | except Exception as e: 218 | if error_score == 'raise': 219 | raise 220 | warn_fit_failure(error_score, e) 221 | est = FIT_FAILURE 222 | fit_time = default_timer() - start_time 223 | 224 | return est, fit_time 225 | 226 | 227 | def fit_transform(est, X, y, error_score='raise', fields=None, params=None, 228 | fit_params=None): 229 | if X is FIT_FAILURE: 230 | est, fit_time, Xt = FIT_FAILURE, 0.0, FIT_FAILURE 231 | else: 232 | if not fit_params: 233 | fit_params = {} 234 | start_time = default_timer() 235 | try: 236 | est = set_params(est, fields, params) 237 | if hasattr(est, 'fit_transform'): 238 | Xt = est.fit_transform(X, y, **fit_params) 239 | else: 240 | est.fit(X, y, **fit_params) 241 | Xt = est.transform(X) 242 | except Exception as e: 243 | if error_score == 'raise': 244 | raise 245 | warn_fit_failure(error_score, e) 246 | est = Xt = FIT_FAILURE 247 | fit_time = default_timer() - start_time 248 | 249 | return (est, fit_time), Xt 250 | 251 | 252 | def _score(est, X, y, scorer): 253 | if est is FIT_FAILURE: 254 | return FIT_FAILURE 255 | if isinstance(scorer, Mapping): 256 | return {k: v(est, X) if y is None else v(est, X, y) 257 | for k, v in scorer.items()} 258 | return scorer(est, X) if y is None else scorer(est, X, y) 259 | 260 | 261 | def score(est_and_time, X_test, y_test, X_train, y_train, scorer): 262 | est, fit_time = est_and_time 263 | start_time = default_timer() 264 | test_score = _score(est, X_test, y_test, scorer) 265 | score_time = default_timer() - start_time 266 | if X_train is None: 267 | return fit_time, test_score, score_time 268 | train_score = _score(est, X_train, y_train, scorer) 269 | return fit_time, test_score, score_time, train_score 270 | 271 | 272 | def fit_and_score(est, cv, X, y, n, scorer, 273 | error_score='raise', fields=None, params=None, 274 | fit_params=None, return_train_score=True): 275 | X_train = cv.extract(X, y, n, True, True) 276 | y_train = cv.extract(X, y, n, False, True) 277 | X_test = cv.extract(X, y, n, True, False) 278 | y_test = cv.extract(X, y, n, False, False) 279 | est_and_time = fit(est, X_train, y_train, error_score, 280 | fields, params, fit_params) 281 | if not return_train_score: 282 | X_train = y_train = None 283 | return score(est_and_time, X_test, y_test, X_train, y_train, scorer) 284 | 285 | 286 | def _store(results, key_name, array, n_splits, n_candidates, 287 | weights=None, splits=False, rank=False): 288 | """A small helper to store the scores/times to the cv_results_""" 289 | # When iterated first by n_splits and then by parameters 290 | array = np.array(array, dtype=np.float64).reshape(n_splits, n_candidates).T 291 | if splits: 292 | for split_i in range(n_splits): 293 | results["split%d_%s" % (split_i, key_name)] = array[:, split_i] 294 | 295 | array_means = np.average(array, axis=1, weights=weights) 296 | results['mean_%s' % key_name] = array_means 297 | # Weighted std is not directly available in numpy 298 | array_stds = np.sqrt(np.average((array - array_means[:, np.newaxis]) ** 2, 299 | axis=1, weights=weights)) 300 | results['std_%s' % key_name] = array_stds 301 | 302 | if rank: 303 | results["rank_%s" % key_name] = np.asarray( 304 | rankdata(-array_means, method='min'), dtype=np.int32) 305 | 306 | 307 | def create_cv_results(scores, candidate_params, n_splits, error_score, weights, 308 | multimetric): 309 | if len(scores[0]) == 4: 310 | fit_times, test_scores, score_times, train_scores = zip(*scores) 311 | else: 312 | fit_times, test_scores, score_times = zip(*scores) 313 | train_scores = None 314 | 315 | if not multimetric: 316 | test_scores = [error_score if s is FIT_FAILURE else s 317 | for s in test_scores] 318 | if train_scores is not None: 319 | train_scores = [error_score if s is FIT_FAILURE else s 320 | for s in train_scores] 321 | else: 322 | test_scores = {k: [error_score if x is FIT_FAILURE else x[k] 323 | for x in test_scores] 324 | for k in multimetric} 325 | if train_scores is not None: 326 | train_scores = {k: [error_score if x is FIT_FAILURE else x[k] 327 | for x in train_scores] 328 | for k in multimetric} 329 | 330 | # Construct the `cv_results_` dictionary 331 | results = {'params': candidate_params} 332 | n_candidates = len(candidate_params) 333 | 334 | if weights is not None: 335 | weights = np.broadcast_to(weights[None, :], 336 | (len(candidate_params), len(weights))) 337 | 338 | _store(results, 'fit_time', fit_times, n_splits, n_candidates) 339 | _store(results, 'score_time', score_times, n_splits, n_candidates) 340 | 341 | if not multimetric: 342 | _store(results, 'test_score', test_scores, n_splits, n_candidates, 343 | splits=True, rank=True, weights=weights) 344 | if train_scores is not None: 345 | _store(results, 'train_score', train_scores, 346 | n_splits, n_candidates, splits=True) 347 | else: 348 | for key in multimetric: 349 | _store(results, 'test_{}'.format(key), test_scores[key], n_splits, 350 | n_candidates, splits=True, rank=True, weights=weights) 351 | if train_scores is not None: 352 | for key in multimetric: 353 | _store(results, 'train_{}'.format(key), train_scores[key], n_splits, 354 | n_candidates, splits=True) 355 | 356 | # Use one MaskedArray and mask all the places where the param is not 357 | # applicable for that candidate. Use defaultdict as each candidate may 358 | # not contain all the params 359 | param_results = defaultdict(lambda: MaskedArray(np.empty(n_candidates), 360 | mask=True, 361 | dtype=object)) 362 | for cand_i, params in enumerate(candidate_params): 363 | for name, value in params.items(): 364 | param_results["param_%s" % name][cand_i] = value 365 | 366 | results.update(param_results) 367 | return results 368 | 369 | 370 | def get_best_params(candidate_params, cv_results, scorer): 371 | best_index = np.flatnonzero( 372 | cv_results["rank_test_{}".format(scorer)] == 1)[0] 373 | return candidate_params[best_index] 374 | 375 | 376 | def fit_best(estimator, params, X, y, fit_params): 377 | estimator = copy_estimator(estimator).set_params(**params) 378 | estimator.fit(X, y, **fit_params) 379 | return estimator 380 | -------------------------------------------------------------------------------- /docs/source/images/unmerged_grid_search_graph.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | unmerged 11 | 12 | 13 | data 14 | 15 | Training Data 16 | 17 | 18 | vect1 19 | 20 | CountVectorizer 21 | - ngram_range=(1, 1) 22 | 23 | 24 | data->vect1 25 | 26 | 27 | 28 | 29 | vect2 30 | 31 | CountVectorizer 32 | - ngram_range=(1, 1) 33 | 34 | 35 | data->vect2 36 | 37 | 38 | 39 | 40 | vect3 41 | 42 | CountVectorizer 43 | - ngram_range=(1, 1) 44 | 45 | 46 | data->vect3 47 | 48 | 49 | 50 | 51 | vect4 52 | 53 | CountVectorizer 54 | - ngram_range=(1, 1) 55 | 56 | 57 | data->vect4 58 | 59 | 60 | 61 | 62 | vect5 63 | 64 | CountVectorizer 65 | - ngram_range=(1, 1) 66 | 67 | 68 | data->vect5 69 | 70 | 71 | 72 | 73 | vect6 74 | 75 | CountVectorizer 76 | - ngram_range=(1, 1) 77 | 78 | 79 | data->vect6 80 | 81 | 82 | 83 | 84 | tfidf1 85 | 86 | TfidfTransformer 87 | - norm='l1' 88 | 89 | 90 | vect1->tfidf1 91 | 92 | 93 | 94 | 95 | tfidf2 96 | 97 | TfidfTransformer 98 | - norm='l1' 99 | 100 | 101 | vect2->tfidf2 102 | 103 | 104 | 105 | 106 | tfidf3 107 | 108 | TfidfTransformer 109 | - norm='l1' 110 | 111 | 112 | vect3->tfidf3 113 | 114 | 115 | 116 | 117 | tfidf4 118 | 119 | TfidfTransformer 120 | - norm='l2' 121 | 122 | 123 | vect4->tfidf4 124 | 125 | 126 | 127 | 128 | tfidf5 129 | 130 | TfidfTransformer 131 | - norm='l2' 132 | 133 | 134 | vect5->tfidf5 135 | 136 | 137 | 138 | 139 | tfidf6 140 | 141 | TfidfTransformer 142 | - norm='l2' 143 | 144 | 145 | vect6->tfidf6 146 | 147 | 148 | 149 | 150 | sgd1 151 | 152 | SGDClassifier 153 | - alpha=1e-3 154 | 155 | 156 | tfidf1->sgd1 157 | 158 | 159 | 160 | 161 | sgd2 162 | 163 | SGDClassifier 164 | - alpha=1e-4 165 | 166 | 167 | tfidf2->sgd2 168 | 169 | 170 | 171 | 172 | sgd3 173 | 174 | SGDClassifier 175 | - alpha=1e-5 176 | 177 | 178 | tfidf3->sgd3 179 | 180 | 181 | 182 | 183 | sgd4 184 | 185 | SGDClassifier 186 | - alpha=1e-3 187 | 188 | 189 | tfidf4->sgd4 190 | 191 | 192 | 193 | 194 | sgd5 195 | 196 | SGDClassifier 197 | - alpha=1e-4 198 | 199 | 200 | tfidf5->sgd5 201 | 202 | 203 | 204 | 205 | sgd6 206 | 207 | SGDClassifier 208 | - alpha=1e-5 209 | 210 | 211 | tfidf6->sgd6 212 | 213 | 214 | 215 | 216 | best 217 | 218 | Choose Best Parameters 219 | 220 | 221 | sgd1->best 222 | 223 | 224 | 225 | 226 | sgd2->best 227 | 228 | 229 | 230 | 231 | sgd3->best 232 | 233 | 234 | 235 | 236 | sgd4->best 237 | 238 | 239 | 240 | 241 | sgd5->best 242 | 243 | 244 | 245 | 246 | sgd6->best 247 | 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /docs/source/sphinxext/docscrape.py: -------------------------------------------------------------------------------- 1 | """Extract reference documentation from the NumPy source tree. 2 | 3 | """ 4 | 5 | import inspect 6 | import textwrap 7 | import re 8 | import pydoc 9 | from warnings import warn 10 | # Try Python 2 first, otherwise load from Python 3 11 | try: 12 | from StringIO import StringIO 13 | except: 14 | from io import StringIO 15 | 16 | 17 | class Reader(object): 18 | """A line-based string reader. 19 | 20 | """ 21 | def __init__(self, data): 22 | """ 23 | Parameters 24 | ---------- 25 | data : str 26 | String with lines separated by '\n'. 27 | 28 | """ 29 | if isinstance(data, list): 30 | self._str = data 31 | else: 32 | self._str = data.split('\n') # store string as list of lines 33 | 34 | self.reset() 35 | 36 | def __getitem__(self, n): 37 | return self._str[n] 38 | 39 | def reset(self): 40 | self._l = 0 # current line nr 41 | 42 | def read(self): 43 | if not self.eof(): 44 | out = self[self._l] 45 | self._l += 1 46 | return out 47 | else: 48 | return '' 49 | 50 | def seek_next_non_empty_line(self): 51 | for l in self[self._l:]: 52 | if l.strip(): 53 | break 54 | else: 55 | self._l += 1 56 | 57 | def eof(self): 58 | return self._l >= len(self._str) 59 | 60 | def read_to_condition(self, condition_func): 61 | start = self._l 62 | for line in self[start:]: 63 | if condition_func(line): 64 | return self[start:self._l] 65 | self._l += 1 66 | if self.eof(): 67 | return self[start:self._l + 1] 68 | return [] 69 | 70 | def read_to_next_empty_line(self): 71 | self.seek_next_non_empty_line() 72 | 73 | def is_empty(line): 74 | return not line.strip() 75 | return self.read_to_condition(is_empty) 76 | 77 | def read_to_next_unindented_line(self): 78 | def is_unindented(line): 79 | return (line.strip() and (len(line.lstrip()) == len(line))) 80 | return self.read_to_condition(is_unindented) 81 | 82 | def peek(self, n=0): 83 | if self._l + n < len(self._str): 84 | return self[self._l + n] 85 | else: 86 | return '' 87 | 88 | def is_empty(self): 89 | return not ''.join(self._str).strip() 90 | 91 | 92 | class NumpyDocString(object): 93 | def __init__(self, docstring, config={}): 94 | docstring = textwrap.dedent(docstring).split('\n') 95 | 96 | self._doc = Reader(docstring) 97 | self._parsed_data = { 98 | 'Signature': '', 99 | 'Summary': [''], 100 | 'Extended Summary': [], 101 | 'Parameters': [], 102 | 'Returns': [], 103 | 'Raises': [], 104 | 'Warns': [], 105 | 'Other Parameters': [], 106 | 'Attributes': [], 107 | 'Methods': [], 108 | 'See Also': [], 109 | 'Notes': [], 110 | 'Warnings': [], 111 | 'References': '', 112 | 'Examples': '', 113 | 'index': {} 114 | } 115 | 116 | self._parse() 117 | 118 | def __getitem__(self, key): 119 | return self._parsed_data[key] 120 | 121 | def __setitem__(self, key, val): 122 | if key not in self._parsed_data: 123 | warn("Unknown section %s" % key) 124 | else: 125 | self._parsed_data[key] = val 126 | 127 | def _is_at_section(self): 128 | self._doc.seek_next_non_empty_line() 129 | 130 | if self._doc.eof(): 131 | return False 132 | 133 | l1 = self._doc.peek().strip() # e.g. Parameters 134 | 135 | if l1.startswith('.. index::'): 136 | return True 137 | 138 | l2 = self._doc.peek(1).strip() # ---------- or ========== 139 | return l2.startswith('-' * len(l1)) or l2.startswith('=' * len(l1)) 140 | 141 | def _strip(self, doc): 142 | i = 0 143 | j = 0 144 | for i, line in enumerate(doc): 145 | if line.strip(): 146 | break 147 | 148 | for j, line in enumerate(doc[::-1]): 149 | if line.strip(): 150 | break 151 | 152 | return doc[i:len(doc) - j] 153 | 154 | def _read_to_next_section(self): 155 | section = self._doc.read_to_next_empty_line() 156 | 157 | while not self._is_at_section() and not self._doc.eof(): 158 | if not self._doc.peek(-1).strip(): # previous line was empty 159 | section += [''] 160 | 161 | section += self._doc.read_to_next_empty_line() 162 | 163 | return section 164 | 165 | def _read_sections(self): 166 | while not self._doc.eof(): 167 | data = self._read_to_next_section() 168 | name = data[0].strip() 169 | 170 | if name.startswith('..'): # index section 171 | yield name, data[1:] 172 | elif len(data) < 2: 173 | yield StopIteration 174 | else: 175 | yield name, self._strip(data[2:]) 176 | 177 | def _parse_param_list(self, content): 178 | r = Reader(content) 179 | params = [] 180 | while not r.eof(): 181 | header = r.read().strip() 182 | if ' : ' in header: 183 | arg_name, arg_type = header.split(' : ')[:2] 184 | else: 185 | arg_name, arg_type = header, '' 186 | 187 | desc = r.read_to_next_unindented_line() 188 | desc = dedent_lines(desc) 189 | 190 | params.append((arg_name, arg_type, desc)) 191 | 192 | return params 193 | 194 | _name_rgx = re.compile(r"^\s*(:(?P\w+):`(?P[a-zA-Z0-9_.-]+)`|" 195 | r" (?P[a-zA-Z0-9_.-]+))\s*", re.X) 196 | 197 | def _parse_see_also(self, content): 198 | """ 199 | func_name : Descriptive text 200 | continued text 201 | another_func_name : Descriptive text 202 | func_name1, func_name2, :meth:`func_name`, func_name3 203 | 204 | """ 205 | items = [] 206 | 207 | def parse_item_name(text): 208 | """Match ':role:`name`' or 'name'""" 209 | m = self._name_rgx.match(text) 210 | if m: 211 | g = m.groups() 212 | if g[1] is None: 213 | return g[3], None 214 | else: 215 | return g[2], g[1] 216 | raise ValueError("%s is not a item name" % text) 217 | 218 | def push_item(name, rest): 219 | if not name: 220 | return 221 | name, role = parse_item_name(name) 222 | items.append((name, list(rest), role)) 223 | del rest[:] 224 | 225 | current_func = None 226 | rest = [] 227 | 228 | for line in content: 229 | if not line.strip(): 230 | continue 231 | 232 | m = self._name_rgx.match(line) 233 | if m and line[m.end():].strip().startswith(':'): 234 | push_item(current_func, rest) 235 | current_func, line = line[:m.end()], line[m.end():] 236 | rest = [line.split(':', 1)[1].strip()] 237 | if not rest[0]: 238 | rest = [] 239 | elif not line.startswith(' '): 240 | push_item(current_func, rest) 241 | current_func = None 242 | if ',' in line: 243 | for func in line.split(','): 244 | push_item(func, []) 245 | elif line.strip(): 246 | current_func = line 247 | elif current_func is not None: 248 | rest.append(line.strip()) 249 | push_item(current_func, rest) 250 | return items 251 | 252 | def _parse_index(self, section, content): 253 | """ 254 | .. index: default 255 | :refguide: something, else, and more 256 | 257 | """ 258 | def strip_each_in(lst): 259 | return [s.strip() for s in lst] 260 | 261 | out = {} 262 | section = section.split('::') 263 | if len(section) > 1: 264 | out['default'] = strip_each_in(section[1].split(','))[0] 265 | for line in content: 266 | line = line.split(':') 267 | if len(line) > 2: 268 | out[line[1]] = strip_each_in(line[2].split(',')) 269 | return out 270 | 271 | def _parse_summary(self): 272 | """Grab signature (if given) and summary""" 273 | if self._is_at_section(): 274 | return 275 | 276 | summary = self._doc.read_to_next_empty_line() 277 | summary_str = " ".join([s.strip() for s in summary]).strip() 278 | if re.compile('^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str): 279 | self['Signature'] = summary_str 280 | if not self._is_at_section(): 281 | self['Summary'] = self._doc.read_to_next_empty_line() 282 | else: 283 | self['Summary'] = summary 284 | 285 | if not self._is_at_section(): 286 | self['Extended Summary'] = self._read_to_next_section() 287 | 288 | def _parse(self): 289 | self._doc.reset() 290 | self._parse_summary() 291 | 292 | for (section, content) in self._read_sections(): 293 | if not section.startswith('..'): 294 | section = ' '.join([s.capitalize() 295 | for s in section.split(' ')]) 296 | if section in ('Parameters', 'Attributes', 'Methods', 297 | 'Returns', 'Raises', 'Warns'): 298 | self[section] = self._parse_param_list(content) 299 | elif section.startswith('.. index::'): 300 | self['index'] = self._parse_index(section, content) 301 | elif section == 'See Also': 302 | self['See Also'] = self._parse_see_also(content) 303 | else: 304 | self[section] = content 305 | 306 | # string conversion routines 307 | 308 | def _str_header(self, name, symbol='-'): 309 | return [name, len(name) * symbol] 310 | 311 | def _str_indent(self, doc, indent=4): 312 | out = [] 313 | for line in doc: 314 | out += [' ' * indent + line] 315 | return out 316 | 317 | def _str_signature(self): 318 | if self['Signature']: 319 | return [self['Signature'].replace('*', '\*')] + [''] 320 | else: 321 | return [''] 322 | 323 | def _str_summary(self): 324 | if self['Summary']: 325 | return self['Summary'] + [''] 326 | else: 327 | return [] 328 | 329 | def _str_extended_summary(self): 330 | if self['Extended Summary']: 331 | return self['Extended Summary'] + [''] 332 | else: 333 | return [] 334 | 335 | def _str_param_list(self, name): 336 | out = [] 337 | if self[name]: 338 | out += self._str_header(name) 339 | for param, param_type, desc in self[name]: 340 | out += ['%s : %s' % (param, param_type)] 341 | out += self._str_indent(desc) 342 | out += [''] 343 | return out 344 | 345 | def _str_section(self, name): 346 | out = [] 347 | if self[name]: 348 | out += self._str_header(name) 349 | out += self[name] 350 | out += [''] 351 | return out 352 | 353 | def _str_see_also(self, func_role): 354 | if not self['See Also']: 355 | return [] 356 | out = [] 357 | out += self._str_header("See Also") 358 | last_had_desc = True 359 | for func, desc, role in self['See Also']: 360 | if role: 361 | link = ':%s:`%s`' % (role, func) 362 | elif func_role: 363 | link = ':%s:`%s`' % (func_role, func) 364 | else: 365 | link = "`%s`_" % func 366 | if desc or last_had_desc: 367 | out += [''] 368 | out += [link] 369 | else: 370 | out[-1] += ", %s" % link 371 | if desc: 372 | out += self._str_indent([' '.join(desc)]) 373 | last_had_desc = True 374 | else: 375 | last_had_desc = False 376 | out += [''] 377 | return out 378 | 379 | def _str_index(self): 380 | idx = self['index'] 381 | out = [] 382 | out += ['.. index:: %s' % idx.get('default', '')] 383 | for section, references in idx.iteritems(): 384 | if section == 'default': 385 | continue 386 | out += [' :%s: %s' % (section, ', '.join(references))] 387 | return out 388 | 389 | def __str__(self, func_role=''): 390 | out = [] 391 | out += self._str_signature() 392 | out += self._str_summary() 393 | out += self._str_extended_summary() 394 | for param_list in ('Parameters', 'Returns', 'Raises'): 395 | out += self._str_param_list(param_list) 396 | out += self._str_section('Warnings') 397 | out += self._str_see_also(func_role) 398 | for s in ('Notes', 'References', 'Examples'): 399 | out += self._str_section(s) 400 | for param_list in ('Attributes', 'Methods'): 401 | out += self._str_param_list(param_list) 402 | out += self._str_index() 403 | return '\n'.join(out) 404 | 405 | 406 | def indent(str, indent=4): 407 | indent_str = ' ' * indent 408 | if str is None: 409 | return indent_str 410 | lines = str.split('\n') 411 | return '\n'.join(indent_str + l for l in lines) 412 | 413 | 414 | def dedent_lines(lines): 415 | """Deindent a list of lines maximally""" 416 | return textwrap.dedent("\n".join(lines)).split("\n") 417 | 418 | 419 | def header(text, style='-'): 420 | return text + '\n' + style * len(text) + '\n' 421 | 422 | 423 | class FunctionDoc(NumpyDocString): 424 | def __init__(self, func, role='func', doc=None, config={}): 425 | self._f = func 426 | self._role = role # e.g. "func" or "meth" 427 | 428 | if doc is None: 429 | if func is None: 430 | raise ValueError("No function or docstring given") 431 | doc = inspect.getdoc(func) or '' 432 | NumpyDocString.__init__(self, doc) 433 | 434 | if not self['Signature'] and func is not None: 435 | func, func_name = self.get_func() 436 | try: 437 | # try to read signature 438 | argspec = inspect.getargspec(func) 439 | argspec = inspect.formatargspec(*argspec) 440 | argspec = argspec.replace('*', '\*') 441 | signature = '%s%s' % (func_name, argspec) 442 | except TypeError as e: 443 | signature = '%s()' % func_name 444 | self['Signature'] = signature 445 | 446 | def get_func(self): 447 | func_name = getattr(self._f, '__name__', self.__class__.__name__) 448 | if inspect.isclass(self._f): 449 | func = getattr(self._f, '__call__', self._f.__init__) 450 | else: 451 | func = self._f 452 | return func, func_name 453 | 454 | def __str__(self): 455 | out = '' 456 | 457 | func, func_name = self.get_func() 458 | signature = self['Signature'].replace('*', '\*') 459 | 460 | roles = {'func': 'function', 461 | 'meth': 'method'} 462 | 463 | if self._role: 464 | if self._role not in roles: 465 | print("Warning: invalid role %s" % self._role) 466 | out += '.. %s:: %s\n \n\n' % (roles.get(self._role, ''), 467 | func_name) 468 | 469 | out += super(FunctionDoc, self).__str__(func_role=self._role) 470 | return out 471 | 472 | 473 | class ClassDoc(NumpyDocString): 474 | def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc, 475 | config=None): 476 | if not inspect.isclass(cls) and cls is not None: 477 | raise ValueError("Expected a class or None, but got %r" % cls) 478 | self._cls = cls 479 | 480 | if modulename and not modulename.endswith('.'): 481 | modulename += '.' 482 | self._mod = modulename 483 | 484 | if doc is None: 485 | if cls is None: 486 | raise ValueError("No class or documentation string given") 487 | doc = pydoc.getdoc(cls) 488 | 489 | NumpyDocString.__init__(self, doc) 490 | 491 | if config is not None and config.get('show_class_members', True): 492 | if not self['Methods']: 493 | self['Methods'] = [(name, '', '') 494 | for name in sorted(self.methods)] 495 | if not self['Attributes']: 496 | self['Attributes'] = [(name, '', '') 497 | for name in sorted(self.properties)] 498 | 499 | @property 500 | def methods(self): 501 | if self._cls is None: 502 | return [] 503 | return [name for name, func in inspect.getmembers(self._cls) 504 | if not name.startswith('_') and callable(func)] 505 | 506 | @property 507 | def properties(self): 508 | if self._cls is None: 509 | return [] 510 | return [name for name, func in inspect.getmembers(self._cls) 511 | if not name.startswith('_') and func is None] 512 | -------------------------------------------------------------------------------- /dask_searchcv/_version.py: -------------------------------------------------------------------------------- 1 | 2 | # This file helps to compute a version number in source trees obtained from 3 | # git-archive tarball (such as those provided by githubs download-from-tag 4 | # feature). Distribution tarballs (built by setup.py sdist) and build 5 | # directories (produced by setup.py build) will contain a much shorter file 6 | # that just contains the computed version number. 7 | 8 | # This file is released into the public domain. Generated by 9 | # versioneer-0.17 (https://github.com/warner/python-versioneer) 10 | 11 | """Git implementation of _version.py.""" 12 | 13 | import errno 14 | import os 15 | import re 16 | import subprocess 17 | import sys 18 | 19 | 20 | def get_keywords(): 21 | """Get the keywords needed to look up the version information.""" 22 | # these strings will be replaced by git during git-archive. 23 | # setup.py/versioneer.py will grep for the variable names, so they must 24 | # each be defined on a line of their own. _version.py will just call 25 | # get_keywords(). 26 | git_refnames = "$Format:%d$" 27 | git_full = "$Format:%H$" 28 | git_date = "$Format:%ci$" 29 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} 30 | return keywords 31 | 32 | 33 | class VersioneerConfig: 34 | """Container for Versioneer configuration parameters.""" 35 | 36 | 37 | def get_config(): 38 | """Create, populate and return the VersioneerConfig() object.""" 39 | # these strings are filled in when 'setup.py versioneer' creates 40 | # _version.py 41 | cfg = VersioneerConfig() 42 | cfg.VCS = "git" 43 | cfg.style = "pep440" 44 | cfg.tag_prefix = "" 45 | cfg.parentdir_prefix = "dask_searchcv-" 46 | cfg.versionfile_source = "dask_searchcv/_version.py" 47 | cfg.verbose = False 48 | return cfg 49 | 50 | 51 | class NotThisMethod(Exception): 52 | """Exception raised if a method is not valid for the current scenario.""" 53 | 54 | 55 | LONG_VERSION_PY = {} 56 | HANDLERS = {} 57 | 58 | 59 | def register_vcs_handler(vcs, method): # decorator 60 | """Decorator to mark a method as the handler for a particular VCS.""" 61 | def decorate(f): 62 | """Store f in HANDLERS[vcs][method].""" 63 | if vcs not in HANDLERS: 64 | HANDLERS[vcs] = {} 65 | HANDLERS[vcs][method] = f 66 | return f 67 | return decorate 68 | 69 | 70 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, 71 | env=None): 72 | """Call the given command(s).""" 73 | assert isinstance(commands, list) 74 | p = None 75 | for c in commands: 76 | try: 77 | dispcmd = str([c] + args) 78 | # remember shell=False, so use git.cmd on windows, not just git 79 | p = subprocess.Popen([c] + args, cwd=cwd, env=env, 80 | stdout=subprocess.PIPE, 81 | stderr=(subprocess.PIPE if hide_stderr 82 | else None)) 83 | break 84 | except EnvironmentError: 85 | e = sys.exc_info()[1] 86 | if e.errno == errno.ENOENT: 87 | continue 88 | if verbose: 89 | print("unable to run %s" % dispcmd) 90 | print(e) 91 | return None, None 92 | else: 93 | if verbose: 94 | print("unable to find command, tried %s" % (commands,)) 95 | return None, None 96 | stdout = p.communicate()[0].strip() 97 | if sys.version_info[0] >= 3: 98 | stdout = stdout.decode() 99 | if p.returncode != 0: 100 | if verbose: 101 | print("unable to run %s (error)" % dispcmd) 102 | print("stdout was %s" % stdout) 103 | return None, p.returncode 104 | return stdout, p.returncode 105 | 106 | 107 | def versions_from_parentdir(parentdir_prefix, root, verbose): 108 | """Try to determine the version from the parent directory name. 109 | 110 | Source tarballs conventionally unpack into a directory that includes both 111 | the project name and a version string. We will also support searching up 112 | two directory levels for an appropriately named parent directory 113 | """ 114 | rootdirs = [] 115 | 116 | for i in range(3): 117 | dirname = os.path.basename(root) 118 | if dirname.startswith(parentdir_prefix): 119 | return {"version": dirname[len(parentdir_prefix):], 120 | "full-revisionid": None, 121 | "dirty": False, "error": None, "date": None} 122 | else: 123 | rootdirs.append(root) 124 | root = os.path.dirname(root) # up a level 125 | 126 | if verbose: 127 | print("Tried directories %s but none started with prefix %s" % 128 | (str(rootdirs), parentdir_prefix)) 129 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 130 | 131 | 132 | @register_vcs_handler("git", "get_keywords") 133 | def git_get_keywords(versionfile_abs): 134 | """Extract version information from the given file.""" 135 | # the code embedded in _version.py can just fetch the value of these 136 | # keywords. When used from setup.py, we don't want to import _version.py, 137 | # so we do it with a regexp instead. This function is not used from 138 | # _version.py. 139 | keywords = {} 140 | try: 141 | f = open(versionfile_abs, "r") 142 | for line in f.readlines(): 143 | if line.strip().startswith("git_refnames ="): 144 | mo = re.search(r'=\s*"(.*)"', line) 145 | if mo: 146 | keywords["refnames"] = mo.group(1) 147 | if line.strip().startswith("git_full ="): 148 | mo = re.search(r'=\s*"(.*)"', line) 149 | if mo: 150 | keywords["full"] = mo.group(1) 151 | if line.strip().startswith("git_date ="): 152 | mo = re.search(r'=\s*"(.*)"', line) 153 | if mo: 154 | keywords["date"] = mo.group(1) 155 | f.close() 156 | except EnvironmentError: 157 | pass 158 | return keywords 159 | 160 | 161 | @register_vcs_handler("git", "keywords") 162 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 163 | """Get version information from git keywords.""" 164 | if not keywords: 165 | raise NotThisMethod("no keywords at all, weird") 166 | date = keywords.get("date") 167 | if date is not None: 168 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant 169 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 170 | # -like" string, which we must then edit to make compliant), because 171 | # it's been around since git-1.5.3, and it's too difficult to 172 | # discover which version we're using, or to work around using an 173 | # older one. 174 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 175 | refnames = keywords["refnames"].strip() 176 | if refnames.startswith("$Format"): 177 | if verbose: 178 | print("keywords are unexpanded, not using") 179 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 180 | refs = set([r.strip() for r in refnames.strip("()").split(",")]) 181 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 182 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 183 | TAG = "tag: " 184 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) 185 | if not tags: 186 | # Either we're using git < 1.8.3, or there really are no tags. We use 187 | # a heuristic: assume all version tags have a digit. The old git %d 188 | # expansion behaves like git log --decorate=short and strips out the 189 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 190 | # between branches and tags. By ignoring refnames without digits, we 191 | # filter out many common branch names like "release" and 192 | # "stabilization", as well as "HEAD" and "master". 193 | tags = set([r for r in refs if re.search(r'\d', r)]) 194 | if verbose: 195 | print("discarding '%s', no digits" % ",".join(refs - tags)) 196 | if verbose: 197 | print("likely tags: %s" % ",".join(sorted(tags))) 198 | for ref in sorted(tags): 199 | # sorting will prefer e.g. "2.0" over "2.0rc1" 200 | if ref.startswith(tag_prefix): 201 | r = ref[len(tag_prefix):] 202 | if verbose: 203 | print("picking %s" % r) 204 | return {"version": r, 205 | "full-revisionid": keywords["full"].strip(), 206 | "dirty": False, "error": None, 207 | "date": date} 208 | # no suitable tags, so version is "0+unknown", but full hex is still there 209 | if verbose: 210 | print("no suitable tags, using unknown + full revision id") 211 | return {"version": "0+unknown", 212 | "full-revisionid": keywords["full"].strip(), 213 | "dirty": False, "error": "no suitable tags", "date": None} 214 | 215 | 216 | @register_vcs_handler("git", "pieces_from_vcs") 217 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): 218 | """Get version from 'git describe' in the root of the source tree. 219 | 220 | This only gets called if the git-archive 'subst' keywords were *not* 221 | expanded, and _version.py hasn't already been rewritten with a short 222 | version string, meaning we're inside a checked out source tree. 223 | """ 224 | GITS = ["git"] 225 | if sys.platform == "win32": 226 | GITS = ["git.cmd", "git.exe"] 227 | 228 | out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, 229 | hide_stderr=True) 230 | if rc != 0: 231 | if verbose: 232 | print("Directory %s not under git control" % root) 233 | raise NotThisMethod("'git rev-parse --git-dir' returned error") 234 | 235 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] 236 | # if there isn't one, this yields HEX[-dirty] (no NUM) 237 | describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", 238 | "--always", "--long", 239 | "--match", "%s*" % tag_prefix], 240 | cwd=root) 241 | # --long was added in git-1.5.5 242 | if describe_out is None: 243 | raise NotThisMethod("'git describe' failed") 244 | describe_out = describe_out.strip() 245 | full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) 246 | if full_out is None: 247 | raise NotThisMethod("'git rev-parse' failed") 248 | full_out = full_out.strip() 249 | 250 | pieces = {} 251 | pieces["long"] = full_out 252 | pieces["short"] = full_out[:7] # maybe improved later 253 | pieces["error"] = None 254 | 255 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 256 | # TAG might have hyphens. 257 | git_describe = describe_out 258 | 259 | # look for -dirty suffix 260 | dirty = git_describe.endswith("-dirty") 261 | pieces["dirty"] = dirty 262 | if dirty: 263 | git_describe = git_describe[:git_describe.rindex("-dirty")] 264 | 265 | # now we have TAG-NUM-gHEX or HEX 266 | 267 | if "-" in git_describe: 268 | # TAG-NUM-gHEX 269 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 270 | if not mo: 271 | # unparseable. Maybe git-describe is misbehaving? 272 | pieces["error"] = ("unable to parse git-describe output: '%s'" 273 | % describe_out) 274 | return pieces 275 | 276 | # tag 277 | full_tag = mo.group(1) 278 | if not full_tag.startswith(tag_prefix): 279 | if verbose: 280 | fmt = "tag '%s' doesn't start with prefix '%s'" 281 | print(fmt % (full_tag, tag_prefix)) 282 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" 283 | % (full_tag, tag_prefix)) 284 | return pieces 285 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 286 | 287 | # distance: number of commits since tag 288 | pieces["distance"] = int(mo.group(2)) 289 | 290 | # commit: short hex revision ID 291 | pieces["short"] = mo.group(3) 292 | 293 | else: 294 | # HEX: no tags 295 | pieces["closest-tag"] = None 296 | count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], 297 | cwd=root) 298 | pieces["distance"] = int(count_out) # total number of commits 299 | 300 | # commit date: see ISO-8601 comment in git_versions_from_keywords() 301 | date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], 302 | cwd=root)[0].strip() 303 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 304 | 305 | return pieces 306 | 307 | 308 | def plus_or_dot(pieces): 309 | """Return a + if we don't already have one, else return a .""" 310 | if "+" in pieces.get("closest-tag", ""): 311 | return "." 312 | return "+" 313 | 314 | 315 | def render_pep440(pieces): 316 | """Build up version string, with post-release "local version identifier". 317 | 318 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 319 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 320 | 321 | Exceptions: 322 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 323 | """ 324 | if pieces["closest-tag"]: 325 | rendered = pieces["closest-tag"] 326 | if pieces["distance"] or pieces["dirty"]: 327 | rendered += plus_or_dot(pieces) 328 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 329 | if pieces["dirty"]: 330 | rendered += ".dirty" 331 | else: 332 | # exception #1 333 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], 334 | pieces["short"]) 335 | if pieces["dirty"]: 336 | rendered += ".dirty" 337 | return rendered 338 | 339 | 340 | def render_pep440_pre(pieces): 341 | """TAG[.post.devDISTANCE] -- No -dirty. 342 | 343 | Exceptions: 344 | 1: no tags. 0.post.devDISTANCE 345 | """ 346 | if pieces["closest-tag"]: 347 | rendered = pieces["closest-tag"] 348 | if pieces["distance"]: 349 | rendered += ".post.dev%d" % pieces["distance"] 350 | else: 351 | # exception #1 352 | rendered = "0.post.dev%d" % pieces["distance"] 353 | return rendered 354 | 355 | 356 | def render_pep440_post(pieces): 357 | """TAG[.postDISTANCE[.dev0]+gHEX] . 358 | 359 | The ".dev0" means dirty. Note that .dev0 sorts backwards 360 | (a dirty tree will appear "older" than the corresponding clean one), 361 | but you shouldn't be releasing software with -dirty anyways. 362 | 363 | Exceptions: 364 | 1: no tags. 0.postDISTANCE[.dev0] 365 | """ 366 | if pieces["closest-tag"]: 367 | rendered = pieces["closest-tag"] 368 | if pieces["distance"] or pieces["dirty"]: 369 | rendered += ".post%d" % pieces["distance"] 370 | if pieces["dirty"]: 371 | rendered += ".dev0" 372 | rendered += plus_or_dot(pieces) 373 | rendered += "g%s" % pieces["short"] 374 | else: 375 | # exception #1 376 | rendered = "0.post%d" % pieces["distance"] 377 | if pieces["dirty"]: 378 | rendered += ".dev0" 379 | rendered += "+g%s" % pieces["short"] 380 | return rendered 381 | 382 | 383 | def render_pep440_old(pieces): 384 | """TAG[.postDISTANCE[.dev0]] . 385 | 386 | The ".dev0" means dirty. 387 | 388 | Eexceptions: 389 | 1: no tags. 0.postDISTANCE[.dev0] 390 | """ 391 | if pieces["closest-tag"]: 392 | rendered = pieces["closest-tag"] 393 | if pieces["distance"] or pieces["dirty"]: 394 | rendered += ".post%d" % pieces["distance"] 395 | if pieces["dirty"]: 396 | rendered += ".dev0" 397 | else: 398 | # exception #1 399 | rendered = "0.post%d" % pieces["distance"] 400 | if pieces["dirty"]: 401 | rendered += ".dev0" 402 | return rendered 403 | 404 | 405 | def render_git_describe(pieces): 406 | """TAG[-DISTANCE-gHEX][-dirty]. 407 | 408 | Like 'git describe --tags --dirty --always'. 409 | 410 | Exceptions: 411 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 412 | """ 413 | if pieces["closest-tag"]: 414 | rendered = pieces["closest-tag"] 415 | if pieces["distance"]: 416 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 417 | else: 418 | # exception #1 419 | rendered = pieces["short"] 420 | if pieces["dirty"]: 421 | rendered += "-dirty" 422 | return rendered 423 | 424 | 425 | def render_git_describe_long(pieces): 426 | """TAG-DISTANCE-gHEX[-dirty]. 427 | 428 | Like 'git describe --tags --dirty --always -long'. 429 | The distance/hash is unconditional. 430 | 431 | Exceptions: 432 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 433 | """ 434 | if pieces["closest-tag"]: 435 | rendered = pieces["closest-tag"] 436 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 437 | else: 438 | # exception #1 439 | rendered = pieces["short"] 440 | if pieces["dirty"]: 441 | rendered += "-dirty" 442 | return rendered 443 | 444 | 445 | def render(pieces, style): 446 | """Render the given version pieces into the requested style.""" 447 | if pieces["error"]: 448 | return {"version": "unknown", 449 | "full-revisionid": pieces.get("long"), 450 | "dirty": None, 451 | "error": pieces["error"], 452 | "date": None} 453 | 454 | if not style or style == "default": 455 | style = "pep440" # the default 456 | 457 | if style == "pep440": 458 | rendered = render_pep440(pieces) 459 | elif style == "pep440-pre": 460 | rendered = render_pep440_pre(pieces) 461 | elif style == "pep440-post": 462 | rendered = render_pep440_post(pieces) 463 | elif style == "pep440-old": 464 | rendered = render_pep440_old(pieces) 465 | elif style == "git-describe": 466 | rendered = render_git_describe(pieces) 467 | elif style == "git-describe-long": 468 | rendered = render_git_describe_long(pieces) 469 | else: 470 | raise ValueError("unknown style '%s'" % style) 471 | 472 | return {"version": rendered, "full-revisionid": pieces["long"], 473 | "dirty": pieces["dirty"], "error": None, 474 | "date": pieces.get("date")} 475 | 476 | 477 | def get_versions(): 478 | """Get version information or return default if unable to do so.""" 479 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 480 | # __file__, we can work backwards from there to the root. Some 481 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 482 | # case we can only use expanded keywords. 483 | 484 | cfg = get_config() 485 | verbose = cfg.verbose 486 | 487 | try: 488 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, 489 | verbose) 490 | except NotThisMethod: 491 | pass 492 | 493 | try: 494 | root = os.path.realpath(__file__) 495 | # versionfile_source is the relative path from the top of the source 496 | # tree (where the .git directory might live) to this file. Invert 497 | # this to find the root from __file__. 498 | for i in cfg.versionfile_source.split('/'): 499 | root = os.path.dirname(root) 500 | except NameError: 501 | return {"version": "0+unknown", "full-revisionid": None, 502 | "dirty": None, 503 | "error": "unable to find root of source tree", 504 | "date": None} 505 | 506 | try: 507 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 508 | return render(pieces, cfg.style) 509 | except NotThisMethod: 510 | pass 511 | 512 | try: 513 | if cfg.parentdir_prefix: 514 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 515 | except NotThisMethod: 516 | pass 517 | 518 | return {"version": "0+unknown", "full-revisionid": None, 519 | "dirty": None, 520 | "error": "unable to compute version", "date": None} 521 | -------------------------------------------------------------------------------- /dask_searchcv/tests/test_model_selection.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import pickle 5 | import warnings 6 | from itertools import product 7 | from multiprocessing import cpu_count 8 | 9 | import pytest 10 | import numpy as np 11 | import pandas as pd 12 | 13 | import dask 14 | import dask.array as da 15 | from dask.base import tokenize 16 | from dask.callbacks import Callback 17 | from dask.delayed import delayed 18 | from dask.threaded import get as get_threading 19 | from dask.utils import tmpdir 20 | 21 | from sklearn.datasets import make_classification, load_iris 22 | from sklearn.decomposition import PCA 23 | from sklearn.exceptions import NotFittedError, FitFailedWarning 24 | from sklearn.ensemble import RandomForestClassifier 25 | from sklearn.feature_selection import SelectKBest 26 | from sklearn.metrics.scorer import _passthrough_scorer 27 | from sklearn.model_selection import (KFold, 28 | GroupKFold, 29 | StratifiedKFold, 30 | TimeSeriesSplit, 31 | ShuffleSplit, 32 | GroupShuffleSplit, 33 | StratifiedShuffleSplit, 34 | LeaveOneOut, 35 | LeavePOut, 36 | LeaveOneGroupOut, 37 | LeavePGroupsOut, 38 | PredefinedSplit, 39 | GridSearchCV) 40 | from sklearn.model_selection._split import _CVIterableWrapper 41 | from sklearn.pipeline import Pipeline, FeatureUnion 42 | from sklearn.svm import SVC 43 | 44 | import dask_searchcv as dcv 45 | from dask_searchcv.model_selection import (compute_n_splits, check_cv, 46 | _normalize_n_jobs, _normalize_scheduler) 47 | from dask_searchcv._compat import _HAS_MULTIPLE_METRICS 48 | from dask_searchcv.methods import CVCache 49 | from dask_searchcv.utils_test import (FailingClassifier, MockClassifier, 50 | ScalingTransformer, CheckXClassifier, 51 | ignore_warnings) 52 | 53 | try: 54 | from distributed import Client 55 | from distributed.utils_test import cluster, loop 56 | has_distributed = True 57 | except ImportError: 58 | loop = pytest.fixture(lambda: None) 59 | has_distributed = False 60 | 61 | 62 | class assert_dask_compute(Callback): 63 | def __init__(self, compute=False): 64 | self.compute = compute 65 | 66 | def __enter__(self): 67 | self.ran = False 68 | super(assert_dask_compute, self).__enter__() 69 | 70 | def __exit__(self, *args): 71 | if not self.compute and self.ran: 72 | raise ValueError("Unexpected call to compute") 73 | elif self.compute and not self.ran: 74 | raise ValueError("Expected call to compute, but none happened") 75 | super(assert_dask_compute, self).__exit__(*args) 76 | 77 | def _start(self, dsk): 78 | self.ran = True 79 | 80 | 81 | def test_visualize(): 82 | pytest.importorskip('graphviz') 83 | 84 | X, y = make_classification(n_samples=100, n_classes=2, flip_y=.2, 85 | random_state=0) 86 | clf = SVC(random_state=0) 87 | grid = {'C': [.1, .5, .9]} 88 | gs = dcv.GridSearchCV(clf, grid).fit(X, y) 89 | 90 | assert hasattr(gs, 'dask_graph_') 91 | 92 | with tmpdir() as d: 93 | gs.visualize(filename=os.path.join(d, 'mydask')) 94 | assert os.path.exists(os.path.join(d, 'mydask.png')) 95 | 96 | # Doesn't work if not fitted 97 | gs = dcv.GridSearchCV(clf, grid) 98 | with pytest.raises(NotFittedError): 99 | gs.visualize() 100 | 101 | 102 | np_X = np.random.normal(size=(20, 3)) 103 | np_y = np.random.randint(2, size=20) 104 | np_groups = np.random.permutation(list(range(5)) * 4) 105 | da_X = da.from_array(np_X, chunks=(3, 3)) 106 | da_y = da.from_array(np_y, chunks=3) 107 | da_groups = da.from_array(np_groups, chunks=3) 108 | del_X = delayed(np_X) 109 | del_y = delayed(np_y) 110 | del_groups = delayed(np_groups) 111 | 112 | 113 | @pytest.mark.parametrize(['cls', 'has_shuffle'], 114 | [(KFold, True), 115 | (GroupKFold, False), 116 | (StratifiedKFold, True), 117 | (TimeSeriesSplit, False)]) 118 | def test_kfolds(cls, has_shuffle): 119 | assert tokenize(cls()) == tokenize(cls()) 120 | assert tokenize(cls(n_splits=3)) != tokenize(cls(n_splits=4)) 121 | if has_shuffle: 122 | assert (tokenize(cls(shuffle=True, random_state=0)) == 123 | tokenize(cls(shuffle=True, random_state=0))) 124 | 125 | rs = np.random.RandomState(42) 126 | assert (tokenize(cls(shuffle=True, random_state=rs)) == 127 | tokenize(cls(shuffle=True, random_state=rs))) 128 | 129 | assert (tokenize(cls(shuffle=True, random_state=0)) != 130 | tokenize(cls(shuffle=True, random_state=2))) 131 | 132 | assert (tokenize(cls(shuffle=False, random_state=0)) == 133 | tokenize(cls(shuffle=False, random_state=2))) 134 | 135 | cv = cls(n_splits=3) 136 | assert compute_n_splits(cv, np_X, np_y, np_groups) == 3 137 | 138 | with assert_dask_compute(False): 139 | assert compute_n_splits(cv, da_X, da_y, da_groups) == 3 140 | 141 | 142 | @pytest.mark.parametrize('cls', [ShuffleSplit, GroupShuffleSplit, 143 | StratifiedShuffleSplit]) 144 | def test_shuffle_split(cls): 145 | assert (tokenize(cls(n_splits=3, random_state=0)) == 146 | tokenize(cls(n_splits=3, random_state=0))) 147 | 148 | assert (tokenize(cls(n_splits=3, random_state=0)) != 149 | tokenize(cls(n_splits=3, random_state=2))) 150 | 151 | assert (tokenize(cls(n_splits=3, random_state=0)) != 152 | tokenize(cls(n_splits=4, random_state=0))) 153 | 154 | cv = cls(n_splits=3) 155 | assert compute_n_splits(cv, np_X, np_y, np_groups) == 3 156 | 157 | with assert_dask_compute(False): 158 | assert compute_n_splits(cv, da_X, da_y, da_groups) == 3 159 | 160 | 161 | @pytest.mark.parametrize('cvs', [(LeaveOneOut(),), 162 | (LeavePOut(2), LeavePOut(3))]) 163 | def test_leave_out(cvs): 164 | tokens = [] 165 | for cv in cvs: 166 | assert tokenize(cv) == tokenize(cv) 167 | tokens.append(cv) 168 | assert len(set(tokens)) == len(tokens) 169 | 170 | cv = cvs[0] 171 | sol = cv.get_n_splits(np_X, np_y, np_groups) 172 | assert compute_n_splits(cv, np_X, np_y, np_groups) == sol 173 | 174 | with assert_dask_compute(True): 175 | assert compute_n_splits(cv, da_X, da_y, da_groups) == sol 176 | 177 | with assert_dask_compute(False): 178 | assert compute_n_splits(cv, np_X, da_y, da_groups) == sol 179 | 180 | 181 | @pytest.mark.parametrize('cvs', [(LeaveOneGroupOut(),), 182 | (LeavePGroupsOut(2), LeavePGroupsOut(3))]) 183 | def test_leave_group_out(cvs): 184 | tokens = [] 185 | for cv in cvs: 186 | assert tokenize(cv) == tokenize(cv) 187 | tokens.append(cv) 188 | assert len(set(tokens)) == len(tokens) 189 | 190 | cv = cvs[0] 191 | sol = cv.get_n_splits(np_X, np_y, np_groups) 192 | assert compute_n_splits(cv, np_X, np_y, np_groups) == sol 193 | 194 | with assert_dask_compute(True): 195 | assert compute_n_splits(cv, da_X, da_y, da_groups) == sol 196 | 197 | with assert_dask_compute(False): 198 | assert compute_n_splits(cv, da_X, da_y, np_groups) == sol 199 | 200 | 201 | def test_predefined_split(): 202 | cv = PredefinedSplit(np.array(list(range(4)) * 5)) 203 | cv2 = PredefinedSplit(np.array(list(range(5)) * 4)) 204 | assert tokenize(cv) == tokenize(cv) 205 | assert tokenize(cv) != tokenize(cv2) 206 | 207 | sol = cv.get_n_splits(np_X, np_y, np_groups) 208 | assert compute_n_splits(cv, np_X, np_y, np_groups) == sol 209 | 210 | with assert_dask_compute(False): 211 | assert compute_n_splits(cv, da_X, da_y, da_groups) == sol 212 | 213 | 214 | def test_old_style_cv(): 215 | cv1 = _CVIterableWrapper([np.array([True, False, True, False] * 5), 216 | np.array([False, True, False, True] * 5)]) 217 | cv2 = _CVIterableWrapper([np.array([True, False, True, False] * 5), 218 | np.array([False, True, True, True] * 5)]) 219 | assert tokenize(cv1) == tokenize(cv1) 220 | assert tokenize(cv1) != tokenize(cv2) 221 | 222 | sol = cv1.get_n_splits(np_X, np_y, np_groups) 223 | assert compute_n_splits(cv1, np_X, np_y, np_groups) == sol 224 | with assert_dask_compute(False): 225 | assert compute_n_splits(cv1, da_X, da_y, da_groups) == sol 226 | 227 | 228 | def test_check_cv(): 229 | # No y, classifier=False 230 | cv = check_cv(3, classifier=False) 231 | assert isinstance(cv, KFold) and cv.n_splits == 3 232 | cv = check_cv(5, classifier=False) 233 | assert isinstance(cv, KFold) and cv.n_splits == 5 234 | 235 | # y, classifier = False 236 | dy = da.from_array(np.array([1, 0, 1, 0, 1]), chunks=2) 237 | with assert_dask_compute(False): 238 | assert isinstance(check_cv(y=dy, classifier=False), KFold) 239 | 240 | # Binary and multi-class y 241 | for y in [np.array([0, 1, 0, 1, 0, 0, 1, 1, 1]), 242 | np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])]: 243 | cv = check_cv(5, y, classifier=True) 244 | assert isinstance(cv, StratifiedKFold) and cv.n_splits == 5 245 | 246 | dy = da.from_array(y, chunks=2) 247 | with assert_dask_compute(True): 248 | cv = check_cv(5, dy, classifier=True) 249 | assert isinstance(cv, StratifiedKFold) and cv.n_splits == 5 250 | 251 | # Non-binary/multi-class y 252 | y = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]]) 253 | assert isinstance(check_cv(y=y, classifier=True), KFold) 254 | 255 | dy = da.from_array(y, chunks=2) 256 | with assert_dask_compute(True): 257 | assert isinstance(check_cv(y=dy, classifier=True), KFold) 258 | 259 | # Old style 260 | cv = [np.array([True, False, True]), np.array([False, True, False])] 261 | with assert_dask_compute(False): 262 | assert isinstance(check_cv(cv, y=dy, classifier=True), 263 | _CVIterableWrapper) 264 | 265 | # CV instance passes through 266 | y = da.ones(5, chunks=2) 267 | cv = ShuffleSplit() 268 | with assert_dask_compute(False): 269 | assert check_cv(cv, y, classifier=True) is cv 270 | assert check_cv(cv, y, classifier=False) is cv 271 | 272 | 273 | def test_grid_search_dask_inputs(): 274 | # Numpy versions 275 | np_X, np_y = make_classification(n_samples=15, n_classes=2, random_state=0) 276 | np_groups = np.random.RandomState(0).randint(0, 3, 15) 277 | # Dask array versions 278 | da_X = da.from_array(np_X, chunks=5) 279 | da_y = da.from_array(np_y, chunks=5) 280 | da_groups = da.from_array(np_groups, chunks=5) 281 | # Delayed versions 282 | del_X = delayed(np_X) 283 | del_y = delayed(np_y) 284 | del_groups = delayed(np_groups) 285 | 286 | cv = GroupKFold() 287 | clf = SVC(random_state=0) 288 | grid = {'C': [1]} 289 | 290 | sol = SVC(C=1, random_state=0).fit(np_X, np_y).support_vectors_ 291 | 292 | for X, y, groups in product([np_X, da_X, del_X], 293 | [np_y, da_y, del_y], 294 | [np_groups, da_groups, del_groups]): 295 | gs = dcv.GridSearchCV(clf, grid, cv=cv) 296 | 297 | with pytest.raises(ValueError) as exc: 298 | gs.fit(X, y) 299 | assert "parameter should not be None" in str(exc.value) 300 | 301 | gs.fit(X, y, groups=groups) 302 | np.testing.assert_allclose(sol, gs.best_estimator_.support_vectors_) 303 | 304 | 305 | def test_pipeline_feature_union(): 306 | iris = load_iris() 307 | X, y = iris.data, iris.target 308 | 309 | pca = PCA(random_state=0) 310 | kbest = SelectKBest() 311 | empty_union = FeatureUnion([('first', None), ('second', None)]) 312 | empty_pipeline = Pipeline([('first', None), ('second', None)]) 313 | scaling = Pipeline([('transform', ScalingTransformer())]) 314 | svc = SVC(kernel='linear', random_state=0) 315 | 316 | pipe = Pipeline([('empty_pipeline', empty_pipeline), 317 | ('scaling', scaling), 318 | ('missing', None), 319 | ('union', FeatureUnion([('pca', pca), 320 | ('missing', None), 321 | ('kbest', kbest), 322 | ('empty_union', empty_union)], 323 | transformer_weights={'pca': 0.5})), 324 | ('svc', svc)]) 325 | 326 | param_grid = dict(scaling__transform__factor=[1, 2], 327 | union__pca__n_components=[1, 2, 3], 328 | union__kbest__k=[1, 2], 329 | svc__C=[0.1, 1, 10]) 330 | 331 | gs = GridSearchCV(pipe, param_grid=param_grid) 332 | gs.fit(X, y) 333 | dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, scheduler='sync') 334 | dgs.fit(X, y) 335 | 336 | # Check best params match 337 | assert gs.best_params_ == dgs.best_params_ 338 | 339 | # Check PCA components match 340 | sk_pca = gs.best_estimator_.named_steps['union'].transformer_list[0][1] 341 | dk_pca = dgs.best_estimator_.named_steps['union'].transformer_list[0][1] 342 | np.testing.assert_allclose(sk_pca.components_, dk_pca.components_) 343 | 344 | # Check SelectKBest scores match 345 | sk_kbest = gs.best_estimator_.named_steps['union'].transformer_list[2][1] 346 | dk_kbest = dgs.best_estimator_.named_steps['union'].transformer_list[2][1] 347 | np.testing.assert_allclose(sk_kbest.scores_, dk_kbest.scores_) 348 | 349 | # Check SVC coefs match 350 | np.testing.assert_allclose(gs.best_estimator_.named_steps['svc'].coef_, 351 | dgs.best_estimator_.named_steps['svc'].coef_) 352 | 353 | 354 | def test_pipeline_sub_estimators(): 355 | iris = load_iris() 356 | X, y = iris.data, iris.target 357 | 358 | scaling = Pipeline([('transform', ScalingTransformer())]) 359 | 360 | pipe = Pipeline([('setup', None), 361 | ('missing', None), 362 | ('scaling', scaling), 363 | ('svc', SVC(kernel='linear', random_state=0))]) 364 | 365 | param_grid = [{'svc__C': [0.1, 0.1]}, # Duplicates to test culling 366 | {'setup': [None], 367 | 'svc__C': [0.1, 1, 10], 368 | 'scaling': [ScalingTransformer(), None]}, 369 | {'setup': [SelectKBest()], 370 | 'setup__k': [1, 2], 371 | 'svc': [SVC(kernel='linear', random_state=0, C=0.1), 372 | SVC(kernel='linear', random_state=0, C=1), 373 | SVC(kernel='linear', random_state=0, C=10)]}] 374 | 375 | gs = GridSearchCV(pipe, param_grid=param_grid, return_train_score=True) 376 | gs.fit(X, y) 377 | dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, scheduler='sync', 378 | return_train_score=True) 379 | dgs.fit(X, y) 380 | 381 | # Check best params match 382 | assert gs.best_params_ == dgs.best_params_ 383 | 384 | # Check cv results match 385 | res = pd.DataFrame(dgs.cv_results_) 386 | sol = pd.DataFrame(gs.cv_results_) 387 | assert res.columns.equals(sol.columns) 388 | skip = ['mean_fit_time', 'std_fit_time', 'mean_score_time', 'std_score_time'] 389 | res = res.drop(skip, axis=1) 390 | sol = sol.drop(skip, axis=1) 391 | assert res.equals(sol) 392 | 393 | # Check SVC coefs match 394 | np.testing.assert_allclose(gs.best_estimator_.named_steps['svc'].coef_, 395 | dgs.best_estimator_.named_steps['svc'].coef_) 396 | 397 | 398 | def check_scores_all_nan(gs, bad_param, score_key='score'): 399 | bad_param = 'param_' + bad_param 400 | n_candidates = len(gs.cv_results_['params']) 401 | keys = ['split{}_test_{}'.format(s, score_key) 402 | for s in range(gs.n_splits_)] 403 | assert all(np.isnan([gs.cv_results_[key][cand_i] 404 | for key in keys]).all() 405 | for cand_i in range(n_candidates) 406 | if gs.cv_results_[bad_param][cand_i] == 407 | FailingClassifier.FAILING_PARAMETER) 408 | 409 | 410 | @pytest.mark.parametrize('weights', 411 | [None, (None, {'tr0': 2, 'tr2': 3}, {'tr0': 2, 'tr2': 4})]) 412 | def test_feature_union(weights): 413 | X = np.ones((10, 5)) 414 | y = np.zeros(10) 415 | 416 | union = FeatureUnion([('tr0', ScalingTransformer()), 417 | ('tr1', ScalingTransformer()), 418 | ('tr2', ScalingTransformer())]) 419 | 420 | factors = [(2, 3, 5), (2, 4, 5), (2, 4, 6), 421 | (2, 4, None), (None, None, None)] 422 | params, sols, grid = [], [], [] 423 | for constants, w in product(factors, weights or [None]): 424 | p = {} 425 | for n, c in enumerate(constants): 426 | if c is None: 427 | p['tr%d' % n] = None 428 | elif n == 3: # 3rd is always an estimator 429 | p['tr%d' % n] = ScalingTransformer(c) 430 | else: 431 | p['tr%d__factor' % n] = c 432 | sol = union.set_params(transformer_weights=w, **p).transform(X) 433 | sols.append(sol) 434 | if w is not None: 435 | p['transformer_weights'] = w 436 | params.append(p) 437 | p2 = {'union__' + k: [v] for k, v in p.items()} 438 | p2['est'] = [CheckXClassifier(sol[0])] 439 | grid.append(p2) 440 | 441 | # Need to recreate the union after setting estimators to `None` above 442 | union = FeatureUnion([('tr0', ScalingTransformer()), 443 | ('tr1', ScalingTransformer()), 444 | ('tr2', ScalingTransformer())]) 445 | 446 | pipe = Pipeline([('union', union), ('est', CheckXClassifier())]) 447 | gs = dcv.GridSearchCV(pipe, grid, refit=False, cv=2) 448 | 449 | with warnings.catch_warnings(record=True): 450 | gs.fit(X, y) 451 | 452 | 453 | @ignore_warnings 454 | def test_feature_union_fit_failure(): 455 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 456 | 457 | pipe = Pipeline([('union', FeatureUnion([('good', MockClassifier()), 458 | ('bad', FailingClassifier())], 459 | transformer_weights={'bad': 0.5})), 460 | ('clf', MockClassifier())]) 461 | 462 | grid = {'union__bad__parameter': [0, 1, 2]} 463 | gs = dcv.GridSearchCV(pipe, grid, refit=False, scoring=None) 464 | 465 | # Check that failure raises if error_score is `'raise'` 466 | with pytest.raises(ValueError): 467 | gs.fit(X, y) 468 | 469 | # Check that grid scores were set to error_score on failure 470 | gs.error_score = float('nan') 471 | with pytest.warns(FitFailedWarning): 472 | gs.fit(X, y) 473 | check_scores_all_nan(gs, 'union__bad__parameter') 474 | 475 | 476 | @ignore_warnings 477 | @pytest.mark.skipif(not _HAS_MULTIPLE_METRICS, reason="Added in 0.19.0") 478 | def test_feature_union_fit_failure_multiple_metrics(): 479 | scoring = {"score_1": _passthrough_scorer, "score_2": _passthrough_scorer} 480 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 481 | 482 | pipe = Pipeline([('union', FeatureUnion([('good', MockClassifier()), 483 | ('bad', FailingClassifier())], 484 | transformer_weights={'bad': 0.5})), 485 | ('clf', MockClassifier())]) 486 | 487 | grid = {'union__bad__parameter': [0, 1, 2]} 488 | gs = dcv.GridSearchCV(pipe, grid, refit=False, scoring=scoring) 489 | 490 | # Check that failure raises if error_score is `'raise'` 491 | with pytest.raises(ValueError): 492 | gs.fit(X, y) 493 | 494 | # Check that grid scores were set to error_score on failure 495 | gs.error_score = float('nan') 496 | with pytest.warns(FitFailedWarning): 497 | gs.fit(X, y) 498 | 499 | for key in scoring: 500 | check_scores_all_nan(gs, 'union__bad__parameter', score_key=key) 501 | 502 | 503 | @ignore_warnings 504 | def test_pipeline_fit_failure(): 505 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 506 | 507 | pipe = Pipeline([('bad', FailingClassifier()), 508 | ('good1', MockClassifier()), 509 | ('good2', MockClassifier())]) 510 | 511 | grid = {'bad__parameter': [0, 1, 2]} 512 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 513 | 514 | # Check that failure raises if error_score is `'raise'` 515 | with pytest.raises(ValueError): 516 | gs.fit(X, y) 517 | 518 | # Check that grid scores were set to error_score on failure 519 | gs.error_score = float('nan') 520 | with pytest.warns(FitFailedWarning): 521 | gs.fit(X, y) 522 | 523 | check_scores_all_nan(gs, 'bad__parameter') 524 | 525 | 526 | def test_pipeline_raises(): 527 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 528 | 529 | pipe = Pipeline([('step1', MockClassifier()), 530 | ('step2', MockClassifier())]) 531 | 532 | grid = {'step3__parameter': [0, 1, 2]} 533 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 534 | with pytest.raises(ValueError): 535 | gs.fit(X, y) 536 | 537 | grid = {'steps': [[('one', MockClassifier()), ('two', MockClassifier())]]} 538 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 539 | with pytest.raises(NotImplementedError): 540 | gs.fit(X, y) 541 | 542 | 543 | def test_feature_union_raises(): 544 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 545 | 546 | union = FeatureUnion([('tr0', MockClassifier()), 547 | ('tr1', MockClassifier())]) 548 | pipe = Pipeline([('union', union), ('est', MockClassifier())]) 549 | 550 | grid = {'union__tr2__parameter': [0, 1, 2]} 551 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 552 | with pytest.raises(ValueError): 553 | gs.fit(X, y) 554 | 555 | grid = {'union__transformer_list': [[('one', MockClassifier())]]} 556 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 557 | with pytest.raises(NotImplementedError): 558 | gs.fit(X, y) 559 | 560 | 561 | def test_bad_error_score(): 562 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 563 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, 564 | error_score='badparam') 565 | 566 | with pytest.raises(ValueError): 567 | gs.fit(X, y) 568 | 569 | 570 | class CountTakes(np.ndarray): 571 | count = 0 572 | 573 | def take(self, *args, **kwargs): 574 | self.count += 1 575 | return super(CountTakes, self).take(*args, **kwargs) 576 | 577 | 578 | def test_cache_cv(): 579 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 580 | X2 = X.view(CountTakes) 581 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, 582 | cv=3, cache_cv=False, scheduler='sync') 583 | gs.fit(X2, y) 584 | assert X2.count == 2 * 3 * 3 # (1 train + 1 test) * n_params * n_splits 585 | 586 | X2 = X.view(CountTakes) 587 | assert X2.count == 0 588 | gs.cache_cv = True 589 | gs.fit(X2, y) 590 | assert X2.count == 2 * 3 # (1 test + 1 train) * n_splits 591 | 592 | 593 | def test_CVCache_serializable(): 594 | inds = np.arange(10) 595 | splits = [(inds[:3], inds[3:]), (inds[3:], inds[:3])] 596 | X = np.arange(100).reshape((10, 10)) 597 | y = np.zeros(10) 598 | cache = CVCache(splits, pairwise=True, cache=True) 599 | 600 | # Add something to the cache 601 | r1 = cache.extract(X, y, 0) 602 | assert cache.extract(X, y, 0) is r1 603 | assert len(cache.cache) == 1 604 | 605 | cache2 = pickle.loads(pickle.dumps(cache)) 606 | assert len(cache2.cache) == 0 607 | assert cache2.pairwise == cache.pairwise 608 | assert all((cache2.splits[i][j] == cache.splits[i][j]).all() 609 | for i in range(2) for j in range(2)) 610 | 611 | 612 | def test_normalize_n_jobs(): 613 | assert _normalize_n_jobs(-1) is None 614 | assert _normalize_n_jobs(-2) == cpu_count() - 1 615 | with pytest.raises(TypeError): 616 | _normalize_n_jobs('not an integer') 617 | 618 | 619 | @pytest.mark.parametrize('scheduler,n_jobs,get', 620 | [(None, 4, get_threading), 621 | ('threading', 4, get_threading), 622 | ('threaded', 4, get_threading), 623 | ('threading', 1, dask.get), 624 | ('sequential', 4, dask.get), 625 | ('synchronous', 4, dask.get), 626 | ('sync', 4, dask.get), 627 | ('multiprocessing', 4, None), 628 | (dask.get, 4, dask.get)]) 629 | def test_scheduler_param(scheduler, n_jobs, get): 630 | if scheduler == 'multiprocessing': 631 | mp = pytest.importorskip('dask.multiprocessing') 632 | get = mp.get 633 | 634 | assert _normalize_scheduler(scheduler, n_jobs) is get 635 | 636 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 637 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, cv=3, 638 | scheduler=scheduler, n_jobs=n_jobs) 639 | gs.fit(X, y) 640 | 641 | 642 | @pytest.mark.skipif('not has_distributed') 643 | def test_scheduler_param_distributed(loop): 644 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 645 | with cluster() as (s, [a, b]): 646 | with Client(s['address'], loop=loop, set_as_default=False) as client: 647 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, 648 | cv=3, scheduler=client) 649 | gs.fit(X, y) 650 | 651 | 652 | def test_scheduler_param_bad(): 653 | with pytest.raises(ValueError): 654 | _normalize_scheduler('threeding', 4) 655 | 656 | 657 | @pytest.mark.skipif(not _HAS_MULTIPLE_METRICS, reason="Added in 0.19.0") 658 | def test_cv_multiplemetrics(): 659 | X, y = make_classification(random_state=0) 660 | 661 | param_grid = {'max_depth': [1, 5]} 662 | a = dcv.GridSearchCV(RandomForestClassifier(), param_grid, refit='score1', 663 | scoring={'score1': 'accuracy', 'score2': 'accuracy'}) 664 | b = GridSearchCV(RandomForestClassifier(), param_grid, refit='score1', 665 | scoring={'score1': 'accuracy', 'score2': 'accuracy'}) 666 | a.fit(X, y) 667 | b.fit(X, y) 668 | 669 | assert a.best_score_ > 0 670 | assert isinstance(a.best_index_, type(b.best_index_)) 671 | assert isinstance(a.best_params_, type(b.best_params_)) 672 | 673 | 674 | @pytest.mark.skipif(not _HAS_MULTIPLE_METRICS, reason="Added in 0.19.0") 675 | def test_cv_multiplemetrics_requires_refit_metric(): 676 | X, y = make_classification(random_state=0) 677 | 678 | param_grid = {'max_depth': [1, 5]} 679 | a = dcv.GridSearchCV(RandomForestClassifier(), param_grid, refit=True, 680 | scoring={'score1': 'accuracy', 'score2': 'accuracy'}) 681 | 682 | with pytest.raises(ValueError): 683 | a.fit(X, y) 684 | 685 | 686 | @pytest.mark.skipif(not _HAS_MULTIPLE_METRICS, reason="Added in 0.19.0") 687 | def test_cv_multiplemetrics_no_refit(): 688 | X, y = make_classification(random_state=0) 689 | 690 | param_grid = {'max_depth': [1, 5]} 691 | a = dcv.GridSearchCV(RandomForestClassifier(), param_grid, refit=False, 692 | scoring={'score1': 'accuracy', 'score2': 'accuracy'}) 693 | b = GridSearchCV(RandomForestClassifier(), param_grid, refit=False, 694 | scoring={'score1': 'accuracy', 'score2': 'accuracy'}) 695 | assert hasattr(a, 'best_index_') is hasattr(b, 'best_index_') 696 | assert hasattr(a, 'best_estimator_') is hasattr(b, 'best_estimator_') 697 | assert hasattr(a, 'best_score_') is hasattr(b, 'best_score_') 698 | -------------------------------------------------------------------------------- /dask_searchcv/tests/test_model_selection_sklearn.py: -------------------------------------------------------------------------------- 1 | # NOTE: These tests were copied (with modification) from the equivalent 2 | # scikit-learn testing code. The scikit-learn license has been included at 3 | # dask_searchcv/SCIKIT_LEARN_LICENSE.txt. 4 | 5 | import pickle 6 | import pytest 7 | 8 | import dask 9 | import dask.array as da 10 | import numpy as np 11 | from numpy.testing import (assert_array_equal, assert_array_almost_equal, 12 | assert_almost_equal) 13 | import scipy.sparse as sp 14 | from scipy.stats import expon 15 | 16 | from sklearn.base import BaseEstimator 17 | from sklearn.cluster import KMeans 18 | from sklearn.datasets import (make_classification, make_blobs, 19 | make_multilabel_classification) 20 | from sklearn.exceptions import NotFittedError, FitFailedWarning 21 | from sklearn.linear_model import Ridge 22 | from sklearn.metrics import (f1_score, make_scorer, roc_auc_score, 23 | accuracy_score) 24 | from sklearn.model_selection import (KFold, StratifiedKFold, 25 | StratifiedShuffleSplit, LeaveOneGroupOut, 26 | LeavePGroupsOut, GroupKFold, 27 | GroupShuffleSplit) 28 | from sklearn.neighbors import KernelDensity 29 | from sklearn.pipeline import Pipeline 30 | from sklearn.preprocessing import Imputer 31 | from sklearn.svm import LinearSVC, SVC 32 | from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier 33 | 34 | import dask_searchcv as dcv 35 | from dask_searchcv.utils_test import (FailingClassifier, MockClassifier, 36 | CheckingClassifier, MockDataFrame, 37 | ignore_warnings) 38 | from dask_searchcv._compat import _HAS_MULTIPLE_METRICS, _SK_VERSION 39 | 40 | 41 | class LinearSVCNoScore(LinearSVC): 42 | """An LinearSVC classifier that has no score method.""" 43 | @property 44 | def score(self): 45 | raise AttributeError 46 | 47 | 48 | X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) 49 | y = np.array([1, 1, 2, 2]) 50 | 51 | da_X = da.from_array(np.random.normal(size=(20, 3)), chunks=(3, 3)) 52 | da_y = da.from_array(np.random.randint(2, size=20), chunks=3) 53 | 54 | 55 | def assert_grid_iter_equals_getitem(grid): 56 | assert list(grid) == [grid[i] for i in range(len(grid))] 57 | 58 | 59 | def test_grid_search(): 60 | # Test that the best estimator contains the right value for foo_param 61 | clf = MockClassifier() 62 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}) 63 | # make sure it selects the smallest parameter in case of ties 64 | grid_search.fit(X, y) 65 | assert grid_search.best_estimator_.foo_param == 2 66 | 67 | assert_array_equal(grid_search.cv_results_["param_foo_param"].data, 68 | [1, 2, 3]) 69 | 70 | # Smoke test the score etc: 71 | grid_search.score(X, y) 72 | grid_search.predict_proba(X) 73 | grid_search.decision_function(X) 74 | grid_search.transform(X) 75 | 76 | # Test exception handling on scoring 77 | grid_search.scoring = 'sklearn' 78 | with pytest.raises(ValueError): 79 | grid_search.fit(X, y) 80 | 81 | 82 | @pytest.mark.parametrize('cls,kwargs', 83 | [(dcv.GridSearchCV, {}), 84 | (dcv.RandomizedSearchCV, {'n_iter': 1})]) 85 | def test_hyperparameter_searcher_with_fit_params(cls, kwargs): 86 | X = np.arange(100).reshape(10, 10) 87 | y = np.array([0] * 5 + [1] * 5) 88 | clf = CheckingClassifier(expected_fit_params=['spam', 'eggs']) 89 | pipe = Pipeline([('clf', clf)]) 90 | searcher = cls(pipe, {'clf__foo_param': [1, 2, 3]}, cv=2, **kwargs) 91 | 92 | # The CheckingClassifer generates an assertion error if 93 | # a parameter is missing or has length != len(X). 94 | with pytest.raises(AssertionError) as exc: 95 | searcher.fit(X, y, clf__spam=np.ones(10)) 96 | assert "Expected fit parameter(s) ['eggs'] not seen." in str(exc.value) 97 | 98 | searcher.fit(X, y, clf__spam=np.ones(10), clf__eggs=np.zeros(10)) 99 | # Test with dask objects as parameters 100 | searcher.fit(X, y, clf__spam=da.ones(10, chunks=2), 101 | clf__eggs=dask.delayed(np.zeros(10))) 102 | 103 | 104 | @ignore_warnings 105 | def test_grid_search_no_score(): 106 | # Test grid-search on classifier that has no score function. 107 | clf = LinearSVC(random_state=0) 108 | X, y = make_blobs(random_state=0, centers=2) 109 | Cs = [.1, 1, 10] 110 | clf_no_score = LinearSVCNoScore(random_state=0) 111 | 112 | # XXX: It seems there's some global shared state in LinearSVC - fitting 113 | # multiple `SVC` instances in parallel using threads sometimes results in 114 | # wrong results. This only happens with threads, not processes/sync. 115 | # For now, we'll fit using the sync scheduler. 116 | grid_search = dcv.GridSearchCV(clf, {'C': Cs}, scoring='accuracy', 117 | scheduler='sync') 118 | grid_search.fit(X, y) 119 | 120 | grid_search_no_score = dcv.GridSearchCV(clf_no_score, {'C': Cs}, 121 | scoring='accuracy', 122 | scheduler='sync') 123 | # smoketest grid search 124 | grid_search_no_score.fit(X, y) 125 | 126 | # check that best params are equal 127 | assert grid_search_no_score.best_params_ == grid_search.best_params_ 128 | # check that we can call score and that it gives the correct result 129 | assert grid_search.score(X, y) == grid_search_no_score.score(X, y) 130 | 131 | # giving no scoring function raises an error 132 | grid_search_no_score = dcv.GridSearchCV(clf_no_score, {'C': Cs}) 133 | with pytest.raises(TypeError) as exc: 134 | grid_search_no_score.fit([[1]]) 135 | assert "no scoring" in str(exc.value) 136 | 137 | 138 | def test_grid_search_score_method(): 139 | X, y = make_classification(n_samples=100, n_classes=2, flip_y=.2, 140 | random_state=0) 141 | clf = LinearSVC(random_state=0) 142 | grid = {'C': [.1]} 143 | 144 | search_no_scoring = dcv.GridSearchCV(clf, grid, scoring=None).fit(X, y) 145 | search_accuracy = dcv.GridSearchCV(clf, grid, scoring='accuracy').fit(X, y) 146 | search_no_score_method_auc = dcv.GridSearchCV(LinearSVCNoScore(), grid, 147 | scoring='roc_auc').fit(X, y) 148 | search_auc = dcv.GridSearchCV(clf, grid, scoring='roc_auc').fit(X, y) 149 | 150 | # Check warning only occurs in situation where behavior changed: 151 | # estimator requires score method to compete with scoring parameter 152 | score_no_scoring = search_no_scoring.score(X, y) 153 | score_accuracy = search_accuracy.score(X, y) 154 | score_no_score_auc = search_no_score_method_auc.score(X, y) 155 | score_auc = search_auc.score(X, y) 156 | 157 | # ensure the test is sane 158 | assert score_auc < 1.0 159 | assert score_accuracy < 1.0 160 | assert score_auc != score_accuracy 161 | 162 | assert_almost_equal(score_accuracy, score_no_scoring) 163 | assert_almost_equal(score_auc, score_no_score_auc) 164 | 165 | 166 | def test_grid_search_groups(): 167 | # Check if ValueError (when groups is None) propagates to dcv.GridSearchCV 168 | # And also check if groups is correctly passed to the cv object 169 | rng = np.random.RandomState(0) 170 | 171 | X, y = make_classification(n_samples=15, n_classes=2, random_state=0) 172 | groups = rng.randint(0, 3, 15) 173 | 174 | clf = LinearSVC(random_state=0) 175 | grid = {'C': [1]} 176 | 177 | group_cvs = [LeaveOneGroupOut(), LeavePGroupsOut(2), GroupKFold(), 178 | GroupShuffleSplit()] 179 | for cv in group_cvs: 180 | gs = dcv.GridSearchCV(clf, grid, cv=cv) 181 | 182 | with pytest.raises(ValueError) as exc: 183 | assert gs.fit(X, y) 184 | assert "parameter should not be None" in str(exc.value) 185 | 186 | gs.fit(X, y, groups=groups) 187 | 188 | non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit()] 189 | for cv in non_group_cvs: 190 | gs = dcv.GridSearchCV(clf, grid, cv=cv) 191 | # Should not raise an error 192 | gs.fit(X, y) 193 | 194 | 195 | @pytest.mark.skipif(_SK_VERSION < '0.19.1', 196 | reason='only deprecated for >= 0.19.1') 197 | def test_return_train_score_warn(): 198 | # Test that warnings are raised. Will be removed in sklearn 0.21 199 | X = np.arange(100).reshape(10, 10) 200 | y = np.array([0] * 5 + [1] * 5) 201 | grid = {'C': [1, 2]} 202 | 203 | for val in [True, False]: 204 | est = dcv.GridSearchCV(LinearSVC(random_state=0), grid, 205 | return_train_score=val) 206 | with pytest.warns(None) as warns: 207 | results = est.fit(X, y).cv_results_ 208 | assert not warns 209 | assert type(results) is dict 210 | 211 | est = dcv.GridSearchCV(LinearSVC(random_state=0), grid) 212 | with pytest.warns(None) as warns: 213 | results = est.fit(X, y).cv_results_ 214 | assert not warns 215 | 216 | train_keys = {'split0_train_score', 'split1_train_score', 217 | 'split2_train_score', 'mean_train_score', 'std_train_score'} 218 | 219 | for key in results: 220 | if key in train_keys: 221 | with pytest.warns(FutureWarning): 222 | results[key] 223 | else: 224 | with pytest.warns(None) as warns: 225 | results[key] 226 | assert not warns 227 | 228 | 229 | def test_classes__property(): 230 | # Test that classes_ property matches best_estimator_.classes_ 231 | X = np.arange(100).reshape(10, 10) 232 | y = np.array([0] * 5 + [1] * 5) 233 | Cs = [.1, 1, 10] 234 | 235 | grid_search = dcv.GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) 236 | grid_search.fit(X, y) 237 | assert_array_equal(grid_search.best_estimator_.classes_, 238 | grid_search.classes_) 239 | 240 | # Test that regressors do not have a classes_ attribute 241 | grid_search = dcv.GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]}) 242 | grid_search.fit(X, y) 243 | assert not hasattr(grid_search, 'classes_') 244 | 245 | # Test that the grid searcher has no classes_ attribute before it's fit 246 | grid_search = dcv.GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) 247 | assert not hasattr(grid_search, 'classes_') 248 | 249 | # Test that the grid searcher has no classes_ attribute without a refit 250 | grid_search = dcv.GridSearchCV(LinearSVC(random_state=0), 251 | {'C': Cs}, refit=False) 252 | grid_search.fit(X, y) 253 | assert not hasattr(grid_search, 'classes_') 254 | 255 | 256 | def test_trivial_cv_results_attr(): 257 | # Test search over a "grid" with only one point. 258 | # Non-regression test: grid_scores_ wouldn't be set by dcv.GridSearchCV. 259 | clf = MockClassifier() 260 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1]}) 261 | grid_search.fit(X, y) 262 | assert hasattr(grid_search, "cv_results_") 263 | 264 | random_search = dcv.RandomizedSearchCV(clf, {'foo_param': [0]}, n_iter=1) 265 | random_search.fit(X, y) 266 | assert hasattr(grid_search, "cv_results_") 267 | 268 | 269 | def test_no_refit(): 270 | # Test that GSCV can be used for model selection alone without refitting 271 | clf = MockClassifier() 272 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=False) 273 | grid_search.fit(X, y) 274 | assert not hasattr(grid_search, "best_estimator_") 275 | assert not hasattr(grid_search, "best_index_") 276 | assert not hasattr(grid_search, "best_score_") 277 | assert not hasattr(grid_search, "best_params_") 278 | 279 | # Make sure the predict/transform etc fns raise meaningfull error msg 280 | for fn_name in ('predict', 'predict_proba', 'predict_log_proba', 281 | 'transform', 'inverse_transform'): 282 | with pytest.raises(NotFittedError) as exc: 283 | getattr(grid_search, fn_name)(X) 284 | assert (('refit=False. %s is available only after refitting on the ' 285 | 'best parameters' % fn_name) in str(exc.value)) 286 | 287 | 288 | @pytest.mark.skipif(not _HAS_MULTIPLE_METRICS, reason="Added in 0.19.0") 289 | def test_no_refit_multiple_metrics(): 290 | clf = DecisionTreeClassifier() 291 | scoring = {'score_1': 'accuracy', 'score_2': 'accuracy'} 292 | 293 | gs = dcv.GridSearchCV(clf, {'max_depth': [1, 2, 3]}, refit=False, 294 | scoring=scoring) 295 | gs.fit(da_X, da_y) 296 | assert not hasattr(gs, "best_estimator_") 297 | assert not hasattr(gs, "best_index_") 298 | assert not hasattr(gs, "best_score_") 299 | assert not hasattr(gs, "best_params_") 300 | 301 | for fn_name in ('predict', 'predict_proba', 'predict_log_proba'): 302 | with pytest.raises(NotFittedError) as exc: 303 | getattr(gs, fn_name)(X) 304 | assert (('refit=False. %s is available only after refitting on the ' 305 | 'best parameters' % fn_name) in str(exc.value)) 306 | 307 | 308 | def test_grid_search_error(): 309 | # Test that grid search will capture errors on data with different length 310 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 311 | 312 | clf = LinearSVC() 313 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 314 | with pytest.raises(ValueError): 315 | cv.fit(X_[:180], y_) 316 | 317 | 318 | def test_grid_search_one_grid_point(): 319 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 320 | param_dict = {"C": [1.0], "kernel": ["rbf"], "gamma": [0.1]} 321 | 322 | clf = SVC() 323 | cv = dcv.GridSearchCV(clf, param_dict) 324 | cv.fit(X_, y_) 325 | 326 | clf = SVC(C=1.0, kernel="rbf", gamma=0.1) 327 | clf.fit(X_, y_) 328 | 329 | assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_) 330 | 331 | 332 | def test_grid_search_bad_param_grid(): 333 | param_dict = {"C": 1.0} 334 | clf = SVC() 335 | 336 | with pytest.raises(ValueError) as exc: 337 | dcv.GridSearchCV(clf, param_dict) 338 | assert ("Parameter values for parameter (C) need to be a sequence" 339 | "(but not a string) or np.ndarray.") in str(exc.value) 340 | 341 | param_dict = {"C": []} 342 | clf = SVC() 343 | 344 | with pytest.raises(ValueError) as exc: 345 | dcv.GridSearchCV(clf, param_dict) 346 | assert ("Parameter values for parameter (C) need to be a non-empty " 347 | "sequence.") in str(exc.value) 348 | 349 | param_dict = {"C": "1,2,3"} 350 | clf = SVC() 351 | 352 | with pytest.raises(ValueError) as exc: 353 | dcv.GridSearchCV(clf, param_dict) 354 | assert ("Parameter values for parameter (C) need to be a sequence" 355 | "(but not a string) or np.ndarray.") in str(exc.value) 356 | 357 | param_dict = {"C": np.ones(6).reshape(3, 2)} 358 | clf = SVC() 359 | with pytest.raises(ValueError): 360 | dcv.GridSearchCV(clf, param_dict) 361 | 362 | 363 | def test_grid_search_sparse(): 364 | # Test that grid search works with both dense and sparse matrices 365 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 366 | 367 | clf = LinearSVC() 368 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 369 | cv.fit(X_[:180], y_[:180]) 370 | y_pred = cv.predict(X_[180:]) 371 | C = cv.best_estimator_.C 372 | 373 | X_ = sp.csr_matrix(X_) 374 | clf = LinearSVC() 375 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 376 | cv.fit(X_[:180].tocoo(), y_[:180]) 377 | y_pred2 = cv.predict(X_[180:]) 378 | C2 = cv.best_estimator_.C 379 | 380 | assert np.mean(y_pred == y_pred2) >= .9 381 | assert C == C2 382 | 383 | 384 | def test_grid_search_sparse_scoring(): 385 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 386 | 387 | clf = LinearSVC() 388 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") 389 | cv.fit(X_[:180], y_[:180]) 390 | y_pred = cv.predict(X_[180:]) 391 | C = cv.best_estimator_.C 392 | 393 | X_ = sp.csr_matrix(X_) 394 | clf = LinearSVC() 395 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") 396 | cv.fit(X_[:180], y_[:180]) 397 | y_pred2 = cv.predict(X_[180:]) 398 | C2 = cv.best_estimator_.C 399 | 400 | assert_array_equal(y_pred, y_pred2) 401 | assert C == C2 402 | # Smoke test the score 403 | # np.testing.assert_allclose(f1_score(cv.predict(X_[:180]), y[:180]), 404 | # cv.score(X_[:180], y[:180])) 405 | 406 | # test loss where greater is worse 407 | def f1_loss(y_true_, y_pred_): 408 | return -f1_score(y_true_, y_pred_) 409 | F1Loss = make_scorer(f1_loss, greater_is_better=False) 410 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring=F1Loss) 411 | cv.fit(X_[:180], y_[:180]) 412 | y_pred3 = cv.predict(X_[180:]) 413 | C3 = cv.best_estimator_.C 414 | 415 | assert C == C3 416 | assert_array_equal(y_pred, y_pred3) 417 | 418 | 419 | def test_grid_search_precomputed_kernel(): 420 | # Test that grid search works when the input features are given in the 421 | # form of a precomputed kernel matrix 422 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 423 | 424 | # compute the training kernel matrix corresponding to the linear kernel 425 | K_train = np.dot(X_[:180], X_[:180].T) 426 | y_train = y_[:180] 427 | 428 | clf = SVC(kernel='precomputed') 429 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 430 | cv.fit(K_train, y_train) 431 | 432 | assert cv.best_score_ >= 0 433 | 434 | # compute the test kernel matrix 435 | K_test = np.dot(X_[180:], X_[:180].T) 436 | y_test = y_[180:] 437 | 438 | y_pred = cv.predict(K_test) 439 | 440 | assert np.mean(y_pred == y_test) >= 0 441 | 442 | # test error is raised when the precomputed kernel is not array-like 443 | # or sparse 444 | with pytest.raises(ValueError): 445 | cv.fit(K_train.tolist(), y_train) 446 | 447 | 448 | def test_grid_search_precomputed_kernel_error_nonsquare(): 449 | # Test that grid search returns an error with a non-square precomputed 450 | # training kernel matrix 451 | K_train = np.zeros((10, 20)) 452 | y_train = np.ones((10, )) 453 | clf = SVC(kernel='precomputed') 454 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 455 | with pytest.raises(ValueError): 456 | cv.fit(K_train, y_train) 457 | 458 | 459 | class BrokenClassifier(BaseEstimator): 460 | """Broken classifier that cannot be fit twice""" 461 | 462 | def __init__(self, parameter=None): 463 | self.parameter = parameter 464 | 465 | def fit(self, X, y): 466 | assert not hasattr(self, 'has_been_fit_') 467 | self.has_been_fit_ = True 468 | 469 | def predict(self, X): 470 | return np.zeros(X.shape[0]) 471 | 472 | 473 | @ignore_warnings 474 | def test_refit(): 475 | # Regression test for bug in refitting 476 | # Simulates re-fitting a broken estimator; this used to break with 477 | # sparse SVMs. 478 | X = np.arange(100).reshape(10, 10) 479 | y = np.array([0] * 5 + [1] * 5) 480 | 481 | clf = dcv.GridSearchCV(BrokenClassifier(), [{'parameter': [0, 1]}], 482 | scoring="precision", refit=True) 483 | clf.fit(X, y) 484 | 485 | 486 | def test_gridsearch_nd(): 487 | # Pass X as list in dcv.GridSearchCV 488 | X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2) 489 | y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11) 490 | clf = CheckingClassifier(check_X=lambda x: x.shape[1:] == (5, 3, 2), 491 | check_y=lambda x: x.shape[1:] == (7, 11)) 492 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}) 493 | grid_search.fit(X_4d, y_3d).score(X, y) 494 | assert hasattr(grid_search, "cv_results_") 495 | 496 | 497 | def test_X_as_list(): 498 | # Pass X as list in dcv.GridSearchCV 499 | X = np.arange(100).reshape(10, 10) 500 | y = np.array([0] * 5 + [1] * 5) 501 | 502 | clf = CheckingClassifier(check_X=lambda x: isinstance(x, list)) 503 | cv = KFold(n_splits=3) 504 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) 505 | grid_search.fit(X.tolist(), y).score(X, y) 506 | assert hasattr(grid_search, "cv_results_") 507 | 508 | 509 | def test_y_as_list(): 510 | # Pass y as list in dcv.GridSearchCV 511 | X = np.arange(100).reshape(10, 10) 512 | y = np.array([0] * 5 + [1] * 5) 513 | 514 | clf = CheckingClassifier(check_y=lambda x: isinstance(x, list)) 515 | cv = KFold(n_splits=3) 516 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) 517 | grid_search.fit(X, y.tolist()).score(X, y) 518 | assert hasattr(grid_search, "cv_results_") 519 | 520 | 521 | @ignore_warnings 522 | def test_pandas_input(): 523 | # check cross_val_score doesn't destroy pandas dataframe 524 | types = [(MockDataFrame, MockDataFrame)] 525 | try: 526 | from pandas import Series, DataFrame 527 | types.append((DataFrame, Series)) 528 | except ImportError: 529 | pass 530 | 531 | X = np.arange(100).reshape(10, 10) 532 | y = np.array([0] * 5 + [1] * 5) 533 | 534 | for InputFeatureType, TargetType in types: 535 | # X dataframe, y series 536 | X_df, y_ser = InputFeatureType(X), TargetType(y) 537 | clf = CheckingClassifier(check_X=lambda x: isinstance(x, InputFeatureType), 538 | check_y=lambda x: isinstance(x, TargetType)) 539 | 540 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}) 541 | grid_search.fit(X_df, y_ser).score(X_df, y_ser) 542 | grid_search.predict(X_df) 543 | assert hasattr(grid_search, "cv_results_") 544 | 545 | 546 | def test_unsupervised_grid_search(): 547 | # test grid-search with unsupervised estimator 548 | X, y = make_blobs(random_state=0) 549 | km = KMeans(random_state=0) 550 | grid_search = dcv.GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), 551 | scoring='adjusted_rand_score') 552 | grid_search.fit(X, y) 553 | # ARI can find the right number :) 554 | assert grid_search.best_params_["n_clusters"] == 3 555 | 556 | # Now without a score, and without y 557 | grid_search = dcv.GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4])) 558 | grid_search.fit(X) 559 | assert grid_search.best_params_["n_clusters"] == 4 560 | 561 | 562 | def test_gridsearch_no_predict(): 563 | # test grid-search with an estimator without predict. 564 | # slight duplication of a test from KDE 565 | def custom_scoring(estimator, X): 566 | return 42 if estimator.bandwidth == .1 else 0 567 | X, _ = make_blobs(cluster_std=.1, random_state=1, 568 | centers=[[0, 1], [1, 0], [0, 0]]) 569 | search = dcv.GridSearchCV(KernelDensity(), 570 | param_grid=dict(bandwidth=[.01, .1, 1]), 571 | scoring=custom_scoring) 572 | search.fit(X) 573 | assert search.best_params_['bandwidth'] == .1 574 | assert search.best_score_ == 42 575 | 576 | 577 | def check_cv_results_array_types(cv_results, param_keys, score_keys): 578 | # Check if the search `cv_results`'s array are of correct types 579 | assert all(isinstance(cv_results[param], np.ma.MaskedArray) 580 | for param in param_keys) 581 | assert all(cv_results[key].dtype == object for key in param_keys) 582 | assert not any(isinstance(cv_results[key], np.ma.MaskedArray) 583 | for key in score_keys) 584 | assert all(cv_results[key].dtype == np.float64 585 | for key in score_keys if not key.startswith('rank')) 586 | assert cv_results['rank_test_score'].dtype == np.int32 587 | 588 | 589 | def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand): 590 | # Test the search.cv_results_ contains all the required results 591 | assert_array_equal(sorted(cv_results.keys()), 592 | sorted(param_keys + score_keys + ('params',))) 593 | assert all(cv_results[key].shape == (n_cand,) 594 | for key in param_keys + score_keys) 595 | 596 | 597 | def test_grid_search_cv_results(): 598 | X, y = make_classification(n_samples=50, n_features=4, 599 | random_state=42) 600 | 601 | n_splits = 3 602 | n_grid_points = 6 603 | params = [dict(kernel=['rbf', ], C=[1, 10], gamma=[0.1, 1]), 604 | dict(kernel=['poly', ], degree=[1, 2])] 605 | grid_search = dcv.GridSearchCV(SVC(), cv=n_splits, iid=False, 606 | param_grid=params, return_train_score=True) 607 | grid_search.fit(X, y) 608 | grid_search_iid = dcv.GridSearchCV(SVC(), cv=n_splits, iid=True, 609 | param_grid=params, return_train_score=True) 610 | grid_search_iid.fit(X, y) 611 | 612 | param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel') 613 | score_keys = ('mean_test_score', 'mean_train_score', 614 | 'rank_test_score', 615 | 'split0_test_score', 'split1_test_score', 616 | 'split2_test_score', 617 | 'split0_train_score', 'split1_train_score', 618 | 'split2_train_score', 619 | 'std_test_score', 'std_train_score', 620 | 'mean_fit_time', 'std_fit_time', 621 | 'mean_score_time', 'std_score_time') 622 | n_candidates = n_grid_points 623 | 624 | for search, iid in zip((grid_search, grid_search_iid), (False, True)): 625 | assert iid == search.iid 626 | cv_results = search.cv_results_ 627 | # Check if score and timing are reasonable 628 | assert all(cv_results['rank_test_score'] >= 1) 629 | assert all(all(cv_results[k] >= 0) for k in score_keys 630 | if k != 'rank_test_score') 631 | assert all(all(cv_results[k] <= 1) for k in score_keys 632 | if 'time' not in k and k != 'rank_test_score') 633 | # Check cv_results structure 634 | check_cv_results_array_types(cv_results, param_keys, score_keys) 635 | check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates) 636 | # Check masking 637 | cv_results = grid_search.cv_results_ 638 | n_candidates = len(grid_search.cv_results_['params']) 639 | assert all((cv_results['param_C'].mask[i] and 640 | cv_results['param_gamma'].mask[i] and 641 | not cv_results['param_degree'].mask[i]) 642 | for i in range(n_candidates) 643 | if cv_results['param_kernel'][i] == 'linear') 644 | assert all((not cv_results['param_C'].mask[i] and 645 | not cv_results['param_gamma'].mask[i] and 646 | cv_results['param_degree'].mask[i]) 647 | for i in range(n_candidates) 648 | if cv_results['param_kernel'][i] == 'rbf') 649 | 650 | 651 | def test_random_search_cv_results(): 652 | # Make a dataset with a lot of noise to get various kind of prediction 653 | # errors across CV folds and parameter settings 654 | X, y = make_classification(n_samples=200, n_features=100, n_informative=3, 655 | random_state=0) 656 | 657 | # scipy.stats dists now supports `seed` but we still support scipy 0.12 658 | # which doesn't support the seed. Hence the assertions in the test for 659 | # random_search alone should not depend on randomization. 660 | n_splits = 3 661 | n_search_iter = 30 662 | params = dict(C=expon(scale=10), gamma=expon(scale=0.1)) 663 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=n_search_iter, 664 | cv=n_splits, iid=False, 665 | param_distributions=params, 666 | return_train_score=True) 667 | random_search.fit(X, y) 668 | random_search_iid = dcv.RandomizedSearchCV(SVC(), n_iter=n_search_iter, 669 | cv=n_splits, iid=True, 670 | param_distributions=params, 671 | return_train_score=True) 672 | random_search_iid.fit(X, y) 673 | 674 | param_keys = ('param_C', 'param_gamma') 675 | score_keys = ('mean_test_score', 'mean_train_score', 676 | 'rank_test_score', 677 | 'split0_test_score', 'split1_test_score', 678 | 'split2_test_score', 679 | 'split0_train_score', 'split1_train_score', 680 | 'split2_train_score', 681 | 'std_test_score', 'std_train_score', 682 | 'mean_fit_time', 'std_fit_time', 683 | 'mean_score_time', 'std_score_time') 684 | n_cand = n_search_iter 685 | 686 | for search, iid in zip((random_search, random_search_iid), (False, True)): 687 | assert iid == search.iid 688 | cv_results = search.cv_results_ 689 | # Check results structure 690 | check_cv_results_array_types(cv_results, param_keys, score_keys) 691 | check_cv_results_keys(cv_results, param_keys, score_keys, n_cand) 692 | # For random_search, all the param array vals should be unmasked 693 | assert not (any(cv_results['param_C'].mask) or 694 | any(cv_results['param_gamma'].mask)) 695 | 696 | 697 | def test_search_iid_param(): 698 | # Test the IID parameter 699 | # noise-free simple 2d-data 700 | X, y = make_blobs(centers=[[0, 0], [1, 0], [0, 1], [1, 1]], random_state=0, 701 | cluster_std=0.1, shuffle=False, n_samples=80) 702 | # split dataset into two folds that are not iid 703 | # first one contains data of all 4 blobs, second only from two. 704 | mask = np.ones(X.shape[0], dtype=np.bool) 705 | mask[np.where(y == 1)[0][::2]] = 0 706 | mask[np.where(y == 2)[0][::2]] = 0 707 | # this leads to perfect classification on one fold and a score of 1/3 on 708 | # the other 709 | # create "cv" for splits 710 | cv = [[mask, ~mask], [~mask, mask]] 711 | # once with iid=True (default) 712 | grid_search = dcv.GridSearchCV(SVC(), param_grid={'C': [1, 10]}, cv=cv, 713 | return_train_score=True) 714 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=2, 715 | param_distributions={'C': [1, 10]}, 716 | return_train_score=True, 717 | cv=cv) 718 | for search in (grid_search, random_search): 719 | search.fit(X, y) 720 | assert search.iid 721 | 722 | test_cv_scores = np.array(list(search.cv_results_['split%d_test_score' 723 | % s_i][0] 724 | for s_i in range(search.n_splits_))) 725 | train_cv_scores = np.array(list(search.cv_results_['split%d_train_' 726 | 'score' % s_i][0] 727 | for s_i in range(search.n_splits_))) 728 | test_mean = search.cv_results_['mean_test_score'][0] 729 | test_std = search.cv_results_['std_test_score'][0] 730 | 731 | train_cv_scores = np.array(list(search.cv_results_['split%d_train_' 732 | 'score' % s_i][0] 733 | for s_i in range(search.n_splits_))) 734 | train_mean = search.cv_results_['mean_train_score'][0] 735 | train_std = search.cv_results_['std_train_score'][0] 736 | 737 | # Test the first candidate 738 | assert search.cv_results_['param_C'][0] == 1 739 | assert_array_almost_equal(test_cv_scores, [1, 1. / 3.]) 740 | assert_array_almost_equal(train_cv_scores, [1, 1]) 741 | 742 | # for first split, 1/4 of dataset is in test, for second 3/4. 743 | # take weighted average and weighted std 744 | expected_test_mean = 1 * 1. / 4. + 1. / 3. * 3. / 4. 745 | expected_test_std = np.sqrt(1. / 4 * (expected_test_mean - 1) ** 2 + 746 | 3. / 4 * (expected_test_mean - 1. / 3.) ** 747 | 2) 748 | assert_almost_equal(test_mean, expected_test_mean) 749 | assert_almost_equal(test_std, expected_test_std) 750 | 751 | # For the train scores, we do not take a weighted mean irrespective of 752 | # i.i.d. or not 753 | assert_almost_equal(train_mean, 1) 754 | assert_almost_equal(train_std, 0) 755 | 756 | # once with iid=False 757 | grid_search = dcv.GridSearchCV(SVC(), param_grid={'C': [1, 10]}, 758 | cv=cv, iid=False, return_train_score=True) 759 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=2, 760 | param_distributions={'C': [1, 10]}, 761 | cv=cv, iid=False, return_train_score=True) 762 | 763 | for search in (grid_search, random_search): 764 | search.fit(X, y) 765 | assert not search.iid 766 | 767 | test_cv_scores = np.array(list(search.cv_results_['split%d_test_score' 768 | % s][0] 769 | for s in range(search.n_splits_))) 770 | test_mean = search.cv_results_['mean_test_score'][0] 771 | test_std = search.cv_results_['std_test_score'][0] 772 | 773 | train_cv_scores = np.array(list(search.cv_results_['split%d_train_' 774 | 'score' % s][0] 775 | for s in range(search.n_splits_))) 776 | train_mean = search.cv_results_['mean_train_score'][0] 777 | train_std = search.cv_results_['std_train_score'][0] 778 | 779 | assert search.cv_results_['param_C'][0] == 1 780 | # scores are the same as above 781 | assert_array_almost_equal(test_cv_scores, [1, 1. / 3.]) 782 | # Unweighted mean/std is used 783 | assert_almost_equal(test_mean, np.mean(test_cv_scores)) 784 | assert_almost_equal(test_std, np.std(test_cv_scores)) 785 | 786 | # For the train scores, we do not take a weighted mean irrespective of 787 | # i.i.d. or not 788 | assert_almost_equal(train_mean, 1) 789 | assert_almost_equal(train_std, 0) 790 | 791 | 792 | def test_search_cv_results_rank_tie_breaking(): 793 | X, y = make_blobs(n_samples=50, random_state=42) 794 | 795 | # The two C values are close enough to give similar models 796 | # which would result in a tie of their mean cv-scores 797 | param_grid = {'C': [1, 1.001, 0.001]} 798 | 799 | grid_search = dcv.GridSearchCV(SVC(), param_grid=param_grid, 800 | return_train_score=True) 801 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=3, 802 | param_distributions=param_grid, 803 | return_train_score=True) 804 | 805 | for search in (grid_search, random_search): 806 | search.fit(X, y) 807 | cv_results = search.cv_results_ 808 | # Check tie breaking strategy - 809 | # Check that there is a tie in the mean scores between 810 | # candidates 1 and 2 alone 811 | assert_almost_equal(cv_results['mean_test_score'][0], 812 | cv_results['mean_test_score'][1]) 813 | assert_almost_equal(cv_results['mean_train_score'][0], 814 | cv_results['mean_train_score'][1]) 815 | try: 816 | assert_almost_equal(cv_results['mean_test_score'][1], 817 | cv_results['mean_test_score'][2]) 818 | except AssertionError: 819 | pass 820 | try: 821 | assert_almost_equal(cv_results['mean_train_score'][1], 822 | cv_results['mean_train_score'][2]) 823 | except AssertionError: 824 | pass 825 | # 'min' rank should be assigned to the tied candidates 826 | assert_almost_equal(search.cv_results_['rank_test_score'], [1, 1, 3]) 827 | 828 | 829 | def test_search_cv_results_none_param(): 830 | X, y = [[1], [2], [3], [4], [5]], [0, 0, 0, 0, 1] 831 | estimators = (DecisionTreeRegressor(), DecisionTreeClassifier()) 832 | est_parameters = {"random_state": [0, None]} 833 | cv = KFold(random_state=0) 834 | 835 | for est in estimators: 836 | grid_search = dcv.GridSearchCV(est, est_parameters, cv=cv).fit(X, y) 837 | assert_array_equal(grid_search.cv_results_['param_random_state'], 838 | [0, None]) 839 | 840 | 841 | def test_grid_search_correct_score_results(): 842 | # test that correct scores are used 843 | n_splits = 3 844 | clf = LinearSVC(random_state=0) 845 | X, y = make_blobs(random_state=0, centers=2) 846 | Cs = [.1, 1, 10] 847 | for score in ['f1', 'roc_auc']: 848 | # XXX: It seems there's some global shared state in LinearSVC - fitting 849 | # multiple `SVC` instances in parallel using threads sometimes results 850 | # in wrong results. This only happens with threads, not processes/sync. 851 | # For now, we'll fit using the sync scheduler. 852 | grid_search = dcv.GridSearchCV(clf, {'C': Cs}, scoring=score, 853 | cv=n_splits, scheduler='sync') 854 | cv_results = grid_search.fit(X, y).cv_results_ 855 | 856 | # Test scorer names 857 | result_keys = list(cv_results.keys()) 858 | expected_keys = (("mean_test_score", "rank_test_score") + 859 | tuple("split%d_test_score" % cv_i 860 | for cv_i in range(n_splits))) 861 | assert all(np.in1d(expected_keys, result_keys)) 862 | 863 | cv = StratifiedKFold(n_splits=n_splits) 864 | n_splits = grid_search.n_splits_ 865 | for candidate_i, C in enumerate(Cs): 866 | clf.set_params(C=C) 867 | cv_scores = np.array( 868 | list(grid_search.cv_results_['split%d_test_score' 869 | % s][candidate_i] 870 | for s in range(n_splits))) 871 | for i, (train, test) in enumerate(cv.split(X, y)): 872 | clf.fit(X[train], y[train]) 873 | if score == "f1": 874 | correct_score = f1_score(y[test], clf.predict(X[test])) 875 | elif score == "roc_auc": 876 | dec = clf.decision_function(X[test]) 877 | correct_score = roc_auc_score(y[test], dec) 878 | assert_almost_equal(correct_score, cv_scores[i]) 879 | 880 | 881 | def test_pickle(): 882 | # Test that a fit search can be pickled 883 | clf = MockClassifier() 884 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True) 885 | grid_search.fit(X, y) 886 | grid_search_pickled = pickle.loads(pickle.dumps(grid_search)) 887 | assert_array_almost_equal(grid_search.predict(X), 888 | grid_search_pickled.predict(X)) 889 | 890 | random_search = dcv.RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]}, 891 | refit=True, n_iter=3) 892 | random_search.fit(X, y) 893 | random_search_pickled = pickle.loads(pickle.dumps(random_search)) 894 | assert_array_almost_equal(random_search.predict(X), 895 | random_search_pickled.predict(X)) 896 | 897 | 898 | def test_grid_search_with_multioutput_data(): 899 | # Test search with multi-output estimator 900 | 901 | X, y = make_multilabel_classification(return_indicator=True, 902 | random_state=0) 903 | 904 | est_parameters = {"max_depth": [1, 2, 3, 4]} 905 | cv = KFold(random_state=0) 906 | 907 | estimators = [DecisionTreeRegressor(random_state=0), 908 | DecisionTreeClassifier(random_state=0)] 909 | 910 | # Test with grid search cv 911 | for est in estimators: 912 | grid_search = dcv.GridSearchCV(est, est_parameters, cv=cv) 913 | grid_search.fit(X, y) 914 | res_params = grid_search.cv_results_['params'] 915 | for cand_i in range(len(res_params)): 916 | est.set_params(**res_params[cand_i]) 917 | 918 | for i, (train, test) in enumerate(cv.split(X, y)): 919 | est.fit(X[train], y[train]) 920 | correct_score = est.score(X[test], y[test]) 921 | assert_almost_equal( 922 | correct_score, 923 | grid_search.cv_results_['split%d_test_score' % i][cand_i]) 924 | 925 | # Test with a randomized search 926 | for est in estimators: 927 | random_search = dcv.RandomizedSearchCV(est, est_parameters, 928 | cv=cv, n_iter=3) 929 | random_search.fit(X, y) 930 | res_params = random_search.cv_results_['params'] 931 | for cand_i in range(len(res_params)): 932 | est.set_params(**res_params[cand_i]) 933 | 934 | for i, (train, test) in enumerate(cv.split(X, y)): 935 | est.fit(X[train], y[train]) 936 | correct_score = est.score(X[test], y[test]) 937 | assert_almost_equal( 938 | correct_score, 939 | random_search.cv_results_['split%d_test_score' 940 | % i][cand_i]) 941 | 942 | 943 | def test_predict_proba_disabled(): 944 | # Test predict_proba when disabled on estimator. 945 | X = np.arange(20).reshape(5, -1) 946 | y = [0, 0, 1, 1, 1] 947 | clf = SVC(probability=False) 948 | gs = dcv.GridSearchCV(clf, {}, cv=2).fit(X, y) 949 | assert not hasattr(gs, "predict_proba") 950 | 951 | 952 | def test_grid_search_allows_nans(): 953 | # Test dcv.GridSearchCV with Imputer 954 | X = np.arange(20, dtype=np.float64).reshape(5, -1) 955 | X[2, :] = np.nan 956 | y = [0, 0, 1, 1, 1] 957 | p = Pipeline([ 958 | ('imputer', Imputer(strategy='mean', missing_values='NaN')), 959 | ('classifier', MockClassifier()), 960 | ]) 961 | dcv.GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y) 962 | 963 | 964 | @ignore_warnings 965 | def test_grid_search_failing_classifier(): 966 | X, y = make_classification(n_samples=20, n_features=10, random_state=0) 967 | clf = FailingClassifier() 968 | 969 | # refit=False because we want to test the behaviour of the grid search part 970 | gs = dcv.GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', 971 | refit=False, error_score=0.0) 972 | 973 | with pytest.warns(FitFailedWarning): 974 | gs.fit(X, y) 975 | 976 | n_candidates = len(gs.cv_results_['params']) 977 | 978 | # Ensure that grid scores were set to zero as required for those fits 979 | # that are expected to fail. 980 | def get_cand_scores(i): 981 | return np.array(list(gs.cv_results_['split%d_test_score' % s][i] 982 | for s in range(gs.n_splits_))) 983 | 984 | assert all((np.all(get_cand_scores(cand_i) == 0.0) 985 | for cand_i in range(n_candidates) 986 | if gs.cv_results_['param_parameter'][cand_i] == 987 | FailingClassifier.FAILING_PARAMETER)) 988 | 989 | gs = dcv.GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', 990 | refit=False, error_score=float('nan')) 991 | 992 | with pytest.warns(FitFailedWarning): 993 | gs.fit(X, y) 994 | 995 | n_candidates = len(gs.cv_results_['params']) 996 | assert all(np.all(np.isnan(get_cand_scores(cand_i))) 997 | for cand_i in range(n_candidates) 998 | if gs.cv_results_['param_parameter'][cand_i] == 999 | FailingClassifier.FAILING_PARAMETER) 1000 | 1001 | 1002 | def test_grid_search_failing_classifier_raise(): 1003 | X, y = make_classification(n_samples=20, n_features=10, random_state=0) 1004 | clf = FailingClassifier() 1005 | 1006 | # refit=False because we want to test the behaviour of the grid search part 1007 | gs = dcv.GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', 1008 | refit=False, error_score='raise') 1009 | 1010 | # FailingClassifier issues a ValueError so this is what we look for. 1011 | with pytest.raises(ValueError): 1012 | gs.fit(X, y) 1013 | 1014 | 1015 | def test_search_train_scores_set_to_false(): 1016 | X = np.arange(6).reshape(6, -1) 1017 | y = [0, 0, 0, 1, 1, 1] 1018 | clf = LinearSVC(random_state=0) 1019 | 1020 | gs = dcv.GridSearchCV(clf, param_grid={'C': [0.1, 0.2]}, 1021 | return_train_score=False) 1022 | gs.fit(X, y) 1023 | for key in gs.cv_results_: 1024 | assert not key.endswith('train_score') 1025 | 1026 | 1027 | @pytest.mark.skipif(not _HAS_MULTIPLE_METRICS, reason="Added in 0.19.0") 1028 | def test_multiple_metrics(): 1029 | scoring = {'AUC': 'roc_auc', 'Accuracy': make_scorer(accuracy_score)} 1030 | 1031 | # Setting refit='AUC', refits an estimator on the whole dataset with the 1032 | # parameter setting that has the best cross-validated AUC score. 1033 | # That estimator is made available at ``gs.best_estimator_`` along with 1034 | # parameters like ``gs.best_score_``, ``gs.best_parameters_`` and 1035 | # ``gs.best_index_`` 1036 | gs = dcv.GridSearchCV(DecisionTreeClassifier(random_state=42), 1037 | param_grid={'min_samples_split': range(2, 403, 10)}, 1038 | scoring=scoring, cv=5, refit='AUC') 1039 | gs.fit(da_X, da_y) 1040 | # some basic checks 1041 | assert set(gs.scorer_) == {'AUC', 'Accuracy'} 1042 | cv_results = gs.cv_results_.keys() 1043 | assert 'split0_test_AUC' in cv_results 1044 | assert 'split0_train_AUC' in cv_results 1045 | 1046 | assert 'split0_test_Accuracy' in cv_results 1047 | assert 'split0_test_Accuracy' in cv_results 1048 | 1049 | assert 'mean_train_AUC' in cv_results 1050 | assert 'mean_train_Accuracy' in cv_results 1051 | 1052 | assert 'std_train_AUC' in cv_results 1053 | assert 'std_train_Accuracy' in cv_results 1054 | --------------------------------------------------------------------------------