├── .coveragerc ├── .gitignore ├── .travis.yml ├── LICENSE.txt ├── MANIFEST.in ├── README.rst ├── dask_searchcv ├── SCIKIT_LEARN_LICENSE.txt ├── __init__.py ├── _normalize.py ├── _version.py ├── methods.py ├── model_selection.py ├── tests │ ├── __init__.py │ ├── test_model_selection.py │ └── test_model_selection_sklearn.py ├── utils.py └── utils_test.py ├── docs ├── Makefile ├── make.bat └── source │ ├── api.rst │ ├── conf.py │ ├── index.rst │ └── sphinxext │ ├── LICENSE.txt │ ├── docscrape.py │ ├── docscrape_sphinx.py │ └── numpydoc.py ├── setup.cfg ├── setup.py └── versioneer.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | dask_searchcv/_version.py 4 | */test_*.py 5 | source = 6 | dask_searchcv 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | docs/build 4 | build/ 5 | dist/ 6 | .idea/ 7 | log.* 8 | log 9 | .coverage 10 | .DS_Store 11 | *.swp 12 | *.swo 13 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | 4 | env: 5 | matrix: 6 | - PYTHON=2.7 SKLEARN=0.18.1 TEST_FLAGS= 7 | - PYTHON=3.5 SKLEARN=0.18.0 TEST_FLAGS= 8 | - PYTHON=3.6 SKLEARN=0.18.1 TEST_FLAGS=--doctest-modules 9 | 10 | addons: 11 | apt: 12 | packages: 13 | - graphviz 14 | 15 | install: 16 | # Install conda 17 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 18 | - bash miniconda.sh -b -p $HOME/miniconda 19 | - export PATH="$HOME/miniconda/bin:$PATH" 20 | - conda config --set always_yes yes --set changeps1 no 21 | - conda update conda 22 | 23 | # Install dependencies 24 | - conda create -n test-environment python=$PYTHON 25 | - source activate test-environment 26 | - conda install dask distributed numpy scikit-learn=$SKLEARN cytoolz pytest 27 | - pip install -q graphviz flake8 28 | - pip install --no-deps -e . 29 | 30 | script: 31 | - py.test dask_searchcv --verbose $TEST_FLAGS 32 | - flake8 dask_searchcv 33 | 34 | notifications: 35 | email: false 36 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Continuum Analytics, Inc. and contributors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | Neither the name of Continuum Analytics nor the names of any contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 28 | THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include dask_searchcv *.py 2 | 3 | include setup.py 4 | include README.rst 5 | include LICENSE.txt 6 | include MANIFEST.in 7 | include versioneer.py 8 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | dask-searchcv 2 | ============= 3 | 4 | |Travis Status| |Doc Status| |Conda Badge| |PyPI Badge| 5 | 6 | Tools for performing hyperparameter search with 7 | `Scikit-Learn `_ and `Dask `_. 8 | 9 | This library provides implementations of Scikit-Learn's ``GridSearchCV`` and 10 | ``RandomizedSearchCV``. They implement many (but not all) of the same 11 | parameters, and should be a drop-in replacement for the subset that they do 12 | implement. For certain problems, these implementations can be more efficient 13 | than those in Scikit-Learn, as they can avoid expensive repeated computations. 14 | 15 | For more information, check out the `documentation `_. 16 | 17 | Install 18 | ------- 19 | 20 | Dask-searchcv is available via ``conda`` or ``pip``: 21 | 22 | :: 23 | 24 | # Install with conda 25 | $ conda install dask-searchcv -c conda-forge 26 | 27 | # Install with pip 28 | $ pip install dask-searchcv 29 | 30 | 31 | Example 32 | ------- 33 | 34 | .. code-block:: python 35 | 36 | from sklearn.datasets import load_digits 37 | from sklearn.svm import SVC 38 | import dask_searchcv as dcv 39 | import numpy as np 40 | 41 | digits = load_digits() 42 | 43 | param_space = {'C': np.logspace(-4, 4, 9), 44 | 'gamma': np.logspace(-4, 4, 9), 45 | 'class_weight': [None, 'balanced']} 46 | 47 | model = SVC(kernel='rbf') 48 | search = dcv.GridSearchCV(model, param_space, cv=3) 49 | 50 | search.fit(digits.data, digits.target) 51 | 52 | 53 | .. |Travis Status| image:: https://travis-ci.org/dask/dask-searchcv.svg?branch=master 54 | :target: https://travis-ci.org/dask/dask-searchcv 55 | .. |Doc Status| image:: http://readthedocs.org/projects/dask-searchcv/badge/?version=latest 56 | :target: http://dask-searchcv.readthedocs.io/en/latest/index.html 57 | :alt: Documentation Status 58 | .. |PyPI Badge| image:: https://img.shields.io/pypi/v/dask-searchcv.svg 59 | :target: https://pypi.python.org/pypi/dask-searchcv 60 | .. |Conda Badge| image:: https://anaconda.org/conda-forge/dask-searchcv/badges/version.svg 61 | :target: https://anaconda.org/conda-forge/dask-searchcv 62 | -------------------------------------------------------------------------------- /dask_searchcv/SCIKIT_LEARN_LICENSE.txt: -------------------------------------------------------------------------------- 1 | New BSD License 2 | 3 | Copyright (c) 2007–2016 The scikit-learn developers. 4 | All rights reserved. 5 | 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | a. Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | b. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | c. Neither the name of the Scikit-learn Developers nor the names of 16 | its contributors may be used to endorse or promote products 17 | derived from this software without specific prior written 18 | permission. 19 | 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 29 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 30 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 31 | DAMAGE. 32 | 33 | -------------------------------------------------------------------------------- /dask_searchcv/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .model_selection import GridSearchCV, RandomizedSearchCV 4 | 5 | from ._version import get_versions 6 | __version__ = get_versions()['version'] 7 | del get_versions 8 | -------------------------------------------------------------------------------- /dask_searchcv/_normalize.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import numpy as np 4 | from dask.base import normalize_token 5 | 6 | from sklearn.base import BaseEstimator 7 | from sklearn.model_selection._split import (_BaseKFold, 8 | BaseShuffleSplit, 9 | LeaveOneOut, 10 | LeaveOneGroupOut, 11 | LeavePOut, 12 | LeavePGroupsOut, 13 | PredefinedSplit, 14 | _CVIterableWrapper) 15 | 16 | 17 | @normalize_token.register(BaseEstimator) 18 | def normalize_estimator(est): 19 | """Normalize an estimator. 20 | 21 | Note: Since scikit-learn requires duck-typing, but not sub-typing from 22 | ``BaseEstimator``, we sometimes need to call this function directly.""" 23 | return type(est).__name__, normalize_token(est.get_params()) 24 | 25 | 26 | def normalize_random_state(random_state): 27 | if isinstance(random_state, np.random.RandomState): 28 | return random_state.get_state() 29 | return random_state 30 | 31 | 32 | @normalize_token.register(_BaseKFold) 33 | def normalize_KFold(x): 34 | # Doesn't matter if shuffle is False 35 | rs = normalize_random_state(x.random_state) if x.shuffle else None 36 | return (type(x).__name__, x.n_splits, x.shuffle, rs) 37 | 38 | 39 | @normalize_token.register(BaseShuffleSplit) 40 | def normalize_ShuffleSplit(x): 41 | return (type(x).__name__, x.n_splits, x.test_size, x.train_size, 42 | normalize_random_state(x.random_state)) 43 | 44 | 45 | @normalize_token.register((LeaveOneOut, LeaveOneGroupOut)) 46 | def normalize_LeaveOneOut(x): 47 | return type(x).__name__ 48 | 49 | 50 | @normalize_token.register((LeavePOut, LeavePGroupsOut)) 51 | def normalize_LeavePOut(x): 52 | return (type(x).__name__, x.p if hasattr(x, 'p') else x.n_groups) 53 | 54 | 55 | @normalize_token.register(PredefinedSplit) 56 | def normalize_PredefinedSplit(x): 57 | return (type(x).__name__, x.test_fold) 58 | 59 | 60 | @normalize_token.register(_CVIterableWrapper) 61 | def normalize_CVIterableWrapper(x): 62 | return (type(x).__name__, x.cv) 63 | -------------------------------------------------------------------------------- /dask_searchcv/_version.py: -------------------------------------------------------------------------------- 1 | 2 | # This file helps to compute a version number in source trees obtained from 3 | # git-archive tarball (such as those provided by githubs download-from-tag 4 | # feature). Distribution tarballs (built by setup.py sdist) and build 5 | # directories (produced by setup.py build) will contain a much shorter file 6 | # that just contains the computed version number. 7 | 8 | # This file is released into the public domain. Generated by 9 | # versioneer-0.17 (https://github.com/warner/python-versioneer) 10 | 11 | """Git implementation of _version.py.""" 12 | 13 | import errno 14 | import os 15 | import re 16 | import subprocess 17 | import sys 18 | 19 | 20 | def get_keywords(): 21 | """Get the keywords needed to look up the version information.""" 22 | # these strings will be replaced by git during git-archive. 23 | # setup.py/versioneer.py will grep for the variable names, so they must 24 | # each be defined on a line of their own. _version.py will just call 25 | # get_keywords(). 26 | git_refnames = "$Format:%d$" 27 | git_full = "$Format:%H$" 28 | git_date = "$Format:%ci$" 29 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} 30 | return keywords 31 | 32 | 33 | class VersioneerConfig: 34 | """Container for Versioneer configuration parameters.""" 35 | 36 | 37 | def get_config(): 38 | """Create, populate and return the VersioneerConfig() object.""" 39 | # these strings are filled in when 'setup.py versioneer' creates 40 | # _version.py 41 | cfg = VersioneerConfig() 42 | cfg.VCS = "git" 43 | cfg.style = "pep440" 44 | cfg.tag_prefix = "" 45 | cfg.parentdir_prefix = "dask_searchcv-" 46 | cfg.versionfile_source = "dask_searchcv/_version.py" 47 | cfg.verbose = False 48 | return cfg 49 | 50 | 51 | class NotThisMethod(Exception): 52 | """Exception raised if a method is not valid for the current scenario.""" 53 | 54 | 55 | LONG_VERSION_PY = {} 56 | HANDLERS = {} 57 | 58 | 59 | def register_vcs_handler(vcs, method): # decorator 60 | """Decorator to mark a method as the handler for a particular VCS.""" 61 | def decorate(f): 62 | """Store f in HANDLERS[vcs][method].""" 63 | if vcs not in HANDLERS: 64 | HANDLERS[vcs] = {} 65 | HANDLERS[vcs][method] = f 66 | return f 67 | return decorate 68 | 69 | 70 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, 71 | env=None): 72 | """Call the given command(s).""" 73 | assert isinstance(commands, list) 74 | p = None 75 | for c in commands: 76 | try: 77 | dispcmd = str([c] + args) 78 | # remember shell=False, so use git.cmd on windows, not just git 79 | p = subprocess.Popen([c] + args, cwd=cwd, env=env, 80 | stdout=subprocess.PIPE, 81 | stderr=(subprocess.PIPE if hide_stderr 82 | else None)) 83 | break 84 | except EnvironmentError: 85 | e = sys.exc_info()[1] 86 | if e.errno == errno.ENOENT: 87 | continue 88 | if verbose: 89 | print("unable to run %s" % dispcmd) 90 | print(e) 91 | return None, None 92 | else: 93 | if verbose: 94 | print("unable to find command, tried %s" % (commands,)) 95 | return None, None 96 | stdout = p.communicate()[0].strip() 97 | if sys.version_info[0] >= 3: 98 | stdout = stdout.decode() 99 | if p.returncode != 0: 100 | if verbose: 101 | print("unable to run %s (error)" % dispcmd) 102 | print("stdout was %s" % stdout) 103 | return None, p.returncode 104 | return stdout, p.returncode 105 | 106 | 107 | def versions_from_parentdir(parentdir_prefix, root, verbose): 108 | """Try to determine the version from the parent directory name. 109 | 110 | Source tarballs conventionally unpack into a directory that includes both 111 | the project name and a version string. We will also support searching up 112 | two directory levels for an appropriately named parent directory 113 | """ 114 | rootdirs = [] 115 | 116 | for i in range(3): 117 | dirname = os.path.basename(root) 118 | if dirname.startswith(parentdir_prefix): 119 | return {"version": dirname[len(parentdir_prefix):], 120 | "full-revisionid": None, 121 | "dirty": False, "error": None, "date": None} 122 | else: 123 | rootdirs.append(root) 124 | root = os.path.dirname(root) # up a level 125 | 126 | if verbose: 127 | print("Tried directories %s but none started with prefix %s" % 128 | (str(rootdirs), parentdir_prefix)) 129 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 130 | 131 | 132 | @register_vcs_handler("git", "get_keywords") 133 | def git_get_keywords(versionfile_abs): 134 | """Extract version information from the given file.""" 135 | # the code embedded in _version.py can just fetch the value of these 136 | # keywords. When used from setup.py, we don't want to import _version.py, 137 | # so we do it with a regexp instead. This function is not used from 138 | # _version.py. 139 | keywords = {} 140 | try: 141 | f = open(versionfile_abs, "r") 142 | for line in f.readlines(): 143 | if line.strip().startswith("git_refnames ="): 144 | mo = re.search(r'=\s*"(.*)"', line) 145 | if mo: 146 | keywords["refnames"] = mo.group(1) 147 | if line.strip().startswith("git_full ="): 148 | mo = re.search(r'=\s*"(.*)"', line) 149 | if mo: 150 | keywords["full"] = mo.group(1) 151 | if line.strip().startswith("git_date ="): 152 | mo = re.search(r'=\s*"(.*)"', line) 153 | if mo: 154 | keywords["date"] = mo.group(1) 155 | f.close() 156 | except EnvironmentError: 157 | pass 158 | return keywords 159 | 160 | 161 | @register_vcs_handler("git", "keywords") 162 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 163 | """Get version information from git keywords.""" 164 | if not keywords: 165 | raise NotThisMethod("no keywords at all, weird") 166 | date = keywords.get("date") 167 | if date is not None: 168 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant 169 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 170 | # -like" string, which we must then edit to make compliant), because 171 | # it's been around since git-1.5.3, and it's too difficult to 172 | # discover which version we're using, or to work around using an 173 | # older one. 174 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 175 | refnames = keywords["refnames"].strip() 176 | if refnames.startswith("$Format"): 177 | if verbose: 178 | print("keywords are unexpanded, not using") 179 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 180 | refs = set([r.strip() for r in refnames.strip("()").split(",")]) 181 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 182 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 183 | TAG = "tag: " 184 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) 185 | if not tags: 186 | # Either we're using git < 1.8.3, or there really are no tags. We use 187 | # a heuristic: assume all version tags have a digit. The old git %d 188 | # expansion behaves like git log --decorate=short and strips out the 189 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 190 | # between branches and tags. By ignoring refnames without digits, we 191 | # filter out many common branch names like "release" and 192 | # "stabilization", as well as "HEAD" and "master". 193 | tags = set([r for r in refs if re.search(r'\d', r)]) 194 | if verbose: 195 | print("discarding '%s', no digits" % ",".join(refs - tags)) 196 | if verbose: 197 | print("likely tags: %s" % ",".join(sorted(tags))) 198 | for ref in sorted(tags): 199 | # sorting will prefer e.g. "2.0" over "2.0rc1" 200 | if ref.startswith(tag_prefix): 201 | r = ref[len(tag_prefix):] 202 | if verbose: 203 | print("picking %s" % r) 204 | return {"version": r, 205 | "full-revisionid": keywords["full"].strip(), 206 | "dirty": False, "error": None, 207 | "date": date} 208 | # no suitable tags, so version is "0+unknown", but full hex is still there 209 | if verbose: 210 | print("no suitable tags, using unknown + full revision id") 211 | return {"version": "0+unknown", 212 | "full-revisionid": keywords["full"].strip(), 213 | "dirty": False, "error": "no suitable tags", "date": None} 214 | 215 | 216 | @register_vcs_handler("git", "pieces_from_vcs") 217 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): 218 | """Get version from 'git describe' in the root of the source tree. 219 | 220 | This only gets called if the git-archive 'subst' keywords were *not* 221 | expanded, and _version.py hasn't already been rewritten with a short 222 | version string, meaning we're inside a checked out source tree. 223 | """ 224 | GITS = ["git"] 225 | if sys.platform == "win32": 226 | GITS = ["git.cmd", "git.exe"] 227 | 228 | out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, 229 | hide_stderr=True) 230 | if rc != 0: 231 | if verbose: 232 | print("Directory %s not under git control" % root) 233 | raise NotThisMethod("'git rev-parse --git-dir' returned error") 234 | 235 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] 236 | # if there isn't one, this yields HEX[-dirty] (no NUM) 237 | describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", 238 | "--always", "--long", 239 | "--match", "%s*" % tag_prefix], 240 | cwd=root) 241 | # --long was added in git-1.5.5 242 | if describe_out is None: 243 | raise NotThisMethod("'git describe' failed") 244 | describe_out = describe_out.strip() 245 | full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) 246 | if full_out is None: 247 | raise NotThisMethod("'git rev-parse' failed") 248 | full_out = full_out.strip() 249 | 250 | pieces = {} 251 | pieces["long"] = full_out 252 | pieces["short"] = full_out[:7] # maybe improved later 253 | pieces["error"] = None 254 | 255 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 256 | # TAG might have hyphens. 257 | git_describe = describe_out 258 | 259 | # look for -dirty suffix 260 | dirty = git_describe.endswith("-dirty") 261 | pieces["dirty"] = dirty 262 | if dirty: 263 | git_describe = git_describe[:git_describe.rindex("-dirty")] 264 | 265 | # now we have TAG-NUM-gHEX or HEX 266 | 267 | if "-" in git_describe: 268 | # TAG-NUM-gHEX 269 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 270 | if not mo: 271 | # unparseable. Maybe git-describe is misbehaving? 272 | pieces["error"] = ("unable to parse git-describe output: '%s'" 273 | % describe_out) 274 | return pieces 275 | 276 | # tag 277 | full_tag = mo.group(1) 278 | if not full_tag.startswith(tag_prefix): 279 | if verbose: 280 | fmt = "tag '%s' doesn't start with prefix '%s'" 281 | print(fmt % (full_tag, tag_prefix)) 282 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" 283 | % (full_tag, tag_prefix)) 284 | return pieces 285 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 286 | 287 | # distance: number of commits since tag 288 | pieces["distance"] = int(mo.group(2)) 289 | 290 | # commit: short hex revision ID 291 | pieces["short"] = mo.group(3) 292 | 293 | else: 294 | # HEX: no tags 295 | pieces["closest-tag"] = None 296 | count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], 297 | cwd=root) 298 | pieces["distance"] = int(count_out) # total number of commits 299 | 300 | # commit date: see ISO-8601 comment in git_versions_from_keywords() 301 | date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], 302 | cwd=root)[0].strip() 303 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 304 | 305 | return pieces 306 | 307 | 308 | def plus_or_dot(pieces): 309 | """Return a + if we don't already have one, else return a .""" 310 | if "+" in pieces.get("closest-tag", ""): 311 | return "." 312 | return "+" 313 | 314 | 315 | def render_pep440(pieces): 316 | """Build up version string, with post-release "local version identifier". 317 | 318 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 319 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 320 | 321 | Exceptions: 322 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 323 | """ 324 | if pieces["closest-tag"]: 325 | rendered = pieces["closest-tag"] 326 | if pieces["distance"] or pieces["dirty"]: 327 | rendered += plus_or_dot(pieces) 328 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 329 | if pieces["dirty"]: 330 | rendered += ".dirty" 331 | else: 332 | # exception #1 333 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], 334 | pieces["short"]) 335 | if pieces["dirty"]: 336 | rendered += ".dirty" 337 | return rendered 338 | 339 | 340 | def render_pep440_pre(pieces): 341 | """TAG[.post.devDISTANCE] -- No -dirty. 342 | 343 | Exceptions: 344 | 1: no tags. 0.post.devDISTANCE 345 | """ 346 | if pieces["closest-tag"]: 347 | rendered = pieces["closest-tag"] 348 | if pieces["distance"]: 349 | rendered += ".post.dev%d" % pieces["distance"] 350 | else: 351 | # exception #1 352 | rendered = "0.post.dev%d" % pieces["distance"] 353 | return rendered 354 | 355 | 356 | def render_pep440_post(pieces): 357 | """TAG[.postDISTANCE[.dev0]+gHEX] . 358 | 359 | The ".dev0" means dirty. Note that .dev0 sorts backwards 360 | (a dirty tree will appear "older" than the corresponding clean one), 361 | but you shouldn't be releasing software with -dirty anyways. 362 | 363 | Exceptions: 364 | 1: no tags. 0.postDISTANCE[.dev0] 365 | """ 366 | if pieces["closest-tag"]: 367 | rendered = pieces["closest-tag"] 368 | if pieces["distance"] or pieces["dirty"]: 369 | rendered += ".post%d" % pieces["distance"] 370 | if pieces["dirty"]: 371 | rendered += ".dev0" 372 | rendered += plus_or_dot(pieces) 373 | rendered += "g%s" % pieces["short"] 374 | else: 375 | # exception #1 376 | rendered = "0.post%d" % pieces["distance"] 377 | if pieces["dirty"]: 378 | rendered += ".dev0" 379 | rendered += "+g%s" % pieces["short"] 380 | return rendered 381 | 382 | 383 | def render_pep440_old(pieces): 384 | """TAG[.postDISTANCE[.dev0]] . 385 | 386 | The ".dev0" means dirty. 387 | 388 | Eexceptions: 389 | 1: no tags. 0.postDISTANCE[.dev0] 390 | """ 391 | if pieces["closest-tag"]: 392 | rendered = pieces["closest-tag"] 393 | if pieces["distance"] or pieces["dirty"]: 394 | rendered += ".post%d" % pieces["distance"] 395 | if pieces["dirty"]: 396 | rendered += ".dev0" 397 | else: 398 | # exception #1 399 | rendered = "0.post%d" % pieces["distance"] 400 | if pieces["dirty"]: 401 | rendered += ".dev0" 402 | return rendered 403 | 404 | 405 | def render_git_describe(pieces): 406 | """TAG[-DISTANCE-gHEX][-dirty]. 407 | 408 | Like 'git describe --tags --dirty --always'. 409 | 410 | Exceptions: 411 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 412 | """ 413 | if pieces["closest-tag"]: 414 | rendered = pieces["closest-tag"] 415 | if pieces["distance"]: 416 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 417 | else: 418 | # exception #1 419 | rendered = pieces["short"] 420 | if pieces["dirty"]: 421 | rendered += "-dirty" 422 | return rendered 423 | 424 | 425 | def render_git_describe_long(pieces): 426 | """TAG-DISTANCE-gHEX[-dirty]. 427 | 428 | Like 'git describe --tags --dirty --always -long'. 429 | The distance/hash is unconditional. 430 | 431 | Exceptions: 432 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 433 | """ 434 | if pieces["closest-tag"]: 435 | rendered = pieces["closest-tag"] 436 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 437 | else: 438 | # exception #1 439 | rendered = pieces["short"] 440 | if pieces["dirty"]: 441 | rendered += "-dirty" 442 | return rendered 443 | 444 | 445 | def render(pieces, style): 446 | """Render the given version pieces into the requested style.""" 447 | if pieces["error"]: 448 | return {"version": "unknown", 449 | "full-revisionid": pieces.get("long"), 450 | "dirty": None, 451 | "error": pieces["error"], 452 | "date": None} 453 | 454 | if not style or style == "default": 455 | style = "pep440" # the default 456 | 457 | if style == "pep440": 458 | rendered = render_pep440(pieces) 459 | elif style == "pep440-pre": 460 | rendered = render_pep440_pre(pieces) 461 | elif style == "pep440-post": 462 | rendered = render_pep440_post(pieces) 463 | elif style == "pep440-old": 464 | rendered = render_pep440_old(pieces) 465 | elif style == "git-describe": 466 | rendered = render_git_describe(pieces) 467 | elif style == "git-describe-long": 468 | rendered = render_git_describe_long(pieces) 469 | else: 470 | raise ValueError("unknown style '%s'" % style) 471 | 472 | return {"version": rendered, "full-revisionid": pieces["long"], 473 | "dirty": pieces["dirty"], "error": None, 474 | "date": pieces.get("date")} 475 | 476 | 477 | def get_versions(): 478 | """Get version information or return default if unable to do so.""" 479 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 480 | # __file__, we can work backwards from there to the root. Some 481 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 482 | # case we can only use expanded keywords. 483 | 484 | cfg = get_config() 485 | verbose = cfg.verbose 486 | 487 | try: 488 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, 489 | verbose) 490 | except NotThisMethod: 491 | pass 492 | 493 | try: 494 | root = os.path.realpath(__file__) 495 | # versionfile_source is the relative path from the top of the source 496 | # tree (where the .git directory might live) to this file. Invert 497 | # this to find the root from __file__. 498 | for i in cfg.versionfile_source.split('/'): 499 | root = os.path.dirname(root) 500 | except NameError: 501 | return {"version": "0+unknown", "full-revisionid": None, 502 | "dirty": None, 503 | "error": "unable to find root of source tree", 504 | "date": None} 505 | 506 | try: 507 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 508 | return render(pieces, cfg.style) 509 | except NotThisMethod: 510 | pass 511 | 512 | try: 513 | if cfg.parentdir_prefix: 514 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 515 | except NotThisMethod: 516 | pass 517 | 518 | return {"version": "0+unknown", "full-revisionid": None, 519 | "dirty": None, 520 | "error": "unable to compute version", "date": None} 521 | -------------------------------------------------------------------------------- /dask_searchcv/methods.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import warnings 4 | from collections import defaultdict 5 | from threading import Lock 6 | from distutils.version import LooseVersion 7 | 8 | import numpy as np 9 | from toolz import pluck 10 | from scipy import sparse 11 | from dask.base import normalize_token 12 | 13 | from sklearn.exceptions import FitFailedWarning 14 | from sklearn.pipeline import Pipeline, FeatureUnion 15 | from sklearn.utils import safe_indexing 16 | from sklearn.utils.fixes import rankdata 17 | from sklearn.utils.validation import check_consistent_length, _is_arraylike 18 | 19 | from .utils import copy_estimator 20 | 21 | # Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop 22 | # support for scikit-learn < 0.18.1 or numpy < 1.12.0. 23 | if LooseVersion(np.__version__) < '1.12.0': 24 | class MaskedArray(np.ma.MaskedArray): 25 | # Before numpy 1.12, np.ma.MaskedArray object is not picklable 26 | # This fix is needed to make our model_selection.GridSearchCV 27 | # picklable as the ``cv_results_`` param uses MaskedArray 28 | def __getstate__(self): 29 | """Return the internal state of the masked array, for pickling 30 | purposes. 31 | 32 | """ 33 | cf = 'CF'[self.flags.fnc] 34 | data_state = super(np.ma.MaskedArray, self).__reduce__()[2] 35 | return data_state + (np.ma.getmaskarray(self).tostring(cf), 36 | self._fill_value) 37 | else: 38 | from numpy.ma import MaskedArray # noqa 39 | 40 | # A singleton to indicate a missing parameter 41 | MISSING = type('MissingParameter', (object,), 42 | {'__slots__': (), 43 | '__reduce__': lambda self: 'MISSING', 44 | '__doc__': "A singleton to indicate a missing parameter"})() 45 | normalize_token.register(type(MISSING), lambda x: 'MISSING') 46 | 47 | 48 | # A singleton to indicate a failed estimator fit 49 | FIT_FAILURE = type('FitFailure', (object,), 50 | {'__slots__': (), 51 | '__reduce__': lambda self: 'FIT_FAILURE', 52 | '__doc__': "A singleton to indicate fit failure"})() 53 | 54 | 55 | def warn_fit_failure(error_score, e): 56 | warnings.warn("Classifier fit failed. The score on this train-test" 57 | " partition for these parameters will be set to %f. " 58 | "Details: \n%r" % (error_score, e), FitFailedWarning) 59 | 60 | 61 | # ----------------------- # 62 | # Functions in the graphs # 63 | # ----------------------- # 64 | 65 | 66 | class CVCache(object): 67 | def __init__(self, splits, pairwise=False, cache=True): 68 | self.splits = splits 69 | self.pairwise = pairwise 70 | self.cache = {} if cache else None 71 | 72 | def __reduce__(self): 73 | return (CVCache, (self.splits, self.pairwise, self.cache is not None)) 74 | 75 | def num_test_samples(self): 76 | return np.array([i.sum() if i.dtype == bool else len(i) 77 | for i in pluck(1, self.splits)]) 78 | 79 | def extract(self, X, y, n, is_x=True, is_train=True): 80 | if is_x: 81 | if self.pairwise: 82 | return self._extract_pairwise(X, y, n, is_train=is_train) 83 | return self._extract(X, y, n, is_x=True, is_train=is_train) 84 | if y is None: 85 | return None 86 | return self._extract(X, y, n, is_x=False, is_train=is_train) 87 | 88 | def extract_param(self, key, x, n): 89 | if self.cache is not None and (n, key) in self.cache: 90 | return self.cache[n, key] 91 | 92 | out = safe_indexing(x, self.splits[n][0]) if _is_arraylike(x) else x 93 | 94 | if self.cache is not None: 95 | self.cache[n, key] = out 96 | return out 97 | 98 | def _extract(self, X, y, n, is_x=True, is_train=True): 99 | if self.cache is not None and (n, is_x, is_train) in self.cache: 100 | return self.cache[n, is_x, is_train] 101 | 102 | inds = self.splits[n][0] if is_train else self.splits[n][1] 103 | result = safe_indexing(X if is_x else y, inds) 104 | 105 | if self.cache is not None: 106 | self.cache[n, is_x, is_train] = result 107 | return result 108 | 109 | def _extract_pairwise(self, X, y, n, is_train=True): 110 | if self.cache is not None and (n, True, is_train) in self.cache: 111 | return self.cache[n, True, is_train] 112 | 113 | if not hasattr(X, "shape"): 114 | raise ValueError("Precomputed kernels or affinity matrices have " 115 | "to be passed as arrays or sparse matrices.") 116 | if X.shape[0] != X.shape[1]: 117 | raise ValueError("X should be a square kernel matrix") 118 | train, test = self.splits[n] 119 | result = X[np.ix_(train if is_train else test, train)] 120 | 121 | if self.cache is not None: 122 | self.cache[n, True, is_train] = result 123 | return result 124 | 125 | 126 | def cv_split(cv, X, y, groups, is_pairwise, cache): 127 | check_consistent_length(X, y, groups) 128 | return CVCache(list(cv.split(X, y, groups)), is_pairwise, cache) 129 | 130 | 131 | def cv_n_samples(cvs): 132 | return cvs.num_test_samples() 133 | 134 | 135 | def cv_extract(cvs, X, y, is_X, is_train, n): 136 | return cvs.extract(X, y, n, is_X, is_train) 137 | 138 | 139 | def cv_extract_params(cvs, keys, vals, n): 140 | return {k: cvs.extract_param(tok, v, n) for (k, tok), v in zip(keys, vals)} 141 | 142 | 143 | def decompress_params(fields, params): 144 | return [{k: v for k, v in zip(fields, p) if v is not MISSING} 145 | for p in params] 146 | 147 | 148 | def pipeline(names, steps): 149 | """Reconstruct a Pipeline from names and steps""" 150 | if any(s is FIT_FAILURE for s in steps): 151 | return FIT_FAILURE 152 | return Pipeline(list(zip(names, steps))) 153 | 154 | 155 | def feature_union(names, steps, weights): 156 | """Reconstruct a FeatureUnion from names, steps, and weights""" 157 | if any(s is FIT_FAILURE for s in steps): 158 | return FIT_FAILURE 159 | return FeatureUnion(list(zip(names, steps)), 160 | transformer_weights=weights) 161 | 162 | 163 | def feature_union_concat(Xs, nsamples, weights): 164 | """Apply weights and concatenate outputs from a FeatureUnion""" 165 | if any(x is FIT_FAILURE for x in Xs): 166 | return FIT_FAILURE 167 | Xs = [X if w is None else X * w for X, w in zip(Xs, weights) 168 | if X is not None] 169 | if not Xs: 170 | return np.zeros((nsamples, 0)) 171 | if any(sparse.issparse(f) for f in Xs): 172 | return sparse.hstack(Xs).tocsr() 173 | return np.hstack(Xs) 174 | 175 | 176 | # Current set_params isn't threadsafe 177 | SET_PARAMS_LOCK = Lock() 178 | 179 | 180 | def set_params(est, fields=None, params=None, copy=True): 181 | if copy: 182 | est = copy_estimator(est) 183 | if fields is None: 184 | return est 185 | params = {f: p for (f, p) in zip(fields, params) if p is not MISSING} 186 | # TODO: rewrite set_params to avoid lock for classes that use the standard 187 | # set_params/get_params methods 188 | with SET_PARAMS_LOCK: 189 | return est.set_params(**params) 190 | 191 | 192 | def fit(est, X, y, error_score='raise', fields=None, params=None, 193 | fit_params=None): 194 | if est is FIT_FAILURE or X is FIT_FAILURE: 195 | return FIT_FAILURE 196 | if not fit_params: 197 | fit_params = {} 198 | try: 199 | est = set_params(est, fields, params) 200 | est.fit(X, y, **fit_params) 201 | except Exception as e: 202 | if error_score == 'raise': 203 | raise 204 | warn_fit_failure(error_score, e) 205 | est = FIT_FAILURE 206 | return est 207 | 208 | 209 | def fit_transform(est, X, y, error_score='raise', fields=None, params=None, 210 | fit_params=None): 211 | if est is FIT_FAILURE or X is FIT_FAILURE: 212 | return FIT_FAILURE, FIT_FAILURE 213 | if not fit_params: 214 | fit_params = {} 215 | try: 216 | est = set_params(est, fields, params) 217 | if hasattr(est, 'fit_transform'): 218 | Xt = est.fit_transform(X, y, **fit_params) 219 | else: 220 | est.fit(X, y, **fit_params) 221 | Xt = est.transform(X) 222 | except Exception as e: 223 | if error_score == 'raise': 224 | raise 225 | warn_fit_failure(error_score, e) 226 | est = Xt = FIT_FAILURE 227 | return est, Xt 228 | 229 | 230 | def _score(est, X, y, scorer): 231 | if est is FIT_FAILURE: 232 | return FIT_FAILURE 233 | return scorer(est, X) if y is None else scorer(est, X, y) 234 | 235 | 236 | def score(est, X_test, y_test, X_train, y_train, scorer): 237 | test_score = _score(est, X_test, y_test, scorer) 238 | if X_train is None: 239 | return test_score 240 | train_score = _score(est, X_train, y_train, scorer) 241 | return test_score, train_score 242 | 243 | 244 | def fit_and_score(est, cv, X, y, n, scorer, 245 | error_score='raise', fields=None, params=None, 246 | fit_params=None, return_train_score=True): 247 | X_train = cv.extract(X, y, n, True, True) 248 | y_train = cv.extract(X, y, n, False, True) 249 | X_test = cv.extract(X, y, n, True, False) 250 | y_test = cv.extract(X, y, n, False, False) 251 | est = fit(est, X_train, y_train, error_score, fields, params, fit_params) 252 | if not return_train_score: 253 | X_train = y_train = None 254 | return score(est, X_test, y_test, X_train, y_train, scorer) 255 | 256 | 257 | def _store(results, key_name, array, n_splits, n_candidates, 258 | weights=None, splits=False, rank=False): 259 | """A small helper to store the scores/times to the cv_results_""" 260 | # When iterated first by n_splits and then by parameters 261 | array = np.array(array, dtype=np.float64).reshape(n_splits, n_candidates).T 262 | if splits: 263 | for split_i in range(n_splits): 264 | results["split%d_%s" % (split_i, key_name)] = array[:, split_i] 265 | 266 | array_means = np.average(array, axis=1, weights=weights) 267 | results['mean_%s' % key_name] = array_means 268 | # Weighted std is not directly available in numpy 269 | array_stds = np.sqrt(np.average((array - array_means[:, np.newaxis]) ** 2, 270 | axis=1, weights=weights)) 271 | results['std_%s' % key_name] = array_stds 272 | 273 | if rank: 274 | results["rank_%s" % key_name] = np.asarray( 275 | rankdata(-array_means, method='min'), dtype=np.int32) 276 | 277 | 278 | def create_cv_results(scores, candidate_params, n_splits, error_score, weights): 279 | if isinstance(scores[0], tuple): 280 | test_scores, train_scores = zip(*scores) 281 | else: 282 | test_scores = scores 283 | train_scores = None 284 | 285 | test_scores = [error_score if s is FIT_FAILURE else s for s in test_scores] 286 | if train_scores is not None: 287 | train_scores = [error_score if s is FIT_FAILURE else s 288 | for s in train_scores] 289 | 290 | # Construct the `cv_results_` dictionary 291 | results = {'params': candidate_params} 292 | n_candidates = len(candidate_params) 293 | 294 | if weights is not None: 295 | weights = np.broadcast_to(weights[None, :], 296 | (len(candidate_params), len(weights))) 297 | 298 | _store(results, 'test_score', test_scores, n_splits, n_candidates, 299 | splits=True, rank=True, weights=weights) 300 | if train_scores is not None: 301 | _store(results, 'train_score', train_scores, 302 | n_splits, n_candidates, splits=True) 303 | 304 | # Use one MaskedArray and mask all the places where the param is not 305 | # applicable for that candidate. Use defaultdict as each candidate may 306 | # not contain all the params 307 | param_results = defaultdict(lambda: MaskedArray(np.empty(n_candidates), 308 | mask=True, 309 | dtype=object)) 310 | for cand_i, params in enumerate(candidate_params): 311 | for name, value in params.items(): 312 | param_results["param_%s" % name][cand_i] = value 313 | 314 | results.update(param_results) 315 | return results 316 | 317 | 318 | def get_best_params(candidate_params, cv_results): 319 | best_index = np.flatnonzero(cv_results["rank_test_score"] == 1)[0] 320 | return candidate_params[best_index] 321 | 322 | 323 | def fit_best(estimator, params, X, y, fit_params): 324 | estimator = copy_estimator(estimator).set_params(**params) 325 | estimator.fit(X, y, **fit_params) 326 | return estimator 327 | -------------------------------------------------------------------------------- /dask_searchcv/model_selection.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | from operator import getitem 4 | from collections import defaultdict 5 | from itertools import repeat 6 | from multiprocessing import cpu_count 7 | import numbers 8 | 9 | import numpy as np 10 | import dask 11 | from dask.base import tokenize, Base 12 | from dask.delayed import delayed 13 | from dask.threaded import get as threaded_get 14 | from dask.utils import derived_from 15 | from sklearn import model_selection 16 | from sklearn.base import is_classifier, clone, BaseEstimator, MetaEstimatorMixin 17 | from sklearn.exceptions import NotFittedError 18 | from sklearn.metrics.scorer import check_scoring 19 | from sklearn.model_selection._search import _check_param_grid, BaseSearchCV 20 | from sklearn.model_selection._split import (_BaseKFold, 21 | BaseShuffleSplit, 22 | KFold, 23 | StratifiedKFold, 24 | LeaveOneOut, 25 | LeaveOneGroupOut, 26 | LeavePOut, 27 | LeavePGroupsOut, 28 | PredefinedSplit, 29 | _CVIterableWrapper) 30 | from sklearn.pipeline import Pipeline, FeatureUnion 31 | from sklearn.utils.metaestimators import if_delegate_has_method 32 | from sklearn.utils.multiclass import type_of_target 33 | from sklearn.utils.validation import _num_samples, check_is_fitted 34 | 35 | from ._normalize import normalize_estimator 36 | from .methods import (fit, fit_transform, fit_and_score, pipeline, fit_best, 37 | get_best_params, create_cv_results, cv_split, 38 | cv_n_samples, cv_extract, cv_extract_params, 39 | decompress_params, score, feature_union, 40 | feature_union_concat, MISSING) 41 | from .utils import to_indexable, to_keys, unzip 42 | 43 | try: 44 | from cytoolz import get, pluck 45 | except: # pragma: no cover 46 | from toolz import get, pluck 47 | 48 | 49 | __all__ = ['GridSearchCV', 'RandomizedSearchCV'] 50 | 51 | 52 | class TokenIterator(object): 53 | def __init__(self, base_token): 54 | self.token = base_token 55 | self.counts = defaultdict(int) 56 | 57 | def __call__(self, est): 58 | typ = type(est) 59 | c = self.counts[typ] 60 | self.counts[typ] += 1 61 | return self.token if c == 0 else self.token + str(c) 62 | 63 | 64 | def build_graph(estimator, cv, scorer, candidate_params, X, y=None, 65 | groups=None, fit_params=None, iid=True, refit=True, 66 | error_score='raise', return_train_score=True, cache_cv=True): 67 | 68 | X, y, groups = to_indexable(X, y, groups) 69 | cv = check_cv(cv, y, is_classifier(estimator)) 70 | # "pairwise" estimators require a different graph for CV splitting 71 | is_pairwise = getattr(estimator, '_pairwise', False) 72 | 73 | dsk = {} 74 | X_name, y_name, groups_name = to_keys(dsk, X, y, groups) 75 | n_splits = compute_n_splits(cv, X, y, groups) 76 | 77 | if fit_params: 78 | # A mapping of {name: (name, graph-key)} 79 | param_values = to_indexable(*fit_params.values(), allow_scalars=True) 80 | fit_params = {k: (k, v) for (k, v) in 81 | zip(fit_params, to_keys(dsk, *param_values))} 82 | else: 83 | fit_params = {} 84 | 85 | fields, tokens, params = normalize_params(candidate_params) 86 | main_token = tokenize(normalize_estimator(estimator), fields, params, 87 | X_name, y_name, groups_name, fit_params, cv, 88 | error_score == 'raise', return_train_score) 89 | 90 | cv_name = 'cv-split-' + main_token 91 | dsk[cv_name] = (cv_split, cv, X_name, y_name, groups_name, 92 | is_pairwise, cache_cv) 93 | 94 | if iid: 95 | weights = 'cv-n-samples-' + main_token 96 | dsk[weights] = (cv_n_samples, cv_name) 97 | else: 98 | weights = None 99 | 100 | scores = do_fit_and_score(dsk, main_token, estimator, cv_name, fields, 101 | tokens, params, X_name, y_name, fit_params, 102 | n_splits, error_score, scorer, 103 | return_train_score) 104 | 105 | cv_results = 'cv-results-' + main_token 106 | candidate_params_name = 'cv-parameters-' + main_token 107 | dsk[candidate_params_name] = (decompress_params, fields, params) 108 | dsk[cv_results] = (create_cv_results, scores, candidate_params_name, 109 | n_splits, error_score, weights) 110 | keys = [cv_results] 111 | 112 | if refit: 113 | best_params = 'best-params-' + main_token 114 | dsk[best_params] = (get_best_params, candidate_params_name, cv_results) 115 | best_estimator = 'best-estimator-' + main_token 116 | if fit_params: 117 | fit_params = (dict, (zip, list(fit_params.keys()), 118 | list(pluck(1, fit_params.values())))) 119 | dsk[best_estimator] = (fit_best, clone(estimator), best_params, 120 | X_name, y_name, fit_params) 121 | keys.append(best_estimator) 122 | 123 | return dsk, keys, n_splits 124 | 125 | 126 | def normalize_params(params): 127 | """Take a list of dictionaries, and tokenize/normalize.""" 128 | # Collect a set of all fields 129 | fields = set() 130 | for p in params: 131 | fields.update(p) 132 | fields = sorted(fields) 133 | 134 | params2 = list(pluck(fields, params, MISSING)) 135 | # Non-basic types (including MISSING) are unique to their id 136 | tokens = [tuple(x if isinstance(x, (int, float, str)) else id(x) 137 | for x in p) for p in params2] 138 | 139 | return fields, tokens, params2 140 | 141 | 142 | def _get_fit_params(cv, fit_params, n_splits): 143 | if not fit_params: 144 | return [(n, None) for n in range(n_splits)] 145 | keys = [] 146 | vals = [] 147 | for name, (full_name, val) in fit_params.items(): 148 | vals.append(val) 149 | keys.append((name, full_name)) 150 | return [(n, (cv_extract_params, cv, keys, vals, n)) 151 | for n in range(n_splits)] 152 | 153 | 154 | def _group_fit_params(steps, fit_params): 155 | param_lk = {n: {} for n, _ in steps} 156 | for pname, pval in fit_params.items(): 157 | step, param = pname.split('__', 1) 158 | param_lk[step][param] = pval 159 | return param_lk 160 | 161 | 162 | def do_fit_and_score(dsk, main_token, est, cv, fields, tokens, params, 163 | X, y, fit_params, n_splits, error_score, scorer, 164 | return_train_score): 165 | if not isinstance(est, Pipeline): 166 | # Fitting and scoring can all be done as a single task 167 | n_and_fit_params = _get_fit_params(cv, fit_params, n_splits) 168 | 169 | est_type = type(est).__name__.lower() 170 | est_name = '%s-%s' % (est_type, main_token) 171 | score_name = '%s-fit-score-%s' % (est_type, main_token) 172 | dsk[est_name] = est 173 | 174 | seen = {} 175 | m = 0 176 | out = [] 177 | out_append = out.append 178 | 179 | for t, p in zip(tokens, params): 180 | if t in seen: 181 | out_append(seen[t]) 182 | else: 183 | for n, fit_params in n_and_fit_params: 184 | dsk[(score_name, m, n)] = (fit_and_score, est_name, cv, 185 | X, y, n, scorer, error_score, 186 | fields, p, fit_params, 187 | return_train_score) 188 | seen[t] = (score_name, m) 189 | out_append((score_name, m)) 190 | m += 1 191 | scores = [k + (n,) for n in range(n_splits) for k in out] 192 | else: 193 | X_train = (cv_extract, cv, X, y, True, True) 194 | X_test = (cv_extract, cv, X, y, True, False) 195 | y_train = (cv_extract, cv, X, y, False, True) 196 | y_test = (cv_extract, cv, X, y, False, False) 197 | 198 | # Fit the estimator on the training data 199 | X_trains = [X_train] * len(params) 200 | y_trains = [y_train] * len(params) 201 | fit_ests = do_fit(dsk, TokenIterator(main_token), est, cv, 202 | fields, tokens, params, X_trains, y_trains, 203 | fit_params, n_splits, error_score) 204 | 205 | score_name = 'score-' + main_token 206 | 207 | scores = [] 208 | scores_append = scores.append 209 | for n in range(n_splits): 210 | if return_train_score: 211 | xtrain = X_train + (n,) 212 | ytrain = y_train + (n,) 213 | else: 214 | xtrain = ytrain = None 215 | 216 | xtest = X_test + (n,) 217 | ytest = y_test + (n,) 218 | 219 | for (name, m) in fit_ests: 220 | dsk[(score_name, m, n)] = (score, (name, m, n), 221 | xtest, ytest, xtrain, ytrain, scorer) 222 | scores_append((score_name, m, n)) 223 | return scores 224 | 225 | 226 | def do_fit(dsk, next_token, est, cv, fields, tokens, params, Xs, ys, 227 | fit_params, n_splits, error_score): 228 | if isinstance(est, Pipeline) and params is not None: 229 | return _do_pipeline(dsk, next_token, est, cv, fields, tokens, params, 230 | Xs, ys, fit_params, n_splits, error_score, False) 231 | else: 232 | n_and_fit_params = _get_fit_params(cv, fit_params, n_splits) 233 | 234 | if params is None: 235 | params = tokens = repeat(None) 236 | fields = None 237 | 238 | token = next_token(est) 239 | est_type = type(est).__name__.lower() 240 | est_name = '%s-%s' % (est_type, token) 241 | fit_name = '%s-fit-%s' % (est_type, token) 242 | dsk[est_name] = est 243 | 244 | seen = {} 245 | m = 0 246 | out = [] 247 | out_append = out.append 248 | 249 | for X, y, t, p in zip(Xs, ys, tokens, params): 250 | if (X, y, t) in seen: 251 | out_append(seen[X, y, t]) 252 | else: 253 | for n, fit_params in n_and_fit_params: 254 | dsk[(fit_name, m, n)] = (fit, est_name, X + (n,), 255 | y + (n,), error_score, 256 | fields, p, fit_params) 257 | seen[(X, y, t)] = (fit_name, m) 258 | out_append((fit_name, m)) 259 | m += 1 260 | 261 | return out 262 | 263 | 264 | def do_fit_transform(dsk, next_token, est, cv, fields, tokens, params, Xs, ys, 265 | fit_params, n_splits, error_score): 266 | if isinstance(est, Pipeline) and params is not None: 267 | return _do_pipeline(dsk, next_token, est, cv, fields, tokens, params, 268 | Xs, ys, fit_params, n_splits, error_score, True) 269 | elif isinstance(est, FeatureUnion) and params is not None: 270 | return _do_featureunion(dsk, next_token, est, cv, fields, tokens, 271 | params, Xs, ys, fit_params, n_splits, 272 | error_score) 273 | else: 274 | n_and_fit_params = _get_fit_params(cv, fit_params, n_splits) 275 | 276 | if params is None: 277 | params = tokens = repeat(None) 278 | fields = None 279 | 280 | name = type(est).__name__.lower() 281 | token = next_token(est) 282 | fit_Xt_name = '%s-fit-transform-%s' % (name, token) 283 | fit_name = '%s-fit-%s' % (name, token) 284 | Xt_name = '%s-transform-%s' % (name, token) 285 | est_name = '%s-%s' % (type(est).__name__.lower(), token) 286 | dsk[est_name] = est 287 | 288 | seen = {} 289 | m = 0 290 | out = [] 291 | out_append = out.append 292 | 293 | for X, y, t, p in zip(Xs, ys, tokens, params): 294 | if (X, y, t) in seen: 295 | out_append(seen[X, y, t]) 296 | else: 297 | for n, fit_params in n_and_fit_params: 298 | dsk[(fit_Xt_name, m, n)] = (fit_transform, est_name, 299 | X + (n,), y + (n,), 300 | error_score, fields, p, 301 | fit_params) 302 | dsk[(fit_name, m, n)] = (getitem, (fit_Xt_name, m, n), 0) 303 | dsk[(Xt_name, m, n)] = (getitem, (fit_Xt_name, m, n), 1) 304 | seen[X, y, t] = m 305 | out_append(m) 306 | m += 1 307 | 308 | return [(fit_name, i) for i in out], [(Xt_name, i) for i in out] 309 | 310 | 311 | def _group_subparams(steps, fields, ignore=()): 312 | # Group the fields into a mapping of {stepname: [(newname, orig_index)]} 313 | field_to_index = dict(zip(fields, range(len(fields)))) 314 | step_fields_lk = {s: [] for s, _ in steps} 315 | for f in fields: 316 | if '__' in f: 317 | step, param = f.split('__', 1) 318 | if step in step_fields_lk: 319 | step_fields_lk[step].append((param, field_to_index[f])) 320 | continue 321 | if f not in step_fields_lk and f not in ignore: 322 | raise ValueError("Unknown parameter: `%s`" % f) 323 | return field_to_index, step_fields_lk 324 | 325 | 326 | def _group_ids_by_index(index, tokens): 327 | id_groups = [] 328 | 329 | def new_group(): 330 | o = [] 331 | id_groups.append(o) 332 | return o.append 333 | 334 | _id_groups = defaultdict(new_group) 335 | for n, t in enumerate(pluck(index, tokens)): 336 | _id_groups[t](n) 337 | return id_groups 338 | 339 | 340 | def _do_fit_step(dsk, next_token, step, cv, fields, tokens, params, Xs, ys, 341 | fit_params, n_splits, error_score, step_fields_lk, 342 | fit_params_lk, field_to_index, step_name, none_passthrough, 343 | is_transform): 344 | sub_fields, sub_inds = map(list, unzip(step_fields_lk[step_name], 2)) 345 | sub_fit_params = fit_params_lk[step_name] 346 | 347 | if step_name in field_to_index: 348 | # The estimator may change each call 349 | new_fits = {} 350 | new_Xs = {} 351 | est_index = field_to_index[step_name] 352 | 353 | for ids in _group_ids_by_index(est_index, tokens): 354 | # Get the estimator for this subgroup 355 | sub_est = params[ids[0]][est_index] 356 | if sub_est is MISSING: 357 | sub_est = step 358 | 359 | # If an estimator is `None`, there's nothing to do 360 | if sub_est is None: 361 | nones = dict.fromkeys(ids, None) 362 | new_fits.update(nones) 363 | if is_transform: 364 | if none_passthrough: 365 | new_Xs.update(zip(ids, get(ids, Xs))) 366 | else: 367 | new_Xs.update(nones) 368 | else: 369 | # Extract the proper subset of Xs, ys 370 | sub_Xs = get(ids, Xs) 371 | sub_ys = get(ids, ys) 372 | # Only subset the parameters/tokens if necessary 373 | if sub_fields: 374 | sub_tokens = list(pluck(sub_inds, get(ids, tokens))) 375 | sub_params = list(pluck(sub_inds, get(ids, params))) 376 | else: 377 | sub_tokens = sub_params = None 378 | 379 | if is_transform: 380 | sub_fits, sub_Xs = do_fit_transform(dsk, next_token, 381 | sub_est, cv, sub_fields, 382 | sub_tokens, sub_params, 383 | sub_Xs, sub_ys, 384 | sub_fit_params, 385 | n_splits, error_score) 386 | new_Xs.update(zip(ids, sub_Xs)) 387 | new_fits.update(zip(ids, sub_fits)) 388 | else: 389 | sub_fits = do_fit(dsk, next_token, sub_est, cv, 390 | sub_fields, sub_tokens, sub_params, 391 | sub_Xs, sub_ys, sub_fit_params, 392 | n_splits, error_score) 393 | new_fits.update(zip(ids, sub_fits)) 394 | # Extract lists of transformed Xs and fit steps 395 | all_ids = list(range(len(Xs))) 396 | if is_transform: 397 | Xs = get(all_ids, new_Xs) 398 | fits = get(all_ids, new_fits) 399 | elif step is None: 400 | # Nothing to do 401 | fits = [None] * len(Xs) 402 | if not none_passthrough: 403 | Xs = fits 404 | else: 405 | # Only subset the parameters/tokens if necessary 406 | if sub_fields: 407 | sub_tokens = list(pluck(sub_inds, tokens)) 408 | sub_params = list(pluck(sub_inds, params)) 409 | else: 410 | sub_tokens = sub_params = None 411 | 412 | if is_transform: 413 | fits, Xs = do_fit_transform(dsk, next_token, step, cv, 414 | sub_fields, sub_tokens, sub_params, 415 | Xs, ys, sub_fit_params, n_splits, 416 | error_score) 417 | else: 418 | fits = do_fit(dsk, next_token, step, cv, sub_fields, 419 | sub_tokens, sub_params, Xs, ys, sub_fit_params, 420 | n_splits, error_score) 421 | return (fits, Xs) if is_transform else (fits, None) 422 | 423 | 424 | def _do_pipeline(dsk, next_token, est, cv, fields, tokens, params, Xs, ys, 425 | fit_params, n_splits, error_score, is_transform): 426 | if 'steps' in fields: 427 | raise NotImplementedError("Setting Pipeline.steps in a gridsearch") 428 | 429 | field_to_index, step_fields_lk = _group_subparams(est.steps, fields) 430 | fit_params_lk = _group_fit_params(est.steps, fit_params) 431 | 432 | # A list of (step, is_transform) 433 | instrs = [(s, True) for s in est.steps[:-1]] 434 | instrs.append((est.steps[-1], is_transform)) 435 | 436 | fit_steps = [] 437 | for (step_name, step), transform in instrs: 438 | fits, Xs = _do_fit_step(dsk, next_token, step, cv, fields, tokens, 439 | params, Xs, ys, fit_params, n_splits, 440 | error_score, step_fields_lk, fit_params_lk, 441 | field_to_index, step_name, True, transform) 442 | fit_steps.append(fits) 443 | 444 | # Rebuild the pipelines 445 | step_names = [n for n, _ in est.steps] 446 | out_ests = [] 447 | out_ests_append = out_ests.append 448 | name = 'pipeline-' + next_token(est) 449 | m = 0 450 | seen = {} 451 | for steps in zip(*fit_steps): 452 | if steps in seen: 453 | out_ests_append(seen[steps]) 454 | else: 455 | for n in range(n_splits): 456 | dsk[(name, m, n)] = (pipeline, step_names, 457 | [None if s is None else s + (n,) 458 | for s in steps]) 459 | seen[steps] = (name, m) 460 | out_ests_append((name, m)) 461 | m += 1 462 | 463 | if is_transform: 464 | return out_ests, Xs 465 | return out_ests 466 | 467 | 468 | def _do_n_samples(dsk, token, Xs, n_splits): 469 | name = 'n_samples-' + token 470 | n_samples = [] 471 | n_samples_append = n_samples.append 472 | seen = {} 473 | m = 0 474 | for x in Xs: 475 | if x in seen: 476 | n_samples_append(seen[x]) 477 | else: 478 | for n in range(n_splits): 479 | dsk[name, m, n] = (_num_samples, x + (n,)) 480 | n_samples_append((name, m)) 481 | seen[x] = (name, m) 482 | m += 1 483 | return n_samples 484 | 485 | 486 | def _do_featureunion(dsk, next_token, est, cv, fields, tokens, params, Xs, ys, 487 | fit_params, n_splits, error_score): 488 | if 'transformer_list' in fields: 489 | raise NotImplementedError("Setting FeatureUnion.transformer_list " 490 | "in a gridsearch") 491 | 492 | (field_to_index, 493 | step_fields_lk) = _group_subparams(est.transformer_list, fields, 494 | ignore=('transformer_weights')) 495 | fit_params_lk = _group_fit_params(est.transformer_list, fit_params) 496 | 497 | token = next_token(est) 498 | 499 | n_samples = _do_n_samples(dsk, token, Xs, n_splits) 500 | 501 | fit_steps = [] 502 | tr_Xs = [] 503 | for (step_name, step) in est.transformer_list: 504 | fits, out_Xs = _do_fit_step(dsk, next_token, step, cv, fields, tokens, 505 | params, Xs, ys, fit_params, n_splits, 506 | error_score, step_fields_lk, fit_params_lk, 507 | field_to_index, step_name, False, True) 508 | fit_steps.append(fits) 509 | tr_Xs.append(out_Xs) 510 | 511 | # Rebuild the FeatureUnions 512 | step_names = [n for n, _ in est.transformer_list] 513 | 514 | if 'transformer_weights' in field_to_index: 515 | index = field_to_index['transformer_weights'] 516 | weight_lk = {} 517 | weight_tokens = list(pluck(index, tokens)) 518 | for i, tok in enumerate(weight_tokens): 519 | if tok not in weight_lk: 520 | weights = params[i][index] 521 | if weights is MISSING: 522 | weights = est.transformer_weights 523 | lk = weights or {} 524 | weight_list = [lk.get(n) for n in step_names] 525 | weight_lk[tok] = (weights, weight_list) 526 | weights = get(weight_tokens, weight_lk) 527 | else: 528 | lk = est.transformer_weights or {} 529 | weight_list = [lk.get(n) for n in step_names] 530 | weight_tokens = repeat(None) 531 | weights = repeat((est.transformer_weights, weight_list)) 532 | 533 | out = [] 534 | out_append = out.append 535 | fit_name = 'feature-union-' + token 536 | tr_name = 'feature-union-concat-' + token 537 | m = 0 538 | seen = {} 539 | for steps, Xs, wt, (w, wl), nsamp in zip(zip(*fit_steps), zip(*tr_Xs), 540 | weight_tokens, weights, n_samples): 541 | if (steps, wt) in seen: 542 | out_append(seen[steps, wt]) 543 | else: 544 | for n in range(n_splits): 545 | dsk[(fit_name, m, n)] = (feature_union, step_names, 546 | [None if s is None else s + (n,) 547 | for s in steps], w) 548 | dsk[(tr_name, m, n)] = (feature_union_concat, 549 | [None if x is None else x + (n,) 550 | for x in Xs], nsamp + (n,), wl) 551 | seen[steps, wt] = m 552 | out_append(m) 553 | m += 1 554 | return [(fit_name, i) for i in out], [(tr_name, i) for i in out] 555 | 556 | 557 | # ------------ # 558 | # CV splitting # 559 | # ------------ # 560 | 561 | def check_cv(cv=3, y=None, classifier=False): 562 | """Dask aware version of ``sklearn.model_selection.check_cv`` 563 | 564 | Same as the scikit-learn version, but works if ``y`` is a dask object. 565 | """ 566 | if cv is None: 567 | cv = 3 568 | 569 | # If ``cv`` is not an integer, the scikit-learn implementation doesn't 570 | # touch the ``y`` object, so passing on a dask object is fine 571 | if not isinstance(y, Base) or not isinstance(cv, numbers.Integral): 572 | return model_selection.check_cv(cv, y, classifier) 573 | 574 | if classifier: 575 | # ``y`` is a dask object. We need to compute the target type 576 | target_type = delayed(type_of_target, pure=True)(y).compute() 577 | if target_type in ('binary', 'multiclass'): 578 | return StratifiedKFold(cv) 579 | return KFold(cv) 580 | 581 | 582 | def compute_n_splits(cv, X, y=None, groups=None): 583 | """Return the number of splits. 584 | 585 | Parameters 586 | ---------- 587 | cv : BaseCrossValidator 588 | X, y, groups : array_like, dask object, or None 589 | 590 | Returns 591 | ------- 592 | n_splits : int 593 | """ 594 | if not any(isinstance(i, Base) for i in (X, y, groups)): 595 | return cv.get_n_splits(X, y, groups) 596 | 597 | if isinstance(cv, (_BaseKFold, BaseShuffleSplit)): 598 | return cv.n_splits 599 | 600 | elif isinstance(cv, PredefinedSplit): 601 | return len(cv.unique_folds) 602 | 603 | elif isinstance(cv, _CVIterableWrapper): 604 | return len(cv.cv) 605 | 606 | elif isinstance(cv, (LeaveOneOut, LeavePOut)) and not isinstance(X, Base): 607 | # Only `X` is referenced for these classes 608 | return cv.get_n_splits(X, None, None) 609 | 610 | elif (isinstance(cv, (LeaveOneGroupOut, LeavePGroupsOut)) and not 611 | isinstance(groups, Base)): 612 | # Only `groups` is referenced for these classes 613 | return cv.get_n_splits(None, None, groups) 614 | 615 | else: 616 | return delayed(cv).get_n_splits(X, y, groups).compute() 617 | 618 | 619 | def _normalize_n_jobs(n_jobs): 620 | if not isinstance(n_jobs, int): 621 | raise TypeError("n_jobs should be an int, got %s" % n_jobs) 622 | if n_jobs == -1: 623 | n_jobs = None # Scheduler default is use all cores 624 | elif n_jobs < -1: 625 | n_jobs = cpu_count() + 1 + n_jobs 626 | return n_jobs 627 | 628 | 629 | _scheduler_aliases = {'sync': 'synchronous', 630 | 'sequential': 'synchronous', 631 | 'threaded': 'threading'} 632 | 633 | 634 | def _normalize_scheduler(scheduler, n_jobs, loop=None): 635 | # Default 636 | if scheduler is None: 637 | scheduler = dask.context._globals.get('get') 638 | if scheduler is None: 639 | scheduler = dask.get if n_jobs == 1 else threaded_get 640 | return scheduler 641 | 642 | # Get-functions 643 | if callable(scheduler): 644 | return scheduler 645 | 646 | # Support name aliases 647 | if isinstance(scheduler, str): 648 | scheduler = _scheduler_aliases.get(scheduler, scheduler) 649 | 650 | if scheduler in ('threading', 'multiprocessing') and n_jobs == 1: 651 | scheduler = dask.get 652 | elif scheduler == 'threading': 653 | scheduler = threaded_get 654 | elif scheduler == 'multiprocessing': 655 | from dask.multiprocessing import get as scheduler 656 | elif scheduler == 'synchronous': 657 | scheduler = dask.get 658 | else: 659 | try: 660 | from dask.distributed import Client 661 | # We pass loop to make testing possible, not needed for normal use 662 | return Client(scheduler, set_as_default=False, loop=loop).get 663 | except Exception as e: 664 | msg = ("Failed to initialize scheduler from parameter %r. " 665 | "This could be due to a typo, or a failure to initialize " 666 | "the distributed scheduler. Original error is below:\n\n" 667 | "%r" % (scheduler, e)) 668 | # Re-raise outside the except to provide a cleaner error message 669 | raise ValueError(msg) 670 | return scheduler 671 | 672 | 673 | class DaskBaseSearchCV(BaseEstimator, MetaEstimatorMixin): 674 | """Base class for hyper parameter search with cross-validation.""" 675 | 676 | def __init__(self, estimator, scoring=None, iid=True, refit=True, cv=None, 677 | error_score='raise', return_train_score=True, scheduler=None, 678 | n_jobs=-1, cache_cv=True): 679 | self.scoring = scoring 680 | self.estimator = estimator 681 | self.iid = iid 682 | self.refit = refit 683 | self.cv = cv 684 | self.error_score = error_score 685 | self.return_train_score = return_train_score 686 | self.scheduler = scheduler 687 | self.n_jobs = n_jobs 688 | self.cache_cv = cache_cv 689 | 690 | @property 691 | def _estimator_type(self): 692 | return self.estimator._estimator_type 693 | 694 | @property 695 | def best_params_(self): 696 | check_is_fitted(self, 'cv_results_') 697 | return self.cv_results_['params'][self.best_index_] 698 | 699 | @property 700 | def best_score_(self): 701 | check_is_fitted(self, 'cv_results_') 702 | return self.cv_results_['mean_test_score'][self.best_index_] 703 | 704 | def _check_is_fitted(self, method_name): 705 | if not self.refit: 706 | msg = ('This {0} instance was initialized with refit=False. {1} ' 707 | 'is available only after refitting on the best ' 708 | 'parameters.').format(type(self).__name__, method_name) 709 | raise NotFittedError(msg) 710 | else: 711 | check_is_fitted(self, 'best_estimator_') 712 | 713 | @property 714 | def classes_(self): 715 | self._check_is_fitted("classes_") 716 | return self.best_estimator_.classes_ 717 | 718 | @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) 719 | @derived_from(BaseSearchCV) 720 | def predict(self, X): 721 | self._check_is_fitted('predict') 722 | return self.best_estimator_.predict(X) 723 | 724 | @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) 725 | @derived_from(BaseSearchCV) 726 | def predict_proba(self, X): 727 | self._check_is_fitted('predict_proba') 728 | return self.best_estimator_.predict_proba(X) 729 | 730 | @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) 731 | @derived_from(BaseSearchCV) 732 | def predict_log_proba(self, X): 733 | self._check_is_fitted('predict_log_proba') 734 | return self.best_estimator_.predict_log_proba(X) 735 | 736 | @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) 737 | @derived_from(BaseSearchCV) 738 | def decision_function(self, X): 739 | self._check_is_fitted('decision_function') 740 | return self.best_estimator_.decision_function(X) 741 | 742 | @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) 743 | @derived_from(BaseSearchCV) 744 | def transform(self, X): 745 | self._check_is_fitted('transform') 746 | return self.best_estimator_.transform(X) 747 | 748 | @if_delegate_has_method(delegate=('best_estimator_', 'estimator')) 749 | @derived_from(BaseSearchCV) 750 | def inverse_transform(self, Xt): 751 | self._check_is_fitted('inverse_transform') 752 | return self.best_estimator_.transform(Xt) 753 | 754 | @derived_from(BaseSearchCV) 755 | def score(self, X, y=None): 756 | if self.scorer_ is None: 757 | raise ValueError("No score function explicitly defined, " 758 | "and the estimator doesn't provide one %s" 759 | % self.best_estimator_) 760 | return self.scorer_(self.best_estimator_, X, y) 761 | 762 | def fit(self, X, y=None, groups=None, **fit_params): 763 | """Run fit with all sets of parameters. 764 | 765 | Parameters 766 | ---------- 767 | X : array-like, shape = [n_samples, n_features] 768 | Training vector, where n_samples is the number of samples and 769 | n_features is the number of features. 770 | y : array-like, shape = [n_samples] or [n_samples, n_output], optional 771 | Target relative to X for classification or regression; 772 | None for unsupervised learning. 773 | groups : array-like, shape = [n_samples], optional 774 | Group labels for the samples used while splitting the dataset into 775 | train/test set. 776 | **fit_params 777 | Parameters passed to the ``fit`` method of the estimator 778 | """ 779 | estimator = self.estimator 780 | self.scorer_ = check_scoring(estimator, scoring=self.scoring) 781 | error_score = self.error_score 782 | if not (isinstance(error_score, numbers.Number) or 783 | error_score == 'raise'): 784 | raise ValueError("error_score must be the string 'raise' or a" 785 | " numeric value.") 786 | 787 | dsk, keys, n_splits = build_graph(estimator, self.cv, self.scorer_, 788 | list(self._get_param_iterator()), 789 | X, y, groups, fit_params, 790 | iid=self.iid, 791 | refit=self.refit, 792 | error_score=error_score, 793 | return_train_score=self.return_train_score, 794 | cache_cv=self.cache_cv) 795 | self.dask_graph_ = dsk 796 | self.n_splits_ = n_splits 797 | 798 | n_jobs = _normalize_n_jobs(self.n_jobs) 799 | scheduler = _normalize_scheduler(self.scheduler, n_jobs) 800 | 801 | out = scheduler(dsk, keys, num_workers=n_jobs) 802 | 803 | self.cv_results_ = results = out[0] 804 | self.best_index_ = np.flatnonzero(results["rank_test_score"] == 1)[0] 805 | 806 | if self.refit: 807 | self.best_estimator_ = out[1] 808 | return self 809 | 810 | def visualize(self, filename='mydask', format=None, **kwargs): 811 | """Render the task graph for this parameter search using ``graphviz``. 812 | 813 | Requires ``graphviz`` to be installed. 814 | 815 | Parameters 816 | ---------- 817 | filename : str or None, optional 818 | The name (without an extension) of the file to write to disk. If 819 | `filename` is None, no file will be written, and we communicate 820 | with dot using only pipes. 821 | format : {'png', 'pdf', 'dot', 'svg', 'jpeg', 'jpg'}, optional 822 | Format in which to write output file. Default is 'png'. 823 | **kwargs 824 | Additional keyword arguments to forward to ``dask.dot.to_graphviz``. 825 | 826 | Returns 827 | ------- 828 | result : IPython.diplay.Image, IPython.display.SVG, or None 829 | See ``dask.dot.dot_graph`` for more information. 830 | """ 831 | check_is_fitted(self, 'dask_graph_') 832 | return dask.visualize(self.dask_graph_, filename=filename, 833 | format=format, **kwargs) 834 | 835 | 836 | _DOC_TEMPLATE = """{oneliner} 837 | 838 | {name} implements a "fit" and a "score" method. 839 | It also implements "predict", "predict_proba", "decision_function", 840 | "transform" and "inverse_transform" if they are implemented in the 841 | estimator used. 842 | 843 | {description} 844 | 845 | Parameters 846 | ---------- 847 | estimator : estimator object. 848 | This is assumed to implement the scikit-learn estimator interface. 849 | Either estimator needs to provide a ``score`` function, 850 | or ``scoring`` must be passed. 851 | 852 | {parameters} 853 | 854 | scoring : string, callable or None, default=None 855 | A string (see model evaluation documentation) or 856 | a scorer callable object / function with signature 857 | ``scorer(estimator, X, y)``. 858 | If ``None``, the ``score`` method of the estimator is used. 859 | 860 | iid : boolean, default=True 861 | If True, the data is assumed to be identically distributed across 862 | the folds, and the loss minimized is the total loss per sample, 863 | and not the mean loss across the folds. 864 | 865 | cv : int, cross-validation generator or an iterable, optional 866 | Determines the cross-validation splitting strategy. 867 | Possible inputs for cv are: 868 | - None, to use the default 3-fold cross validation, 869 | - integer, to specify the number of folds in a ``(Stratified)KFold``, 870 | - An object to be used as a cross-validation generator. 871 | - An iterable yielding train, test splits. 872 | 873 | For integer/None inputs, if the estimator is a classifier and ``y`` is 874 | either binary or multiclass, ``StratifiedKFold`` is used. In all 875 | other cases, ``KFold`` is used. 876 | 877 | refit : boolean, default=True 878 | Refit the best estimator with the entire dataset. 879 | If "False", it is impossible to make predictions using 880 | this {name} instance after fitting. 881 | 882 | error_score : 'raise' (default) or numeric 883 | Value to assign to the score if an error occurs in estimator fitting. 884 | If set to 'raise', the error is raised. If a numeric value is given, 885 | FitFailedWarning is raised. This parameter does not affect the refit 886 | step, which will always raise the error. 887 | 888 | return_train_score : boolean, default=True 889 | If ``'False'``, the ``cv_results_`` attribute will not include training 890 | scores. 891 | 892 | scheduler : string, callable, or None, default=None 893 | The dask scheduler to use. Default is to use the global scheduler if set, 894 | and fallback to the threaded scheduler otherwise. To use a different 895 | scheduler, specify it by name (either "threading", "multiprocessing", 896 | or "synchronous") or provide the scheduler ``get`` function. Other 897 | arguments are assumed to be the address of a distributed scheduler, 898 | and passed to ``dask.distributed.Client``. 899 | 900 | n_jobs : int, default=-1 901 | Number of jobs to run in parallel. Ignored for the synchronous and 902 | distributed schedulers. If ``n_jobs == -1`` [default] all cpus are used. 903 | For ``n_jobs < -1``, ``(n_cpus + 1 + n_jobs)`` are used. 904 | 905 | cache_cv : bool, default=True 906 | Whether to extract each train/test subset at most once in each worker 907 | process, or every time that subset is needed. Caching the splits can 908 | speedup computation at the cost of increased memory usage per worker 909 | process. 910 | 911 | If True, worst case memory usage is ``(n_splits + 1) * (X.nbytes + 912 | y.nbytes)`` per worker. If False, worst case memory usage is 913 | ``(n_threads_per_worker + 1) * (X.nbytes + y.nbytes)`` per worker. 914 | 915 | Examples 916 | -------- 917 | {example} 918 | 919 | Attributes 920 | ---------- 921 | cv_results_ : dict of numpy (masked) ndarrays 922 | A dict with keys as column headers and values as columns, that can be 923 | imported into a pandas ``DataFrame``. 924 | 925 | For instance the below given table 926 | 927 | +------------+-----------+------------+-----------------+---+---------+ 928 | |param_kernel|param_gamma|param_degree|split0_test_score|...|rank.....| 929 | +============+===========+============+=================+===+=========+ 930 | | 'poly' | -- | 2 | 0.8 |...| 2 | 931 | +------------+-----------+------------+-----------------+---+---------+ 932 | | 'poly' | -- | 3 | 0.7 |...| 4 | 933 | +------------+-----------+------------+-----------------+---+---------+ 934 | | 'rbf' | 0.1 | -- | 0.8 |...| 3 | 935 | +------------+-----------+------------+-----------------+---+---------+ 936 | | 'rbf' | 0.2 | -- | 0.9 |...| 1 | 937 | +------------+-----------+------------+-----------------+---+---------+ 938 | 939 | will be represented by a ``cv_results_`` dict of:: 940 | 941 | {{ 942 | 'param_kernel': masked_array(data = ['poly', 'poly', 'rbf', 'rbf'], 943 | mask = [False False False False]...) 944 | 'param_gamma': masked_array(data = [-- -- 0.1 0.2], 945 | mask = [ True True False False]...), 946 | 'param_degree': masked_array(data = [2.0 3.0 -- --], 947 | mask = [False False True True]...), 948 | 'split0_test_score' : [0.8, 0.7, 0.8, 0.9], 949 | 'split1_test_score' : [0.82, 0.5, 0.7, 0.78], 950 | 'mean_test_score' : [0.81, 0.60, 0.75, 0.82], 951 | 'std_test_score' : [0.02, 0.01, 0.03, 0.03], 952 | 'rank_test_score' : [2, 4, 3, 1], 953 | 'split0_train_score' : [0.8, 0.9, 0.7], 954 | 'split1_train_score' : [0.82, 0.5, 0.7], 955 | 'mean_train_score' : [0.81, 0.7, 0.7], 956 | 'std_train_score' : [0.03, 0.03, 0.04], 957 | 'params' : [{{'kernel': 'poly', 'degree': 2}}, ...], 958 | }} 959 | 960 | NOTE that the key ``'params'`` is used to store a list of parameter 961 | settings dict for all the parameter candidates. 962 | 963 | best_estimator_ : estimator 964 | Estimator that was chosen by the search, i.e. estimator 965 | which gave highest score (or smallest loss if specified) 966 | on the left out data. Not available if refit=False. 967 | 968 | best_score_ : float 969 | Score of best_estimator on the left out data. 970 | 971 | best_params_ : dict 972 | Parameter setting that gave the best results on the hold out data. 973 | 974 | best_index_ : int 975 | The index (of the ``cv_results_`` arrays) which corresponds to the best 976 | candidate parameter setting. 977 | 978 | The dict at ``search.cv_results_['params'][search.best_index_]`` gives 979 | the parameter setting for the best model, that gives the highest 980 | mean score (``search.best_score_``). 981 | 982 | scorer_ : function 983 | Scorer function used on the held out data to choose the best 984 | parameters for the model. 985 | 986 | n_splits_ : int 987 | The number of cross-validation splits (folds/iterations). 988 | 989 | Notes 990 | ------ 991 | The parameters selected are those that maximize the score of the left out 992 | data, unless an explicit score is passed in which case it is used instead. 993 | """ 994 | 995 | # ------------ # 996 | # GridSearchCV # 997 | # ------------ # 998 | 999 | _grid_oneliner = """\ 1000 | Exhaustive search over specified parameter values for an estimator.\ 1001 | """ 1002 | _grid_description = """\ 1003 | The parameters of the estimator used to apply these methods are optimized 1004 | by cross-validated grid-search over a parameter grid.\ 1005 | """ 1006 | _grid_parameters = """\ 1007 | param_grid : dict or list of dictionaries 1008 | Dictionary with parameters names (string) as keys and lists of 1009 | parameter settings to try as values, or a list of such 1010 | dictionaries, in which case the grids spanned by each dictionary 1011 | in the list are explored. This enables searching over any sequence 1012 | of parameter settings.\ 1013 | """ 1014 | _grid_example = """\ 1015 | >>> import dask_searchcv as dcv 1016 | >>> from sklearn import svm, datasets 1017 | >>> iris = datasets.load_iris() 1018 | >>> parameters = {'kernel': ['linear', 'rbf'], 'C': [1, 10]} 1019 | >>> svc = svm.SVC() 1020 | >>> clf = dcv.GridSearchCV(svc, parameters) 1021 | >>> clf.fit(iris.data, iris.target) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS 1022 | GridSearchCV(cache_cv=..., cv=..., error_score=..., 1023 | estimator=SVC(C=..., cache_size=..., class_weight=..., coef0=..., 1024 | decision_function_shape=..., degree=..., gamma=..., 1025 | kernel=..., max_iter=-1, probability=False, 1026 | random_state=..., shrinking=..., tol=..., 1027 | verbose=...), 1028 | iid=..., n_jobs=..., param_grid=..., refit=..., return_train_score=..., 1029 | scheduler=..., scoring=...) 1030 | >>> sorted(clf.cv_results_.keys()) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS 1031 | ['mean_test_score', 'mean_train_score', 'param_C', 'param_kernel',... 1032 | 'params', 'rank_test_score', 'split0_test_score', 'split0_train_score',... 1033 | 'split1_test_score', 'split1_train_score', 'split2_test_score',... 1034 | 'split2_train_score', 'std_test_score', 'std_train_score'...]\ 1035 | """ 1036 | 1037 | 1038 | class GridSearchCV(DaskBaseSearchCV): 1039 | __doc__ = _DOC_TEMPLATE.format(name="GridSearchCV", 1040 | oneliner=_grid_oneliner, 1041 | description=_grid_description, 1042 | parameters=_grid_parameters, 1043 | example=_grid_example) 1044 | 1045 | def __init__(self, estimator, param_grid, scoring=None, iid=True, 1046 | refit=True, cv=None, error_score='raise', 1047 | return_train_score=True, scheduler=None, n_jobs=-1, 1048 | cache_cv=True): 1049 | super(GridSearchCV, self).__init__(estimator=estimator, 1050 | scoring=scoring, iid=iid, refit=refit, cv=cv, 1051 | error_score=error_score, return_train_score=return_train_score, 1052 | scheduler=scheduler, n_jobs=n_jobs, cache_cv=cache_cv) 1053 | 1054 | _check_param_grid(param_grid) 1055 | self.param_grid = param_grid 1056 | 1057 | def _get_param_iterator(self): 1058 | """Return ParameterGrid instance for the given param_grid""" 1059 | return model_selection.ParameterGrid(self.param_grid) 1060 | 1061 | 1062 | # ------------------ # 1063 | # RandomizedSearchCV # 1064 | # ------------------ # 1065 | 1066 | _randomized_oneliner = "Randomized search on hyper parameters." 1067 | _randomized_description = """\ 1068 | In contrast to GridSearchCV, not all parameter values are tried out, but 1069 | rather a fixed number of parameter settings is sampled from the specified 1070 | distributions. The number of parameter settings that are tried is 1071 | given by n_iter. 1072 | 1073 | If all parameters are presented as a list, sampling without replacement is 1074 | performed. If at least one parameter is given as a distribution, sampling 1075 | with replacement is used. It is highly recommended to use continuous 1076 | distributions for continuous parameters.\ 1077 | """ 1078 | _randomized_parameters = """\ 1079 | param_distributions : dict 1080 | Dictionary with parameters names (string) as keys and distributions 1081 | or lists of parameters to try. Distributions must provide a ``rvs`` 1082 | method for sampling (such as those from scipy.stats.distributions). 1083 | If a list is given, it is sampled uniformly. 1084 | 1085 | n_iter : int, default=10 1086 | Number of parameter settings that are sampled. n_iter trades 1087 | off runtime vs quality of the solution. 1088 | 1089 | random_state : int or RandomState 1090 | Pseudo random number generator state used for random uniform sampling 1091 | from lists of possible values instead of scipy.stats distributions.\ 1092 | """ 1093 | _randomized_example = """\ 1094 | >>> import dask_searchcv as dcv 1095 | >>> from scipy.stats import expon 1096 | >>> from sklearn import svm, datasets 1097 | >>> iris = datasets.load_iris() 1098 | >>> parameters = {'C': expon(scale=100), 'kernel': ['linear', 'rbf']} 1099 | >>> svc = svm.SVC() 1100 | >>> clf = dcv.RandomizedSearchCV(svc, parameters, n_iter=100) 1101 | >>> clf.fit(iris.data, iris.target) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS 1102 | RandomizedSearchCV(cache_cv=..., cv=..., error_score=..., 1103 | estimator=SVC(C=..., cache_size=..., class_weight=..., coef0=..., 1104 | decision_function_shape=..., degree=..., gamma=..., 1105 | kernel=..., max_iter=..., probability=..., 1106 | random_state=..., shrinking=..., tol=..., 1107 | verbose=...), 1108 | iid=..., n_iter=..., n_jobs=..., param_distributions=..., 1109 | random_state=..., refit=..., return_train_score=..., 1110 | scheduler=..., scoring=...) 1111 | >>> sorted(clf.cv_results_.keys()) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS 1112 | ['mean_test_score', 'mean_train_score', 'param_C', 'param_kernel',... 1113 | 'params', 'rank_test_score', 'split0_test_score', 'split0_train_score',... 1114 | 'split1_test_score', 'split1_train_score', 'split2_test_score',... 1115 | 'split2_train_score', 'std_test_score', 'std_train_score'...]\ 1116 | """ 1117 | 1118 | 1119 | class RandomizedSearchCV(DaskBaseSearchCV): 1120 | __doc__ = _DOC_TEMPLATE.format(name="RandomizedSearchCV", 1121 | oneliner=_randomized_oneliner, 1122 | description=_randomized_description, 1123 | parameters=_randomized_parameters, 1124 | example=_randomized_example) 1125 | 1126 | def __init__(self, estimator, param_distributions, n_iter=10, 1127 | random_state=None, scoring=None, iid=True, refit=True, 1128 | cv=None, error_score='raise', return_train_score=True, 1129 | scheduler=None, n_jobs=-1, cache_cv=True): 1130 | 1131 | super(RandomizedSearchCV, self).__init__(estimator=estimator, 1132 | scoring=scoring, iid=iid, refit=refit, cv=cv, 1133 | error_score=error_score, return_train_score=return_train_score, 1134 | scheduler=scheduler, n_jobs=n_jobs, cache_cv=cache_cv) 1135 | 1136 | self.param_distributions = param_distributions 1137 | self.n_iter = n_iter 1138 | self.random_state = random_state 1139 | 1140 | def _get_param_iterator(self): 1141 | """Return ParameterSampler instance for the given distributions""" 1142 | return model_selection.ParameterSampler(self.param_distributions, 1143 | self.n_iter, random_state=self.random_state) 1144 | -------------------------------------------------------------------------------- /dask_searchcv/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jcrist/dask-searchcv/ba7cffdc4b2064c07cb0932082fe245b63c9392e/dask_searchcv/tests/__init__.py -------------------------------------------------------------------------------- /dask_searchcv/tests/test_model_selection.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import pickle 5 | from itertools import product 6 | from multiprocessing import cpu_count 7 | 8 | import pytest 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import dask 13 | import dask.array as da 14 | from dask.base import tokenize 15 | from dask.callbacks import Callback 16 | from dask.delayed import delayed 17 | from dask.threaded import get as get_threading 18 | from dask.utils import tmpdir 19 | 20 | from sklearn.datasets import make_classification, load_iris 21 | from sklearn.decomposition import PCA 22 | from sklearn.exceptions import NotFittedError, FitFailedWarning 23 | from sklearn.feature_selection import SelectKBest 24 | from sklearn.model_selection import (KFold, 25 | GroupKFold, 26 | StratifiedKFold, 27 | TimeSeriesSplit, 28 | ShuffleSplit, 29 | GroupShuffleSplit, 30 | StratifiedShuffleSplit, 31 | LeaveOneOut, 32 | LeavePOut, 33 | LeaveOneGroupOut, 34 | LeavePGroupsOut, 35 | PredefinedSplit, 36 | GridSearchCV) 37 | from sklearn.model_selection._split import _CVIterableWrapper 38 | from sklearn.pipeline import Pipeline, FeatureUnion 39 | from sklearn.svm import SVC 40 | 41 | import dask_searchcv as dcv 42 | from dask_searchcv.model_selection import (compute_n_splits, check_cv, 43 | _normalize_n_jobs, _normalize_scheduler) 44 | from dask_searchcv.methods import CVCache 45 | from dask_searchcv.utils_test import (FailingClassifier, MockClassifier, 46 | ScalingTransformer, CheckXClassifier, 47 | ignore_warnings) 48 | 49 | try: 50 | from distributed import Client 51 | from distributed.utils_test import cluster, loop 52 | has_distributed = True 53 | except: 54 | loop = pytest.fixture(lambda: None) 55 | has_distributed = False 56 | 57 | 58 | class assert_dask_compute(Callback): 59 | def __init__(self, compute=False): 60 | self.compute = compute 61 | 62 | def __enter__(self): 63 | self.ran = False 64 | super(assert_dask_compute, self).__enter__() 65 | 66 | def __exit__(self, *args): 67 | if not self.compute and self.ran: 68 | raise ValueError("Unexpected call to compute") 69 | elif self.compute and not self.ran: 70 | raise ValueError("Expected call to compute, but none happened") 71 | super(assert_dask_compute, self).__exit__(*args) 72 | 73 | def _start(self, dsk): 74 | self.ran = True 75 | 76 | 77 | def test_visualize(): 78 | pytest.importorskip('graphviz') 79 | 80 | X, y = make_classification(n_samples=100, n_classes=2, flip_y=.2, 81 | random_state=0) 82 | clf = SVC(random_state=0) 83 | grid = {'C': [.1, .5, .9]} 84 | gs = dcv.GridSearchCV(clf, grid).fit(X, y) 85 | 86 | assert hasattr(gs, 'dask_graph_') 87 | 88 | with tmpdir() as d: 89 | gs.visualize(filename=os.path.join(d, 'mydask')) 90 | assert os.path.exists(os.path.join(d, 'mydask.png')) 91 | 92 | # Doesn't work if not fitted 93 | gs = dcv.GridSearchCV(clf, grid) 94 | with pytest.raises(NotFittedError): 95 | gs.visualize() 96 | 97 | 98 | np_X = np.random.normal(size=(20, 3)) 99 | np_y = np.random.randint(2, size=20) 100 | np_groups = np.random.permutation(list(range(5)) * 4) 101 | da_X = da.from_array(np_X, chunks=(3, 3)) 102 | da_y = da.from_array(np_y, chunks=3) 103 | da_groups = da.from_array(np_groups, chunks=3) 104 | del_X = delayed(np_X) 105 | del_y = delayed(np_y) 106 | del_groups = delayed(np_groups) 107 | 108 | 109 | @pytest.mark.parametrize(['cls', 'has_shuffle'], 110 | [(KFold, True), 111 | (GroupKFold, False), 112 | (StratifiedKFold, True), 113 | (TimeSeriesSplit, False)]) 114 | def test_kfolds(cls, has_shuffle): 115 | assert tokenize(cls()) == tokenize(cls()) 116 | assert tokenize(cls(n_splits=3)) != tokenize(cls(n_splits=4)) 117 | if has_shuffle: 118 | assert (tokenize(cls(shuffle=True, random_state=0)) == 119 | tokenize(cls(shuffle=True, random_state=0))) 120 | 121 | rs = np.random.RandomState(42) 122 | assert (tokenize(cls(shuffle=True, random_state=rs)) == 123 | tokenize(cls(shuffle=True, random_state=rs))) 124 | 125 | assert (tokenize(cls(shuffle=True, random_state=0)) != 126 | tokenize(cls(shuffle=True, random_state=2))) 127 | 128 | assert (tokenize(cls(shuffle=False, random_state=0)) == 129 | tokenize(cls(shuffle=False, random_state=2))) 130 | 131 | cv = cls(n_splits=3) 132 | assert compute_n_splits(cv, np_X, np_y, np_groups) == 3 133 | 134 | with assert_dask_compute(False): 135 | assert compute_n_splits(cv, da_X, da_y, da_groups) == 3 136 | 137 | 138 | @pytest.mark.parametrize('cls', [ShuffleSplit, GroupShuffleSplit, 139 | StratifiedShuffleSplit]) 140 | def test_shuffle_split(cls): 141 | assert (tokenize(cls(n_splits=3, random_state=0)) == 142 | tokenize(cls(n_splits=3, random_state=0))) 143 | 144 | assert (tokenize(cls(n_splits=3, random_state=0)) != 145 | tokenize(cls(n_splits=3, random_state=2))) 146 | 147 | assert (tokenize(cls(n_splits=3, random_state=0)) != 148 | tokenize(cls(n_splits=4, random_state=0))) 149 | 150 | cv = cls(n_splits=3) 151 | assert compute_n_splits(cv, np_X, np_y, np_groups) == 3 152 | 153 | with assert_dask_compute(False): 154 | assert compute_n_splits(cv, da_X, da_y, da_groups) == 3 155 | 156 | 157 | @pytest.mark.parametrize('cvs', [(LeaveOneOut(),), 158 | (LeavePOut(2), LeavePOut(3))]) 159 | def test_leave_out(cvs): 160 | tokens = [] 161 | for cv in cvs: 162 | assert tokenize(cv) == tokenize(cv) 163 | tokens.append(cv) 164 | assert len(set(tokens)) == len(tokens) 165 | 166 | cv = cvs[0] 167 | sol = cv.get_n_splits(np_X, np_y, np_groups) 168 | assert compute_n_splits(cv, np_X, np_y, np_groups) == sol 169 | 170 | with assert_dask_compute(True): 171 | assert compute_n_splits(cv, da_X, da_y, da_groups) == sol 172 | 173 | with assert_dask_compute(False): 174 | assert compute_n_splits(cv, np_X, da_y, da_groups) == sol 175 | 176 | 177 | @pytest.mark.parametrize('cvs', [(LeaveOneGroupOut(),), 178 | (LeavePGroupsOut(2), LeavePGroupsOut(3))]) 179 | def test_leave_group_out(cvs): 180 | tokens = [] 181 | for cv in cvs: 182 | assert tokenize(cv) == tokenize(cv) 183 | tokens.append(cv) 184 | assert len(set(tokens)) == len(tokens) 185 | 186 | cv = cvs[0] 187 | sol = cv.get_n_splits(np_X, np_y, np_groups) 188 | assert compute_n_splits(cv, np_X, np_y, np_groups) == sol 189 | 190 | with assert_dask_compute(True): 191 | assert compute_n_splits(cv, da_X, da_y, da_groups) == sol 192 | 193 | with assert_dask_compute(False): 194 | assert compute_n_splits(cv, da_X, da_y, np_groups) == sol 195 | 196 | 197 | def test_predefined_split(): 198 | cv = PredefinedSplit(np.array(list(range(4)) * 5)) 199 | cv2 = PredefinedSplit(np.array(list(range(5)) * 4)) 200 | assert tokenize(cv) == tokenize(cv) 201 | assert tokenize(cv) != tokenize(cv2) 202 | 203 | sol = cv.get_n_splits(np_X, np_y, np_groups) 204 | assert compute_n_splits(cv, np_X, np_y, np_groups) == sol 205 | 206 | with assert_dask_compute(False): 207 | assert compute_n_splits(cv, da_X, da_y, da_groups) == sol 208 | 209 | 210 | def test_old_style_cv(): 211 | cv1 = _CVIterableWrapper([np.array([True, False, True, False] * 5), 212 | np.array([False, True, False, True] * 5)]) 213 | cv2 = _CVIterableWrapper([np.array([True, False, True, False] * 5), 214 | np.array([False, True, True, True] * 5)]) 215 | assert tokenize(cv1) == tokenize(cv1) 216 | assert tokenize(cv1) != tokenize(cv2) 217 | 218 | sol = cv1.get_n_splits(np_X, np_y, np_groups) 219 | assert compute_n_splits(cv1, np_X, np_y, np_groups) == sol 220 | with assert_dask_compute(False): 221 | assert compute_n_splits(cv1, da_X, da_y, da_groups) == sol 222 | 223 | 224 | def test_check_cv(): 225 | # No y, classifier=False 226 | cv = check_cv(3, classifier=False) 227 | assert isinstance(cv, KFold) and cv.n_splits == 3 228 | cv = check_cv(5, classifier=False) 229 | assert isinstance(cv, KFold) and cv.n_splits == 5 230 | 231 | # y, classifier = False 232 | dy = da.from_array(np.array([1, 0, 1, 0, 1]), chunks=2) 233 | with assert_dask_compute(False): 234 | assert isinstance(check_cv(y=dy, classifier=False), KFold) 235 | 236 | # Binary and multi-class y 237 | for y in [np.array([0, 1, 0, 1, 0, 0, 1, 1, 1]), 238 | np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])]: 239 | cv = check_cv(5, y, classifier=True) 240 | assert isinstance(cv, StratifiedKFold) and cv.n_splits == 5 241 | 242 | dy = da.from_array(y, chunks=2) 243 | with assert_dask_compute(True): 244 | cv = check_cv(5, dy, classifier=True) 245 | assert isinstance(cv, StratifiedKFold) and cv.n_splits == 5 246 | 247 | # Non-binary/multi-class y 248 | y = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]]) 249 | assert isinstance(check_cv(y=y, classifier=True), KFold) 250 | 251 | dy = da.from_array(y, chunks=2) 252 | with assert_dask_compute(True): 253 | assert isinstance(check_cv(y=dy, classifier=True), KFold) 254 | 255 | # Old style 256 | cv = [np.array([True, False, True]), np.array([False, True, False])] 257 | with assert_dask_compute(False): 258 | assert isinstance(check_cv(cv, y=dy, classifier=True), 259 | _CVIterableWrapper) 260 | 261 | # CV instance passes through 262 | y = da.ones(5, chunks=2) 263 | cv = ShuffleSplit() 264 | with assert_dask_compute(False): 265 | assert check_cv(cv, y, classifier=True) is cv 266 | assert check_cv(cv, y, classifier=False) is cv 267 | 268 | 269 | def test_grid_search_dask_inputs(): 270 | # Numpy versions 271 | np_X, np_y = make_classification(n_samples=15, n_classes=2, random_state=0) 272 | np_groups = np.random.RandomState(0).randint(0, 3, 15) 273 | # Dask array versions 274 | da_X = da.from_array(np_X, chunks=5) 275 | da_y = da.from_array(np_y, chunks=5) 276 | da_groups = da.from_array(np_groups, chunks=5) 277 | # Delayed versions 278 | del_X = delayed(np_X) 279 | del_y = delayed(np_y) 280 | del_groups = delayed(np_groups) 281 | 282 | cv = GroupKFold() 283 | clf = SVC(random_state=0) 284 | grid = {'C': [1]} 285 | 286 | sol = SVC(C=1, random_state=0).fit(np_X, np_y).support_vectors_ 287 | 288 | for X, y, groups in product([np_X, da_X, del_X], 289 | [np_y, da_y, del_y], 290 | [np_groups, da_groups, del_groups]): 291 | gs = dcv.GridSearchCV(clf, grid, cv=cv) 292 | 293 | with pytest.raises(ValueError) as exc: 294 | gs.fit(X, y) 295 | assert "The groups parameter should not be None" in str(exc.value) 296 | 297 | gs.fit(X, y, groups=groups) 298 | np.testing.assert_allclose(sol, gs.best_estimator_.support_vectors_) 299 | 300 | 301 | def test_pipeline_feature_union(): 302 | iris = load_iris() 303 | X, y = iris.data, iris.target 304 | 305 | pca = PCA(random_state=0) 306 | kbest = SelectKBest() 307 | empty_union = FeatureUnion([('first', None), ('second', None)]) 308 | empty_pipeline = Pipeline([('first', None), ('second', None)]) 309 | scaling = Pipeline([('transform', ScalingTransformer())]) 310 | svc = SVC(kernel='linear', random_state=0) 311 | 312 | pipe = Pipeline([('empty_pipeline', empty_pipeline), 313 | ('scaling', scaling), 314 | ('missing', None), 315 | ('union', FeatureUnion([('pca', pca), 316 | ('missing', None), 317 | ('kbest', kbest), 318 | ('empty_union', empty_union)], 319 | transformer_weights={'pca': 0.5})), 320 | ('svc', svc)]) 321 | 322 | param_grid = dict(scaling__transform__factor=[1, 2], 323 | union__pca__n_components=[1, 2, 3], 324 | union__kbest__k=[1, 2], 325 | svc__C=[0.1, 1, 10]) 326 | 327 | gs = GridSearchCV(pipe, param_grid=param_grid) 328 | gs.fit(X, y) 329 | dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, scheduler='sync') 330 | dgs.fit(X, y) 331 | 332 | # Check best params match 333 | assert gs.best_params_ == dgs.best_params_ 334 | 335 | # Check PCA components match 336 | sk_pca = gs.best_estimator_.named_steps['union'].transformer_list[0][1] 337 | dk_pca = dgs.best_estimator_.named_steps['union'].transformer_list[0][1] 338 | np.testing.assert_allclose(sk_pca.components_, dk_pca.components_) 339 | 340 | # Check SelectKBest scores match 341 | sk_kbest = gs.best_estimator_.named_steps['union'].transformer_list[2][1] 342 | dk_kbest = dgs.best_estimator_.named_steps['union'].transformer_list[2][1] 343 | np.testing.assert_allclose(sk_kbest.scores_, dk_kbest.scores_) 344 | 345 | # Check SVC coefs match 346 | np.testing.assert_allclose(gs.best_estimator_.named_steps['svc'].coef_, 347 | dgs.best_estimator_.named_steps['svc'].coef_) 348 | 349 | 350 | def test_pipeline_sub_estimators(): 351 | iris = load_iris() 352 | X, y = iris.data, iris.target 353 | 354 | scaling = Pipeline([('transform', ScalingTransformer())]) 355 | 356 | pipe = Pipeline([('setup', None), 357 | ('missing', None), 358 | ('scaling', scaling), 359 | ('svc', SVC(kernel='linear', random_state=0))]) 360 | 361 | param_grid = [{'svc__C': [0.1, 0.1]}, # Duplicates to test culling 362 | {'setup': [None], 363 | 'svc__C': [0.1, 1, 10], 364 | 'scaling': [ScalingTransformer(), None]}, 365 | {'setup': [SelectKBest()], 366 | 'setup__k': [1, 2], 367 | 'svc': [SVC(kernel='linear', random_state=0, C=0.1), 368 | SVC(kernel='linear', random_state=0, C=1), 369 | SVC(kernel='linear', random_state=0, C=10)]}] 370 | 371 | gs = GridSearchCV(pipe, param_grid=param_grid) 372 | gs.fit(X, y) 373 | dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, scheduler='sync') 374 | dgs.fit(X, y) 375 | 376 | # Check best params match 377 | assert gs.best_params_ == dgs.best_params_ 378 | 379 | # Check cv results match 380 | res = pd.DataFrame(dgs.cv_results_) 381 | sol = pd.DataFrame(gs.cv_results_)[res.columns] 382 | assert res.equals(sol) 383 | 384 | # Check SVC coefs match 385 | np.testing.assert_allclose(gs.best_estimator_.named_steps['svc'].coef_, 386 | dgs.best_estimator_.named_steps['svc'].coef_) 387 | 388 | 389 | def check_scores_all_nan(gs, bad_param): 390 | bad_param = 'param_' + bad_param 391 | n_candidates = len(gs.cv_results_['params']) 392 | assert all(np.isnan([gs.cv_results_['split%d_test_score' % s][cand_i] 393 | for s in range(gs.n_splits_)]).all() 394 | for cand_i in range(n_candidates) 395 | if gs.cv_results_[bad_param][cand_i] == 396 | FailingClassifier.FAILING_PARAMETER) 397 | 398 | 399 | @pytest.mark.parametrize('weights', 400 | [None, (None, {'tr0': 2, 'tr2': 3}, {'tr0': 2, 'tr2': 4})]) 401 | def test_feature_union(weights): 402 | X = np.ones((10, 5)) 403 | y = np.zeros(10) 404 | 405 | union = FeatureUnion([('tr0', ScalingTransformer()), 406 | ('tr1', ScalingTransformer()), 407 | ('tr2', ScalingTransformer())]) 408 | 409 | factors = [(2, 3, 5), (2, 4, 5), (2, 4, 6), 410 | (2, 4, None), (None, None, None)] 411 | params, sols, grid = [], [], [] 412 | for constants, w in product(factors, weights or [None]): 413 | p = {} 414 | for n, c in enumerate(constants): 415 | if c is None: 416 | p['tr%d' % n] = None 417 | elif n == 3: # 3rd is always an estimator 418 | p['tr%d' % n] = ScalingTransformer(c) 419 | else: 420 | p['tr%d__factor' % n] = c 421 | sol = union.set_params(transformer_weights=w, **p).transform(X) 422 | sols.append(sol) 423 | if w is not None: 424 | p['transformer_weights'] = w 425 | params.append(p) 426 | p2 = {'union__' + k: [v] for k, v in p.items()} 427 | p2['est'] = [CheckXClassifier(sol[0])] 428 | grid.append(p2) 429 | 430 | # Need to recreate the union after setting estimators to `None` above 431 | union = FeatureUnion([('tr0', ScalingTransformer()), 432 | ('tr1', ScalingTransformer()), 433 | ('tr2', ScalingTransformer())]) 434 | 435 | pipe = Pipeline([('union', union), ('est', CheckXClassifier())]) 436 | gs = dcv.GridSearchCV(pipe, grid, refit=False, cv=2) 437 | gs.fit(X, y) 438 | 439 | 440 | @ignore_warnings 441 | def test_feature_union_fit_failure(): 442 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 443 | 444 | pipe = Pipeline([('union', FeatureUnion([('good', MockClassifier()), 445 | ('bad', FailingClassifier())], 446 | transformer_weights={'bad': 0.5})), 447 | ('clf', MockClassifier())]) 448 | 449 | grid = {'union__bad__parameter': [0, 1, 2]} 450 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 451 | 452 | # Check that failure raises if error_score is `'raise'` 453 | with pytest.raises(ValueError): 454 | gs.fit(X, y) 455 | 456 | # Check that grid scores were set to error_score on failure 457 | gs.error_score = float('nan') 458 | with pytest.warns(FitFailedWarning): 459 | gs.fit(X, y) 460 | 461 | check_scores_all_nan(gs, 'union__bad__parameter') 462 | 463 | 464 | @ignore_warnings 465 | def test_pipeline_fit_failure(): 466 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 467 | 468 | pipe = Pipeline([('bad', FailingClassifier()), 469 | ('good1', MockClassifier()), 470 | ('good2', MockClassifier())]) 471 | 472 | grid = {'bad__parameter': [0, 1, 2]} 473 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 474 | 475 | # Check that failure raises if error_score is `'raise'` 476 | with pytest.raises(ValueError): 477 | gs.fit(X, y) 478 | 479 | # Check that grid scores were set to error_score on failure 480 | gs.error_score = float('nan') 481 | with pytest.warns(FitFailedWarning): 482 | gs.fit(X, y) 483 | 484 | check_scores_all_nan(gs, 'bad__parameter') 485 | 486 | 487 | def test_pipeline_raises(): 488 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 489 | 490 | pipe = Pipeline([('step1', MockClassifier()), 491 | ('step2', MockClassifier())]) 492 | 493 | grid = {'step3__parameter': [0, 1, 2]} 494 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 495 | with pytest.raises(ValueError): 496 | gs.fit(X, y) 497 | 498 | grid = {'steps': [[('one', MockClassifier()), ('two', MockClassifier())]]} 499 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 500 | with pytest.raises(NotImplementedError): 501 | gs.fit(X, y) 502 | 503 | 504 | def test_feature_union_raises(): 505 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 506 | 507 | union = FeatureUnion([('tr0', MockClassifier()), 508 | ('tr1', MockClassifier())]) 509 | pipe = Pipeline([('union', union), ('est', MockClassifier())]) 510 | 511 | grid = {'union__tr2__parameter': [0, 1, 2]} 512 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 513 | with pytest.raises(ValueError): 514 | gs.fit(X, y) 515 | 516 | grid = {'union__transformer_list': [[('one', MockClassifier())]]} 517 | gs = dcv.GridSearchCV(pipe, grid, refit=False) 518 | with pytest.raises(NotImplementedError): 519 | gs.fit(X, y) 520 | 521 | 522 | def test_bad_error_score(): 523 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 524 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, 525 | error_score='badparam') 526 | 527 | with pytest.raises(ValueError): 528 | gs.fit(X, y) 529 | 530 | 531 | class CountTakes(np.ndarray): 532 | count = 0 533 | 534 | def take(self, *args, **kwargs): 535 | self.count += 1 536 | return super(CountTakes, self).take(*args, **kwargs) 537 | 538 | 539 | def test_cache_cv(): 540 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 541 | X2 = X.view(CountTakes) 542 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, 543 | cv=3, cache_cv=False, scheduler='sync') 544 | gs.fit(X2, y) 545 | assert X2.count == 2 * 3 * 3 # (1 train + 1 test) * n_params * n_splits 546 | 547 | X2 = X.view(CountTakes) 548 | assert X2.count == 0 549 | gs.cache_cv = True 550 | gs.fit(X2, y) 551 | assert X2.count == 2 * 3 # (1 test + 1 train) * n_splits 552 | 553 | 554 | def test_CVCache_serializable(): 555 | inds = np.arange(10) 556 | splits = [(inds[:3], inds[3:]), (inds[3:], inds[:3])] 557 | X = np.arange(100).reshape((10, 10)) 558 | y = np.zeros(10) 559 | cache = CVCache(splits, pairwise=True, cache=True) 560 | 561 | # Add something to the cache 562 | r1 = cache.extract(X, y, 0) 563 | assert cache.extract(X, y, 0) is r1 564 | assert len(cache.cache) == 1 565 | 566 | cache2 = pickle.loads(pickle.dumps(cache)) 567 | assert len(cache2.cache) == 0 568 | assert cache2.pairwise == cache.pairwise 569 | assert all((cache2.splits[i][j] == cache.splits[i][j]).all() 570 | for i in range(2) for j in range(2)) 571 | 572 | 573 | def test_normalize_n_jobs(): 574 | assert _normalize_n_jobs(-1) is None 575 | assert _normalize_n_jobs(-2) == cpu_count() - 1 576 | with pytest.raises(TypeError): 577 | _normalize_n_jobs('not an integer') 578 | 579 | 580 | @pytest.mark.parametrize('scheduler,n_jobs,get', 581 | [(None, 4, get_threading), 582 | ('threading', 4, get_threading), 583 | ('threaded', 4, get_threading), 584 | ('threading', 1, dask.get), 585 | ('sequential', 4, dask.get), 586 | ('synchronous', 4, dask.get), 587 | ('sync', 4, dask.get), 588 | ('multiprocessing', 4, None), 589 | (dask.get, 4, dask.get)]) 590 | def test_scheduler_param(scheduler, n_jobs, get): 591 | if scheduler == 'multiprocessing': 592 | mp = pytest.importorskip('dask.multiprocessing') 593 | get = mp.get 594 | 595 | assert _normalize_scheduler(scheduler, n_jobs) is get 596 | 597 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 598 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, cv=3, 599 | scheduler=scheduler, n_jobs=n_jobs) 600 | gs.fit(X, y) 601 | 602 | 603 | @pytest.mark.skipif('not has_distributed') 604 | def test_scheduler_param_distributed(loop): 605 | X, y = make_classification(n_samples=100, n_features=10, random_state=0) 606 | gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, cv=3) 607 | with cluster() as (s, [a, b]): 608 | with Client(s['address'], loop=loop): 609 | gs.fit(X, y) 610 | 611 | 612 | def test_scheduler_param_bad(loop): 613 | with pytest.raises(ValueError): 614 | _normalize_scheduler('threeding', 4, loop) 615 | -------------------------------------------------------------------------------- /dask_searchcv/tests/test_model_selection_sklearn.py: -------------------------------------------------------------------------------- 1 | # NOTE: These tests were copied (with modification) from the equivalent 2 | # scikit-learn testing code. The scikit-learn license has been included at 3 | # dask_searchcv/SCIKIT_LEARN_LICENSE.txt. 4 | 5 | import pickle 6 | import pytest 7 | 8 | import dask 9 | import dask.array as da 10 | import numpy as np 11 | from numpy.testing import (assert_array_equal, assert_array_almost_equal, 12 | assert_almost_equal) 13 | import scipy.sparse as sp 14 | from scipy.stats import expon 15 | 16 | from sklearn.base import BaseEstimator 17 | from sklearn.cluster import KMeans 18 | from sklearn.datasets import (make_classification, make_blobs, 19 | make_multilabel_classification) 20 | from sklearn.exceptions import NotFittedError, FitFailedWarning 21 | from sklearn.linear_model import Ridge 22 | from sklearn.metrics import f1_score, make_scorer, roc_auc_score 23 | from sklearn.model_selection import (KFold, StratifiedKFold, 24 | StratifiedShuffleSplit, LeaveOneGroupOut, 25 | LeavePGroupsOut, GroupKFold, 26 | GroupShuffleSplit) 27 | from sklearn.neighbors import KernelDensity 28 | from sklearn.pipeline import Pipeline 29 | from sklearn.preprocessing import Imputer 30 | from sklearn.svm import LinearSVC, SVC 31 | from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier 32 | from sklearn.utils.fixes import in1d 33 | 34 | import dask_searchcv as dcv 35 | from dask_searchcv.utils_test import (FailingClassifier, MockClassifier, 36 | CheckingClassifier, MockDataFrame, 37 | ignore_warnings) 38 | 39 | 40 | class LinearSVCNoScore(LinearSVC): 41 | """An LinearSVC classifier that has no score method.""" 42 | @property 43 | def score(self): 44 | raise AttributeError 45 | 46 | 47 | X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) 48 | y = np.array([1, 1, 2, 2]) 49 | 50 | 51 | def assert_grid_iter_equals_getitem(grid): 52 | assert list(grid) == [grid[i] for i in range(len(grid))] 53 | 54 | 55 | def test_grid_search(): 56 | # Test that the best estimator contains the right value for foo_param 57 | clf = MockClassifier() 58 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}) 59 | # make sure it selects the smallest parameter in case of ties 60 | grid_search.fit(X, y) 61 | assert grid_search.best_estimator_.foo_param == 2 62 | 63 | assert_array_equal(grid_search.cv_results_["param_foo_param"].data, 64 | [1, 2, 3]) 65 | 66 | # Smoke test the score etc: 67 | grid_search.score(X, y) 68 | grid_search.predict_proba(X) 69 | grid_search.decision_function(X) 70 | grid_search.transform(X) 71 | 72 | # Test exception handling on scoring 73 | grid_search.scoring = 'sklearn' 74 | with pytest.raises(ValueError): 75 | grid_search.fit(X, y) 76 | 77 | 78 | @pytest.mark.parametrize('cls,kwargs', 79 | [(dcv.GridSearchCV, {}), 80 | (dcv.RandomizedSearchCV, {'n_iter': 1})]) 81 | def test_hyperparameter_searcher_with_fit_params(cls, kwargs): 82 | X = np.arange(100).reshape(10, 10) 83 | y = np.array([0] * 5 + [1] * 5) 84 | clf = CheckingClassifier(expected_fit_params=['spam', 'eggs']) 85 | pipe = Pipeline([('clf', clf)]) 86 | searcher = cls(pipe, {'clf__foo_param': [1, 2, 3]}, cv=2, **kwargs) 87 | 88 | # The CheckingClassifer generates an assertion error if 89 | # a parameter is missing or has length != len(X). 90 | with pytest.raises(AssertionError) as exc: 91 | searcher.fit(X, y, clf__spam=np.ones(10)) 92 | assert "Expected fit parameter(s) ['eggs'] not seen." in str(exc.value) 93 | 94 | searcher.fit(X, y, clf__spam=np.ones(10), clf__eggs=np.zeros(10)) 95 | # Test with dask objects as parameters 96 | searcher.fit(X, y, clf__spam=da.ones(10, chunks=2), 97 | clf__eggs=dask.delayed(np.zeros(10))) 98 | 99 | 100 | @ignore_warnings 101 | def test_grid_search_no_score(): 102 | # Test grid-search on classifier that has no score function. 103 | clf = LinearSVC(random_state=0) 104 | X, y = make_blobs(random_state=0, centers=2) 105 | Cs = [.1, 1, 10] 106 | clf_no_score = LinearSVCNoScore(random_state=0) 107 | 108 | # XXX: It seems there's some global shared state in LinearSVC - fitting 109 | # multiple `SVC` instances in parallel using threads sometimes results in 110 | # wrong results. This only happens with threads, not processes/sync. 111 | # For now, we'll fit using the sync scheduler. 112 | grid_search = dcv.GridSearchCV(clf, {'C': Cs}, scoring='accuracy', 113 | scheduler='sync') 114 | grid_search.fit(X, y) 115 | 116 | grid_search_no_score = dcv.GridSearchCV(clf_no_score, {'C': Cs}, 117 | scoring='accuracy', 118 | scheduler='sync') 119 | # smoketest grid search 120 | grid_search_no_score.fit(X, y) 121 | 122 | # check that best params are equal 123 | assert grid_search_no_score.best_params_ == grid_search.best_params_ 124 | # check that we can call score and that it gives the correct result 125 | assert grid_search.score(X, y) == grid_search_no_score.score(X, y) 126 | 127 | # giving no scoring function raises an error 128 | grid_search_no_score = dcv.GridSearchCV(clf_no_score, {'C': Cs}) 129 | with pytest.raises(TypeError) as exc: 130 | grid_search_no_score.fit([[1]]) 131 | assert "no scoring" in str(exc.value) 132 | 133 | 134 | def test_grid_search_score_method(): 135 | X, y = make_classification(n_samples=100, n_classes=2, flip_y=.2, 136 | random_state=0) 137 | clf = LinearSVC(random_state=0) 138 | grid = {'C': [.1]} 139 | 140 | search_no_scoring = dcv.GridSearchCV(clf, grid, scoring=None).fit(X, y) 141 | search_accuracy = dcv.GridSearchCV(clf, grid, scoring='accuracy').fit(X, y) 142 | search_no_score_method_auc = dcv.GridSearchCV(LinearSVCNoScore(), grid, 143 | scoring='roc_auc').fit(X, y) 144 | search_auc = dcv.GridSearchCV(clf, grid, scoring='roc_auc').fit(X, y) 145 | 146 | # Check warning only occurs in situation where behavior changed: 147 | # estimator requires score method to compete with scoring parameter 148 | score_no_scoring = search_no_scoring.score(X, y) 149 | score_accuracy = search_accuracy.score(X, y) 150 | score_no_score_auc = search_no_score_method_auc.score(X, y) 151 | score_auc = search_auc.score(X, y) 152 | 153 | # ensure the test is sane 154 | assert score_auc < 1.0 155 | assert score_accuracy < 1.0 156 | assert score_auc != score_accuracy 157 | 158 | assert_almost_equal(score_accuracy, score_no_scoring) 159 | assert_almost_equal(score_auc, score_no_score_auc) 160 | 161 | 162 | def test_grid_search_groups(): 163 | # Check if ValueError (when groups is None) propagates to dcv.GridSearchCV 164 | # And also check if groups is correctly passed to the cv object 165 | rng = np.random.RandomState(0) 166 | 167 | X, y = make_classification(n_samples=15, n_classes=2, random_state=0) 168 | groups = rng.randint(0, 3, 15) 169 | 170 | clf = LinearSVC(random_state=0) 171 | grid = {'C': [1]} 172 | 173 | group_cvs = [LeaveOneGroupOut(), LeavePGroupsOut(2), GroupKFold(), 174 | GroupShuffleSplit()] 175 | for cv in group_cvs: 176 | gs = dcv.GridSearchCV(clf, grid, cv=cv) 177 | 178 | with pytest.raises(ValueError) as exc: 179 | assert gs.fit(X, y) 180 | assert "The groups parameter should not be None" in str(exc.value) 181 | 182 | gs.fit(X, y, groups=groups) 183 | 184 | non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit()] 185 | for cv in non_group_cvs: 186 | gs = dcv.GridSearchCV(clf, grid, cv=cv) 187 | # Should not raise an error 188 | gs.fit(X, y) 189 | 190 | 191 | def test_classes__property(): 192 | # Test that classes_ property matches best_estimator_.classes_ 193 | X = np.arange(100).reshape(10, 10) 194 | y = np.array([0] * 5 + [1] * 5) 195 | Cs = [.1, 1, 10] 196 | 197 | grid_search = dcv.GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) 198 | grid_search.fit(X, y) 199 | assert_array_equal(grid_search.best_estimator_.classes_, 200 | grid_search.classes_) 201 | 202 | # Test that regressors do not have a classes_ attribute 203 | grid_search = dcv.GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]}) 204 | grid_search.fit(X, y) 205 | assert not hasattr(grid_search, 'classes_') 206 | 207 | # Test that the grid searcher has no classes_ attribute before it's fit 208 | grid_search = dcv.GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) 209 | assert not hasattr(grid_search, 'classes_') 210 | 211 | # Test that the grid searcher has no classes_ attribute without a refit 212 | grid_search = dcv.GridSearchCV(LinearSVC(random_state=0), 213 | {'C': Cs}, refit=False) 214 | grid_search.fit(X, y) 215 | assert not hasattr(grid_search, 'classes_') 216 | 217 | 218 | def test_trivial_cv_results_attr(): 219 | # Test search over a "grid" with only one point. 220 | # Non-regression test: grid_scores_ wouldn't be set by dcv.GridSearchCV. 221 | clf = MockClassifier() 222 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1]}) 223 | grid_search.fit(X, y) 224 | assert hasattr(grid_search, "cv_results_") 225 | 226 | random_search = dcv.RandomizedSearchCV(clf, {'foo_param': [0]}, n_iter=1) 227 | random_search.fit(X, y) 228 | assert hasattr(grid_search, "cv_results_") 229 | 230 | 231 | def test_no_refit(): 232 | # Test that GSCV can be used for model selection alone without refitting 233 | clf = MockClassifier() 234 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=False) 235 | grid_search.fit(X, y) 236 | assert (not hasattr(grid_search, "best_estimator_") and 237 | hasattr(grid_search, "best_index_") and 238 | hasattr(grid_search, "best_params_")) 239 | 240 | # Make sure the predict/transform etc fns raise meaningfull error msg 241 | for fn_name in ('predict', 'predict_proba', 'predict_log_proba', 242 | 'transform', 'inverse_transform'): 243 | with pytest.raises(NotFittedError) as exc: 244 | getattr(grid_search, fn_name)(X) 245 | assert (('refit=False. %s is available only after refitting on the ' 246 | 'best parameters' % fn_name) in str(exc.value)) 247 | 248 | 249 | def test_grid_search_error(): 250 | # Test that grid search will capture errors on data with different length 251 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 252 | 253 | clf = LinearSVC() 254 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 255 | with pytest.raises(ValueError): 256 | cv.fit(X_[:180], y_) 257 | 258 | 259 | def test_grid_search_one_grid_point(): 260 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 261 | param_dict = {"C": [1.0], "kernel": ["rbf"], "gamma": [0.1]} 262 | 263 | clf = SVC() 264 | cv = dcv.GridSearchCV(clf, param_dict) 265 | cv.fit(X_, y_) 266 | 267 | clf = SVC(C=1.0, kernel="rbf", gamma=0.1) 268 | clf.fit(X_, y_) 269 | 270 | assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_) 271 | 272 | 273 | def test_grid_search_bad_param_grid(): 274 | param_dict = {"C": 1.0} 275 | clf = SVC() 276 | 277 | with pytest.raises(ValueError) as exc: 278 | dcv.GridSearchCV(clf, param_dict) 279 | assert ("Parameter values for parameter (C) need to be a sequence" 280 | "(but not a string) or np.ndarray.") in str(exc.value) 281 | 282 | param_dict = {"C": []} 283 | clf = SVC() 284 | 285 | with pytest.raises(ValueError) as exc: 286 | dcv.GridSearchCV(clf, param_dict) 287 | assert ("Parameter values for parameter (C) need to be a non-empty " 288 | "sequence.") in str(exc.value) 289 | 290 | param_dict = {"C": "1,2,3"} 291 | clf = SVC() 292 | 293 | with pytest.raises(ValueError) as exc: 294 | dcv.GridSearchCV(clf, param_dict) 295 | assert ("Parameter values for parameter (C) need to be a sequence" 296 | "(but not a string) or np.ndarray.") in str(exc.value) 297 | 298 | param_dict = {"C": np.ones(6).reshape(3, 2)} 299 | clf = SVC() 300 | with pytest.raises(ValueError): 301 | dcv.GridSearchCV(clf, param_dict) 302 | 303 | 304 | def test_grid_search_sparse(): 305 | # Test that grid search works with both dense and sparse matrices 306 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 307 | 308 | clf = LinearSVC() 309 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 310 | cv.fit(X_[:180], y_[:180]) 311 | y_pred = cv.predict(X_[180:]) 312 | C = cv.best_estimator_.C 313 | 314 | X_ = sp.csr_matrix(X_) 315 | clf = LinearSVC() 316 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 317 | cv.fit(X_[:180].tocoo(), y_[:180]) 318 | y_pred2 = cv.predict(X_[180:]) 319 | C2 = cv.best_estimator_.C 320 | 321 | assert np.mean(y_pred == y_pred2) >= .9 322 | assert C == C2 323 | 324 | 325 | def test_grid_search_sparse_scoring(): 326 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 327 | 328 | clf = LinearSVC() 329 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") 330 | cv.fit(X_[:180], y_[:180]) 331 | y_pred = cv.predict(X_[180:]) 332 | C = cv.best_estimator_.C 333 | 334 | X_ = sp.csr_matrix(X_) 335 | clf = LinearSVC() 336 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring="f1") 337 | cv.fit(X_[:180], y_[:180]) 338 | y_pred2 = cv.predict(X_[180:]) 339 | C2 = cv.best_estimator_.C 340 | 341 | assert_array_equal(y_pred, y_pred2) 342 | assert C == C2 343 | # Smoke test the score 344 | # np.testing.assert_allclose(f1_score(cv.predict(X_[:180]), y[:180]), 345 | # cv.score(X_[:180], y[:180])) 346 | 347 | # test loss where greater is worse 348 | def f1_loss(y_true_, y_pred_): 349 | return -f1_score(y_true_, y_pred_) 350 | F1Loss = make_scorer(f1_loss, greater_is_better=False) 351 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}, scoring=F1Loss) 352 | cv.fit(X_[:180], y_[:180]) 353 | y_pred3 = cv.predict(X_[180:]) 354 | C3 = cv.best_estimator_.C 355 | 356 | assert C == C3 357 | assert_array_equal(y_pred, y_pred3) 358 | 359 | 360 | def test_grid_search_precomputed_kernel(): 361 | # Test that grid search works when the input features are given in the 362 | # form of a precomputed kernel matrix 363 | X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0) 364 | 365 | # compute the training kernel matrix corresponding to the linear kernel 366 | K_train = np.dot(X_[:180], X_[:180].T) 367 | y_train = y_[:180] 368 | 369 | clf = SVC(kernel='precomputed') 370 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 371 | cv.fit(K_train, y_train) 372 | 373 | assert cv.best_score_ >= 0 374 | 375 | # compute the test kernel matrix 376 | K_test = np.dot(X_[180:], X_[:180].T) 377 | y_test = y_[180:] 378 | 379 | y_pred = cv.predict(K_test) 380 | 381 | assert np.mean(y_pred == y_test) >= 0 382 | 383 | # test error is raised when the precomputed kernel is not array-like 384 | # or sparse 385 | with pytest.raises(ValueError): 386 | cv.fit(K_train.tolist(), y_train) 387 | 388 | 389 | def test_grid_search_precomputed_kernel_error_nonsquare(): 390 | # Test that grid search returns an error with a non-square precomputed 391 | # training kernel matrix 392 | K_train = np.zeros((10, 20)) 393 | y_train = np.ones((10, )) 394 | clf = SVC(kernel='precomputed') 395 | cv = dcv.GridSearchCV(clf, {'C': [0.1, 1.0]}) 396 | with pytest.raises(ValueError): 397 | cv.fit(K_train, y_train) 398 | 399 | 400 | class BrokenClassifier(BaseEstimator): 401 | """Broken classifier that cannot be fit twice""" 402 | 403 | def __init__(self, parameter=None): 404 | self.parameter = parameter 405 | 406 | def fit(self, X, y): 407 | assert not hasattr(self, 'has_been_fit_') 408 | self.has_been_fit_ = True 409 | 410 | def predict(self, X): 411 | return np.zeros(X.shape[0]) 412 | 413 | 414 | @ignore_warnings 415 | def test_refit(): 416 | # Regression test for bug in refitting 417 | # Simulates re-fitting a broken estimator; this used to break with 418 | # sparse SVMs. 419 | X = np.arange(100).reshape(10, 10) 420 | y = np.array([0] * 5 + [1] * 5) 421 | 422 | clf = dcv.GridSearchCV(BrokenClassifier(), [{'parameter': [0, 1]}], 423 | scoring="precision", refit=True) 424 | clf.fit(X, y) 425 | 426 | 427 | def test_gridsearch_nd(): 428 | # Pass X as list in dcv.GridSearchCV 429 | X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2) 430 | y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11) 431 | clf = CheckingClassifier(check_X=lambda x: x.shape[1:] == (5, 3, 2), 432 | check_y=lambda x: x.shape[1:] == (7, 11)) 433 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}) 434 | grid_search.fit(X_4d, y_3d).score(X, y) 435 | assert hasattr(grid_search, "cv_results_") 436 | 437 | 438 | def test_X_as_list(): 439 | # Pass X as list in dcv.GridSearchCV 440 | X = np.arange(100).reshape(10, 10) 441 | y = np.array([0] * 5 + [1] * 5) 442 | 443 | clf = CheckingClassifier(check_X=lambda x: isinstance(x, list)) 444 | cv = KFold(n_splits=3) 445 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) 446 | grid_search.fit(X.tolist(), y).score(X, y) 447 | assert hasattr(grid_search, "cv_results_") 448 | 449 | 450 | def test_y_as_list(): 451 | # Pass y as list in dcv.GridSearchCV 452 | X = np.arange(100).reshape(10, 10) 453 | y = np.array([0] * 5 + [1] * 5) 454 | 455 | clf = CheckingClassifier(check_y=lambda x: isinstance(x, list)) 456 | cv = KFold(n_splits=3) 457 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv) 458 | grid_search.fit(X, y.tolist()).score(X, y) 459 | assert hasattr(grid_search, "cv_results_") 460 | 461 | 462 | @ignore_warnings 463 | def test_pandas_input(): 464 | # check cross_val_score doesn't destroy pandas dataframe 465 | types = [(MockDataFrame, MockDataFrame)] 466 | try: 467 | from pandas import Series, DataFrame 468 | types.append((DataFrame, Series)) 469 | except ImportError: 470 | pass 471 | 472 | X = np.arange(100).reshape(10, 10) 473 | y = np.array([0] * 5 + [1] * 5) 474 | 475 | for InputFeatureType, TargetType in types: 476 | # X dataframe, y series 477 | X_df, y_ser = InputFeatureType(X), TargetType(y) 478 | clf = CheckingClassifier(check_X=lambda x: isinstance(x, InputFeatureType), 479 | check_y=lambda x: isinstance(x, TargetType)) 480 | 481 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}) 482 | grid_search.fit(X_df, y_ser).score(X_df, y_ser) 483 | grid_search.predict(X_df) 484 | assert hasattr(grid_search, "cv_results_") 485 | 486 | 487 | def test_unsupervised_grid_search(): 488 | # test grid-search with unsupervised estimator 489 | X, y = make_blobs(random_state=0) 490 | km = KMeans(random_state=0) 491 | grid_search = dcv.GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]), 492 | scoring='adjusted_rand_score') 493 | grid_search.fit(X, y) 494 | # ARI can find the right number :) 495 | assert grid_search.best_params_["n_clusters"] == 3 496 | 497 | # Now without a score, and without y 498 | grid_search = dcv.GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4])) 499 | grid_search.fit(X) 500 | assert grid_search.best_params_["n_clusters"] == 4 501 | 502 | 503 | def test_gridsearch_no_predict(): 504 | # test grid-search with an estimator without predict. 505 | # slight duplication of a test from KDE 506 | def custom_scoring(estimator, X): 507 | return 42 if estimator.bandwidth == .1 else 0 508 | X, _ = make_blobs(cluster_std=.1, random_state=1, 509 | centers=[[0, 1], [1, 0], [0, 0]]) 510 | search = dcv.GridSearchCV(KernelDensity(), 511 | param_grid=dict(bandwidth=[.01, .1, 1]), 512 | scoring=custom_scoring) 513 | search.fit(X) 514 | assert search.best_params_['bandwidth'] == .1 515 | assert search.best_score_ == 42 516 | 517 | 518 | def check_cv_results_array_types(cv_results, param_keys, score_keys): 519 | # Check if the search `cv_results`'s array are of correct types 520 | assert all(isinstance(cv_results[param], np.ma.MaskedArray) 521 | for param in param_keys) 522 | assert all(cv_results[key].dtype == object for key in param_keys) 523 | assert not any(isinstance(cv_results[key], np.ma.MaskedArray) 524 | for key in score_keys) 525 | assert all(cv_results[key].dtype == np.float64 526 | for key in score_keys if not key.startswith('rank')) 527 | assert cv_results['rank_test_score'].dtype == np.int32 528 | 529 | 530 | def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand): 531 | # Test the search.cv_results_ contains all the required results 532 | assert_array_equal(sorted(cv_results.keys()), 533 | sorted(param_keys + score_keys + ('params',))) 534 | assert all(cv_results[key].shape == (n_cand,) 535 | for key in param_keys + score_keys) 536 | 537 | 538 | def test_grid_search_cv_results(): 539 | X, y = make_classification(n_samples=50, n_features=4, 540 | random_state=42) 541 | 542 | n_splits = 3 543 | n_grid_points = 6 544 | params = [dict(kernel=['rbf', ], C=[1, 10], gamma=[0.1, 1]), 545 | dict(kernel=['poly', ], degree=[1, 2])] 546 | grid_search = dcv.GridSearchCV(SVC(), cv=n_splits, iid=False, 547 | param_grid=params) 548 | grid_search.fit(X, y) 549 | grid_search_iid = dcv.GridSearchCV(SVC(), cv=n_splits, iid=True, 550 | param_grid=params) 551 | grid_search_iid.fit(X, y) 552 | 553 | param_keys = ('param_C', 'param_degree', 'param_gamma', 'param_kernel') 554 | score_keys = ('mean_test_score', 'mean_train_score', 555 | 'rank_test_score', 556 | 'split0_test_score', 'split1_test_score', 557 | 'split2_test_score', 558 | 'split0_train_score', 'split1_train_score', 559 | 'split2_train_score', 560 | 'std_test_score', 'std_train_score') 561 | n_candidates = n_grid_points 562 | 563 | for search, iid in zip((grid_search, grid_search_iid), (False, True)): 564 | assert iid == search.iid 565 | cv_results = search.cv_results_ 566 | # Check if score and timing are reasonable 567 | assert all(cv_results['rank_test_score'] >= 1) 568 | assert all(all(cv_results[k] >= 0) for k in score_keys 569 | if k != 'rank_test_score') 570 | assert all(all(cv_results[k] <= 1) for k in score_keys 571 | if 'time' not in k and k != 'rank_test_score') 572 | # Check cv_results structure 573 | check_cv_results_array_types(cv_results, param_keys, score_keys) 574 | check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates) 575 | # Check masking 576 | cv_results = grid_search.cv_results_ 577 | n_candidates = len(grid_search.cv_results_['params']) 578 | assert all((cv_results['param_C'].mask[i] and 579 | cv_results['param_gamma'].mask[i] and 580 | not cv_results['param_degree'].mask[i]) 581 | for i in range(n_candidates) 582 | if cv_results['param_kernel'][i] == 'linear') 583 | assert all((not cv_results['param_C'].mask[i] and 584 | not cv_results['param_gamma'].mask[i] and 585 | cv_results['param_degree'].mask[i]) 586 | for i in range(n_candidates) 587 | if cv_results['param_kernel'][i] == 'rbf') 588 | 589 | 590 | def test_random_search_cv_results(): 591 | # Make a dataset with a lot of noise to get various kind of prediction 592 | # errors across CV folds and parameter settings 593 | X, y = make_classification(n_samples=200, n_features=100, n_informative=3, 594 | random_state=0) 595 | 596 | # scipy.stats dists now supports `seed` but we still support scipy 0.12 597 | # which doesn't support the seed. Hence the assertions in the test for 598 | # random_search alone should not depend on randomization. 599 | n_splits = 3 600 | n_search_iter = 30 601 | params = dict(C=expon(scale=10), gamma=expon(scale=0.1)) 602 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=n_search_iter, 603 | cv=n_splits, iid=False, 604 | param_distributions=params) 605 | random_search.fit(X, y) 606 | random_search_iid = dcv.RandomizedSearchCV(SVC(), n_iter=n_search_iter, 607 | cv=n_splits, iid=True, 608 | param_distributions=params) 609 | random_search_iid.fit(X, y) 610 | 611 | param_keys = ('param_C', 'param_gamma') 612 | score_keys = ('mean_test_score', 'mean_train_score', 613 | 'rank_test_score', 614 | 'split0_test_score', 'split1_test_score', 615 | 'split2_test_score', 616 | 'split0_train_score', 'split1_train_score', 617 | 'split2_train_score', 618 | 'std_test_score', 'std_train_score') 619 | n_cand = n_search_iter 620 | 621 | for search, iid in zip((random_search, random_search_iid), (False, True)): 622 | assert iid == search.iid 623 | cv_results = search.cv_results_ 624 | # Check results structure 625 | check_cv_results_array_types(cv_results, param_keys, score_keys) 626 | check_cv_results_keys(cv_results, param_keys, score_keys, n_cand) 627 | # For random_search, all the param array vals should be unmasked 628 | assert not (any(cv_results['param_C'].mask) or 629 | any(cv_results['param_gamma'].mask)) 630 | 631 | 632 | def test_search_iid_param(): 633 | # Test the IID parameter 634 | # noise-free simple 2d-data 635 | X, y = make_blobs(centers=[[0, 0], [1, 0], [0, 1], [1, 1]], random_state=0, 636 | cluster_std=0.1, shuffle=False, n_samples=80) 637 | # split dataset into two folds that are not iid 638 | # first one contains data of all 4 blobs, second only from two. 639 | mask = np.ones(X.shape[0], dtype=np.bool) 640 | mask[np.where(y == 1)[0][::2]] = 0 641 | mask[np.where(y == 2)[0][::2]] = 0 642 | # this leads to perfect classification on one fold and a score of 1/3 on 643 | # the other 644 | # create "cv" for splits 645 | cv = [[mask, ~mask], [~mask, mask]] 646 | # once with iid=True (default) 647 | grid_search = dcv.GridSearchCV(SVC(), param_grid={'C': [1, 10]}, cv=cv) 648 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=2, 649 | param_distributions={'C': [1, 10]}, 650 | cv=cv) 651 | for search in (grid_search, random_search): 652 | search.fit(X, y) 653 | assert search.iid 654 | 655 | test_cv_scores = np.array(list(search.cv_results_['split%d_test_score' 656 | % s_i][0] 657 | for s_i in range(search.n_splits_))) 658 | train_cv_scores = np.array(list(search.cv_results_['split%d_train_' 659 | 'score' % s_i][0] 660 | for s_i in range(search.n_splits_))) 661 | test_mean = search.cv_results_['mean_test_score'][0] 662 | test_std = search.cv_results_['std_test_score'][0] 663 | 664 | train_cv_scores = np.array(list(search.cv_results_['split%d_train_' 665 | 'score' % s_i][0] 666 | for s_i in range(search.n_splits_))) 667 | train_mean = search.cv_results_['mean_train_score'][0] 668 | train_std = search.cv_results_['std_train_score'][0] 669 | 670 | # Test the first candidate 671 | assert search.cv_results_['param_C'][0] == 1 672 | assert_array_almost_equal(test_cv_scores, [1, 1. / 3.]) 673 | assert_array_almost_equal(train_cv_scores, [1, 1]) 674 | 675 | # for first split, 1/4 of dataset is in test, for second 3/4. 676 | # take weighted average and weighted std 677 | expected_test_mean = 1 * 1. / 4. + 1. / 3. * 3. / 4. 678 | expected_test_std = np.sqrt(1. / 4 * (expected_test_mean - 1) ** 2 + 679 | 3. / 4 * (expected_test_mean - 1. / 3.) ** 680 | 2) 681 | assert_almost_equal(test_mean, expected_test_mean) 682 | assert_almost_equal(test_std, expected_test_std) 683 | 684 | # For the train scores, we do not take a weighted mean irrespective of 685 | # i.i.d. or not 686 | assert_almost_equal(train_mean, 1) 687 | assert_almost_equal(train_std, 0) 688 | 689 | # once with iid=False 690 | grid_search = dcv.GridSearchCV(SVC(), param_grid={'C': [1, 10]}, 691 | cv=cv, iid=False) 692 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=2, 693 | param_distributions={'C': [1, 10]}, 694 | cv=cv, iid=False) 695 | 696 | for search in (grid_search, random_search): 697 | search.fit(X, y) 698 | assert not search.iid 699 | 700 | test_cv_scores = np.array(list(search.cv_results_['split%d_test_score' 701 | % s][0] 702 | for s in range(search.n_splits_))) 703 | test_mean = search.cv_results_['mean_test_score'][0] 704 | test_std = search.cv_results_['std_test_score'][0] 705 | 706 | train_cv_scores = np.array(list(search.cv_results_['split%d_train_' 707 | 'score' % s][0] 708 | for s in range(search.n_splits_))) 709 | train_mean = search.cv_results_['mean_train_score'][0] 710 | train_std = search.cv_results_['std_train_score'][0] 711 | 712 | assert search.cv_results_['param_C'][0] == 1 713 | # scores are the same as above 714 | assert_array_almost_equal(test_cv_scores, [1, 1. / 3.]) 715 | # Unweighted mean/std is used 716 | assert_almost_equal(test_mean, np.mean(test_cv_scores)) 717 | assert_almost_equal(test_std, np.std(test_cv_scores)) 718 | 719 | # For the train scores, we do not take a weighted mean irrespective of 720 | # i.i.d. or not 721 | assert_almost_equal(train_mean, 1) 722 | assert_almost_equal(train_std, 0) 723 | 724 | 725 | def test_search_cv_results_rank_tie_breaking(): 726 | X, y = make_blobs(n_samples=50, random_state=42) 727 | 728 | # The two C values are close enough to give similar models 729 | # which would result in a tie of their mean cv-scores 730 | param_grid = {'C': [1, 1.001, 0.001]} 731 | 732 | grid_search = dcv.GridSearchCV(SVC(), param_grid=param_grid) 733 | random_search = dcv.RandomizedSearchCV(SVC(), n_iter=3, 734 | param_distributions=param_grid) 735 | 736 | for search in (grid_search, random_search): 737 | search.fit(X, y) 738 | cv_results = search.cv_results_ 739 | # Check tie breaking strategy - 740 | # Check that there is a tie in the mean scores between 741 | # candidates 1 and 2 alone 742 | assert_almost_equal(cv_results['mean_test_score'][0], 743 | cv_results['mean_test_score'][1]) 744 | assert_almost_equal(cv_results['mean_train_score'][0], 745 | cv_results['mean_train_score'][1]) 746 | try: 747 | assert_almost_equal(cv_results['mean_test_score'][1], 748 | cv_results['mean_test_score'][2]) 749 | except AssertionError: 750 | pass 751 | try: 752 | assert_almost_equal(cv_results['mean_train_score'][1], 753 | cv_results['mean_train_score'][2]) 754 | except AssertionError: 755 | pass 756 | # 'min' rank should be assigned to the tied candidates 757 | assert_almost_equal(search.cv_results_['rank_test_score'], [1, 1, 3]) 758 | 759 | 760 | def test_search_cv_results_none_param(): 761 | X, y = [[1], [2], [3], [4], [5]], [0, 0, 0, 0, 1] 762 | estimators = (DecisionTreeRegressor(), DecisionTreeClassifier()) 763 | est_parameters = {"random_state": [0, None]} 764 | cv = KFold(random_state=0) 765 | 766 | for est in estimators: 767 | grid_search = dcv.GridSearchCV(est, est_parameters, cv=cv).fit(X, y) 768 | assert_array_equal(grid_search.cv_results_['param_random_state'], 769 | [0, None]) 770 | 771 | 772 | def test_grid_search_correct_score_results(): 773 | # test that correct scores are used 774 | n_splits = 3 775 | clf = LinearSVC(random_state=0) 776 | X, y = make_blobs(random_state=0, centers=2) 777 | Cs = [.1, 1, 10] 778 | for score in ['f1', 'roc_auc']: 779 | # XXX: It seems there's some global shared state in LinearSVC - fitting 780 | # multiple `SVC` instances in parallel using threads sometimes results 781 | # in wrong results. This only happens with threads, not processes/sync. 782 | # For now, we'll fit using the sync scheduler. 783 | grid_search = dcv.GridSearchCV(clf, {'C': Cs}, scoring=score, 784 | cv=n_splits, scheduler='sync') 785 | cv_results = grid_search.fit(X, y).cv_results_ 786 | 787 | # Test scorer names 788 | result_keys = list(cv_results.keys()) 789 | expected_keys = (("mean_test_score", "rank_test_score") + 790 | tuple("split%d_test_score" % cv_i 791 | for cv_i in range(n_splits))) 792 | assert all(in1d(expected_keys, result_keys)) 793 | 794 | cv = StratifiedKFold(n_splits=n_splits) 795 | n_splits = grid_search.n_splits_ 796 | for candidate_i, C in enumerate(Cs): 797 | clf.set_params(C=C) 798 | cv_scores = np.array( 799 | list(grid_search.cv_results_['split%d_test_score' 800 | % s][candidate_i] 801 | for s in range(n_splits))) 802 | for i, (train, test) in enumerate(cv.split(X, y)): 803 | clf.fit(X[train], y[train]) 804 | if score == "f1": 805 | correct_score = f1_score(y[test], clf.predict(X[test])) 806 | elif score == "roc_auc": 807 | dec = clf.decision_function(X[test]) 808 | correct_score = roc_auc_score(y[test], dec) 809 | assert_almost_equal(correct_score, cv_scores[i]) 810 | 811 | 812 | def test_pickle(): 813 | # Test that a fit search can be pickled 814 | clf = MockClassifier() 815 | grid_search = dcv.GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True) 816 | grid_search.fit(X, y) 817 | grid_search_pickled = pickle.loads(pickle.dumps(grid_search)) 818 | assert_array_almost_equal(grid_search.predict(X), 819 | grid_search_pickled.predict(X)) 820 | 821 | random_search = dcv.RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]}, 822 | refit=True, n_iter=3) 823 | random_search.fit(X, y) 824 | random_search_pickled = pickle.loads(pickle.dumps(random_search)) 825 | assert_array_almost_equal(random_search.predict(X), 826 | random_search_pickled.predict(X)) 827 | 828 | 829 | def test_grid_search_with_multioutput_data(): 830 | # Test search with multi-output estimator 831 | 832 | X, y = make_multilabel_classification(return_indicator=True, 833 | random_state=0) 834 | 835 | est_parameters = {"max_depth": [1, 2, 3, 4]} 836 | cv = KFold(random_state=0) 837 | 838 | estimators = [DecisionTreeRegressor(random_state=0), 839 | DecisionTreeClassifier(random_state=0)] 840 | 841 | # Test with grid search cv 842 | for est in estimators: 843 | grid_search = dcv.GridSearchCV(est, est_parameters, cv=cv) 844 | grid_search.fit(X, y) 845 | res_params = grid_search.cv_results_['params'] 846 | for cand_i in range(len(res_params)): 847 | est.set_params(**res_params[cand_i]) 848 | 849 | for i, (train, test) in enumerate(cv.split(X, y)): 850 | est.fit(X[train], y[train]) 851 | correct_score = est.score(X[test], y[test]) 852 | assert_almost_equal( 853 | correct_score, 854 | grid_search.cv_results_['split%d_test_score' % i][cand_i]) 855 | 856 | # Test with a randomized search 857 | for est in estimators: 858 | random_search = dcv.RandomizedSearchCV(est, est_parameters, 859 | cv=cv, n_iter=3) 860 | random_search.fit(X, y) 861 | res_params = random_search.cv_results_['params'] 862 | for cand_i in range(len(res_params)): 863 | est.set_params(**res_params[cand_i]) 864 | 865 | for i, (train, test) in enumerate(cv.split(X, y)): 866 | est.fit(X[train], y[train]) 867 | correct_score = est.score(X[test], y[test]) 868 | assert_almost_equal( 869 | correct_score, 870 | random_search.cv_results_['split%d_test_score' 871 | % i][cand_i]) 872 | 873 | 874 | def test_predict_proba_disabled(): 875 | # Test predict_proba when disabled on estimator. 876 | X = np.arange(20).reshape(5, -1) 877 | y = [0, 0, 1, 1, 1] 878 | clf = SVC(probability=False) 879 | gs = dcv.GridSearchCV(clf, {}, cv=2).fit(X, y) 880 | assert not hasattr(gs, "predict_proba") 881 | 882 | 883 | def test_grid_search_allows_nans(): 884 | # Test dcv.GridSearchCV with Imputer 885 | X = np.arange(20, dtype=np.float64).reshape(5, -1) 886 | X[2, :] = np.nan 887 | y = [0, 0, 1, 1, 1] 888 | p = Pipeline([ 889 | ('imputer', Imputer(strategy='mean', missing_values='NaN')), 890 | ('classifier', MockClassifier()), 891 | ]) 892 | dcv.GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y) 893 | 894 | 895 | @ignore_warnings 896 | def test_grid_search_failing_classifier(): 897 | X, y = make_classification(n_samples=20, n_features=10, random_state=0) 898 | clf = FailingClassifier() 899 | 900 | # refit=False because we want to test the behaviour of the grid search part 901 | gs = dcv.GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', 902 | refit=False, error_score=0.0) 903 | 904 | with pytest.warns(FitFailedWarning): 905 | gs.fit(X, y) 906 | 907 | n_candidates = len(gs.cv_results_['params']) 908 | 909 | # Ensure that grid scores were set to zero as required for those fits 910 | # that are expected to fail. 911 | def get_cand_scores(i): 912 | return np.array(list(gs.cv_results_['split%d_test_score' % s][i] 913 | for s in range(gs.n_splits_))) 914 | 915 | assert all((np.all(get_cand_scores(cand_i) == 0.0) 916 | for cand_i in range(n_candidates) 917 | if gs.cv_results_['param_parameter'][cand_i] == 918 | FailingClassifier.FAILING_PARAMETER)) 919 | 920 | gs = dcv.GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', 921 | refit=False, error_score=float('nan')) 922 | 923 | with pytest.warns(FitFailedWarning): 924 | gs.fit(X, y) 925 | 926 | n_candidates = len(gs.cv_results_['params']) 927 | assert all(np.all(np.isnan(get_cand_scores(cand_i))) 928 | for cand_i in range(n_candidates) 929 | if gs.cv_results_['param_parameter'][cand_i] == 930 | FailingClassifier.FAILING_PARAMETER) 931 | 932 | 933 | def test_grid_search_failing_classifier_raise(): 934 | X, y = make_classification(n_samples=20, n_features=10, random_state=0) 935 | clf = FailingClassifier() 936 | 937 | # refit=False because we want to test the behaviour of the grid search part 938 | gs = dcv.GridSearchCV(clf, [{'parameter': [0, 1, 2]}], scoring='accuracy', 939 | refit=False, error_score='raise') 940 | 941 | # FailingClassifier issues a ValueError so this is what we look for. 942 | with pytest.raises(ValueError): 943 | gs.fit(X, y) 944 | 945 | 946 | def test_search_train_scores_set_to_false(): 947 | X = np.arange(6).reshape(6, -1) 948 | y = [0, 0, 0, 1, 1, 1] 949 | clf = LinearSVC(random_state=0) 950 | 951 | gs = dcv.GridSearchCV(clf, param_grid={'C': [0.1, 0.2]}, 952 | return_train_score=False) 953 | gs.fit(X, y) 954 | -------------------------------------------------------------------------------- /dask_searchcv/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import dask.array as da 4 | from dask.base import Base, tokenize 5 | from dask.delayed import delayed, Delayed 6 | 7 | from sklearn.utils.validation import indexable, _is_arraylike 8 | 9 | 10 | def _indexable(x): 11 | return indexable(x)[0] 12 | 13 | 14 | def _maybe_indexable(x): 15 | return indexable(x)[0] if _is_arraylike(x) else x 16 | 17 | 18 | def to_indexable(*args, **kwargs): 19 | """Ensure that all args are an indexable type. 20 | 21 | Conversion runs lazily for dask objects, immediately otherwise. 22 | 23 | Parameters 24 | ---------- 25 | args : array_like or scalar 26 | allow_scalars : bool, optional 27 | Whether to allow scalars in args. Default is False. 28 | """ 29 | if kwargs.get('allow_scalars', False): 30 | indexable = _maybe_indexable 31 | else: 32 | indexable = _indexable 33 | for x in args: 34 | if x is None or isinstance(x, da.Array): 35 | yield x 36 | elif isinstance(x, Base): 37 | yield delayed(indexable, pure=True)(x) 38 | else: 39 | yield indexable(x) 40 | 41 | 42 | def to_keys(dsk, *args): 43 | for x in args: 44 | if x is None: 45 | yield None 46 | elif isinstance(x, da.Array): 47 | x = delayed(x) 48 | dsk.update(x.dask) 49 | yield x.key 50 | elif isinstance(x, Delayed): 51 | dsk.update(x.dask) 52 | yield x.key 53 | else: 54 | assert not isinstance(x, Base) 55 | key = 'array-' + tokenize(x) 56 | dsk[key] = x 57 | yield key 58 | 59 | 60 | def copy_estimator(est): 61 | # Semantically, we'd like to use `sklearn.clone` here instead. However, 62 | # `sklearn.clone` isn't threadsafe, so we don't want to call it in 63 | # tasks. Since `est` is guaranteed to not be a fit estimator, we can 64 | # use `copy.deepcopy` here without fear of copying large data. 65 | return copy.deepcopy(est) 66 | 67 | 68 | def unzip(itbl, n): 69 | return zip(*itbl) if itbl else [()] * n 70 | -------------------------------------------------------------------------------- /dask_searchcv/utils_test.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from sklearn.base import BaseEstimator, ClassifierMixin 5 | from sklearn.utils.validation import _num_samples, check_array 6 | 7 | 8 | # This class doesn't inherit from BaseEstimator to test hyperparameter search 9 | # on user-defined classifiers. 10 | class MockClassifier(object): 11 | """Dummy classifier to test the parameter search algorithms""" 12 | def __init__(self, foo_param=0): 13 | self.foo_param = foo_param 14 | 15 | def fit(self, X, Y): 16 | assert len(X) == len(Y) 17 | self.classes_ = np.unique(Y) 18 | return self 19 | 20 | def predict(self, T): 21 | return T.shape[0] 22 | 23 | predict_proba = predict 24 | predict_log_proba = predict 25 | decision_function = predict 26 | inverse_transform = predict 27 | 28 | def transform(self, X): 29 | return X 30 | 31 | def score(self, X=None, Y=None): 32 | if self.foo_param > 1: 33 | score = 1. 34 | else: 35 | score = 0. 36 | return score 37 | 38 | def get_params(self, deep=False): 39 | return {'foo_param': self.foo_param} 40 | 41 | def set_params(self, **params): 42 | self.foo_param = params['foo_param'] 43 | return self 44 | 45 | 46 | class ScalingTransformer(BaseEstimator): 47 | def __init__(self, factor=1): 48 | self.factor = factor 49 | 50 | def fit(self, X, y): 51 | return self 52 | 53 | def transform(self, X): 54 | return X * self.factor 55 | 56 | 57 | class CheckXClassifier(BaseEstimator): 58 | """Used to check output of featureunions""" 59 | def __init__(self, expected_X=None): 60 | self.expected_X = expected_X 61 | 62 | def fit(self, X, y): 63 | assert (X == self.expected_X).all() 64 | assert len(X) == len(y) 65 | return self 66 | 67 | def predict(self, X): 68 | return X.sum(axis=1) 69 | 70 | def score(self, X=None, y=None): 71 | return self.predict(X)[0] 72 | 73 | 74 | class FailingClassifier(BaseEstimator): 75 | """Classifier that raises a ValueError on fit()""" 76 | 77 | FAILING_PARAMETER = 2 78 | 79 | def __init__(self, parameter=None): 80 | self.parameter = parameter 81 | 82 | def fit(self, X, y=None): 83 | if self.parameter == FailingClassifier.FAILING_PARAMETER: 84 | raise ValueError("Failing classifier failed as required") 85 | return self 86 | 87 | def transform(self, X): 88 | return X 89 | 90 | def predict(self, X): 91 | return np.zeros(X.shape[0]) 92 | 93 | 94 | def ignore_warnings(f): 95 | """A super simple version of `sklearn.utils.testing.ignore_warnings""" 96 | def _(*args, **kwargs): 97 | with warnings.catch_warnings(record=True): 98 | f(*args, **kwargs) 99 | return _ 100 | 101 | 102 | # XXX: Mocking classes copied from sklearn.utils.mocking to remove nose 103 | # dependency. Can be removed when scikit-learn switches to pytest. See issue 104 | # here: https://github.com/scikit-learn/scikit-learn/issues/7319 105 | 106 | class ArraySlicingWrapper(object): 107 | def __init__(self, array): 108 | self.array = array 109 | 110 | def __getitem__(self, aslice): 111 | return MockDataFrame(self.array[aslice]) 112 | 113 | 114 | class MockDataFrame(object): 115 | # have shape and length but don't support indexing. 116 | def __init__(self, array): 117 | self.array = array 118 | self.values = array 119 | self.shape = array.shape 120 | self.ndim = array.ndim 121 | # ugly hack to make iloc work. 122 | self.iloc = ArraySlicingWrapper(array) 123 | 124 | def __len__(self): 125 | return len(self.array) 126 | 127 | def __array__(self, dtype=None): 128 | # Pandas data frames also are array-like: we want to make sure that 129 | # input validation in cross-validation does not try to call that 130 | # method. 131 | return self.array 132 | 133 | 134 | class CheckingClassifier(BaseEstimator, ClassifierMixin): 135 | """Dummy classifier to test pipelining and meta-estimators. 136 | 137 | Checks some property of X and y in fit / predict. 138 | This allows testing whether pipelines / cross-validation or metaestimators 139 | changed the input. 140 | """ 141 | def __init__(self, check_y=None, check_X=None, foo_param=0, 142 | expected_fit_params=None): 143 | self.check_y = check_y 144 | self.check_X = check_X 145 | self.foo_param = foo_param 146 | self.expected_fit_params = expected_fit_params 147 | 148 | def fit(self, X, y, **fit_params): 149 | assert len(X) == len(y) 150 | if self.check_X is not None: 151 | assert self.check_X(X) 152 | if self.check_y is not None: 153 | assert self.check_y(y) 154 | self.classes_ = np.unique(check_array(y, ensure_2d=False, 155 | allow_nd=True)) 156 | if self.expected_fit_params: 157 | missing = set(self.expected_fit_params) - set(fit_params) 158 | assert len(missing) == 0, ('Expected fit parameter(s) %s not ' 159 | 'seen.' % list(missing)) 160 | for key, value in fit_params.items(): 161 | assert len(value) == len(X), ('Fit parameter %s has length %d; ' 162 | 'expected %d.' % (key, len(value), 163 | len(X))) 164 | return self 165 | 166 | def predict(self, T): 167 | if self.check_X is not None: 168 | assert self.check_X(T) 169 | return self.classes_[np.zeros(_num_samples(T), dtype=np.int)] 170 | 171 | def score(self, X=None, Y=None): 172 | if self.foo_param > 1: 173 | score = 1. 174 | else: 175 | score = 0. 176 | return score 177 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = dask-searchcv 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=dask-searchcv 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | === 3 | 4 | .. currentmodule:: dask_searchcv 5 | 6 | .. autosummary:: 7 | GridSearchCV 8 | RandomizedSearchCV 9 | 10 | .. autoclass:: GridSearchCV 11 | :members: 12 | :inherited-members: 13 | 14 | .. autoclass:: RandomizedSearchCV 15 | :members: 16 | :inherited-members: 17 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # dask-searchcv documentation build configuration file, created by 5 | # sphinx-quickstart on Fri Mar 31 16:55:14 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | sys.path.insert(0, os.path.abspath('.')) 23 | 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # 29 | # needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.autosummary', 35 | 'sphinx.ext.mathjax', 'sphinxext.numpydoc'] 36 | 37 | autosummary_generate = True 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # The suffix(es) of source filenames. 43 | # You can specify multiple suffix as a list of string: 44 | # 45 | # source_suffix = ['.rst', '.md'] 46 | source_suffix = '.rst' 47 | 48 | # The master toctree document. 49 | master_doc = 'index' 50 | 51 | # General information about the project. 52 | project = 'dask-searchcv' 53 | copyright = '2017, Dask Development Team' 54 | author = 'Dask Development Team' 55 | 56 | # The version info for the project you're documenting, acts as replacement for 57 | # |version| and |release|, also used in various other places throughout the 58 | # built documents. 59 | # 60 | # The short X.Y version. 61 | version = '' 62 | # The full version, including alpha/beta/rc tags. 63 | release = '' 64 | 65 | # The language for content autogenerated by Sphinx. Refer to documentation 66 | # for a list of supported languages. 67 | # 68 | # This is also used if you do content translation via gettext catalogs. 69 | # Usually you set "language" from the command line for these cases. 70 | language = None 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This patterns also effect to html_static_path and html_extra_path 75 | exclude_patterns = [] 76 | 77 | # The name of the Pygments (syntax highlighting) style to use. 78 | pygments_style = 'sphinx' 79 | 80 | # If true, `todo` and `todoList` produce output, else they produce nothing. 81 | todo_include_todos = False 82 | 83 | 84 | # -- Options for HTML output ---------------------------------------------- 85 | 86 | # Taken from docs.readthedocs.io: 87 | # on_rtd is whether we are on readthedocs.io 88 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 89 | 90 | if not on_rtd: # only import and set the theme if we're building docs locally 91 | import sphinx_rtd_theme 92 | html_theme = 'sphinx_rtd_theme' 93 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 94 | 95 | # Theme options are theme-specific and customize the look and feel of a theme 96 | # further. For a list of options available for each theme, see the 97 | # documentation. 98 | # 99 | # html_theme_options = {} 100 | 101 | # Add any paths that contain custom static files (such as style sheets) here, 102 | # relative to this directory. They are copied after the builtin static files, 103 | # so a file named "default.css" will overwrite the builtin "default.css". 104 | html_static_path = ['_static'] 105 | 106 | 107 | # -- Options for HTMLHelp output ------------------------------------------ 108 | 109 | # Output file base name for HTML help builder. 110 | htmlhelp_basename = 'dask-searchcvdoc' 111 | 112 | 113 | # -- Options for LaTeX output --------------------------------------------- 114 | 115 | latex_elements = { 116 | # The paper size ('letterpaper' or 'a4paper'). 117 | # 118 | # 'papersize': 'letterpaper', 119 | 120 | # The font size ('10pt', '11pt' or '12pt'). 121 | # 122 | # 'pointsize': '10pt', 123 | 124 | # Additional stuff for the LaTeX preamble. 125 | # 126 | # 'preamble': '', 127 | 128 | # Latex figure (float) alignment 129 | # 130 | # 'figure_align': 'htbp', 131 | } 132 | 133 | # Grouping the document tree into LaTeX files. List of tuples 134 | # (source start file, target name, title, 135 | # author, documentclass [howto, manual, or own class]). 136 | latex_documents = [ 137 | (master_doc, 'dask-searchcv.tex', 'dask-searchcv Documentation', 138 | 'Dask Development Team', 'manual'), 139 | ] 140 | 141 | 142 | # -- Options for manual page output --------------------------------------- 143 | 144 | # One entry per manual page. List of tuples 145 | # (source start file, name, description, authors, manual section). 146 | man_pages = [ 147 | (master_doc, 'dask-searchcv', 'dask-searchcv Documentation', 148 | [author], 1) 149 | ] 150 | 151 | 152 | # -- Options for Texinfo output ------------------------------------------- 153 | 154 | # Grouping the document tree into Texinfo files. List of tuples 155 | # (source start file, target name, title, author, 156 | # dir menu entry, description, category) 157 | texinfo_documents = [ 158 | (master_doc, 'dask-searchcv', 'dask-searchcv Documentation', 159 | author, 'dask-searchcv', 'One line description of project.', 160 | 'Miscellaneous'), 161 | ] 162 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | dask-searchcv 2 | ============= 3 | 4 | Tools for performing hyperparameter optimization of Scikit-Learn models using 5 | Dask. 6 | 7 | Introduction 8 | ------------ 9 | 10 | This library provides implementations of Scikit-Learn's ``GridSearchCV`` and 11 | ``RandomizedSearchCV``. They implement many (but not all) of the same 12 | parameters, and should be a drop-in replacement for the subset that they do 13 | implement. For certain problems, these implementations can be more efficient 14 | than those in Scikit-Learn, as they can avoid expensive repeated computations. 15 | 16 | Highlights 17 | ---------- 18 | 19 | - Drop-in replacement for Scikit-Learn's ``GridSearchCV`` and 20 | ``RandomizedSearchCV``. 21 | 22 | - Hyperparameter optimization can be done in parallel using threads, processes, 23 | or distributed across a cluster. 24 | 25 | - Works well with Dask collections. Dask arrays, dataframes, and delayed can be 26 | passed to ``fit``. 27 | 28 | - Candidate estimators with identical parameters and inputs will only be fit 29 | once. For meta-estimators such as ``Pipeline`` this can be significantly more 30 | efficient as it can avoid expensive repeated computations. 31 | 32 | Install 33 | ------- 34 | 35 | Dask-searchcv is available via ``pip``: 36 | 37 | :: 38 | 39 | $ pip install dask-searchcv 40 | 41 | Example 42 | ------- 43 | 44 | .. code-block:: python 45 | 46 | from sklearn.datasets import load_digits 47 | from sklearn.svm import SVC 48 | import dask_searchcv as dcv 49 | import numpy as np 50 | 51 | digits = load_digits() 52 | 53 | param_space = {'C': np.logspace(-4, 4, 9), 54 | 'gamma': np.logspace(-4, 4, 9), 55 | 'class_weight': [None, 'balanced']} 56 | 57 | model = SVC(kernel='rbf') 58 | search = dcv.GridSearchCV(model, param_space, cv=3) 59 | 60 | search.fit(digits.data, digits.target) 61 | 62 | Index 63 | ----- 64 | 65 | .. toctree:: 66 | 67 | api 68 | -------------------------------------------------------------------------------- /docs/source/sphinxext/LICENSE.txt: -------------------------------------------------------------------------------- 1 | ------------------------------------------------------------------------------- 2 | The files 3 | - numpydoc.py 4 | - docscrape.py 5 | - docscrape_sphinx.py 6 | have the following license: 7 | 8 | Copyright (C) 2008 Stefan van der Walt , Pauli Virtanen 9 | 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions are 12 | met: 13 | 14 | 1. Redistributions of source code must retain the above copyright 15 | notice, this list of conditions and the following disclaimer. 16 | 2. Redistributions in binary form must reproduce the above copyright 17 | notice, this list of conditions and the following disclaimer in 18 | the documentation and/or other materials provided with the 19 | distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 22 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, 25 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 28 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 29 | STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING 30 | IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /docs/source/sphinxext/docscrape.py: -------------------------------------------------------------------------------- 1 | """Extract reference documentation from the NumPy source tree. 2 | 3 | """ 4 | 5 | import inspect 6 | import textwrap 7 | import re 8 | import pydoc 9 | from warnings import warn 10 | # Try Python 2 first, otherwise load from Python 3 11 | try: 12 | from StringIO import StringIO 13 | except: 14 | from io import StringIO 15 | 16 | 17 | class Reader(object): 18 | """A line-based string reader. 19 | 20 | """ 21 | def __init__(self, data): 22 | """ 23 | Parameters 24 | ---------- 25 | data : str 26 | String with lines separated by '\n'. 27 | 28 | """ 29 | if isinstance(data, list): 30 | self._str = data 31 | else: 32 | self._str = data.split('\n') # store string as list of lines 33 | 34 | self.reset() 35 | 36 | def __getitem__(self, n): 37 | return self._str[n] 38 | 39 | def reset(self): 40 | self._l = 0 # current line nr 41 | 42 | def read(self): 43 | if not self.eof(): 44 | out = self[self._l] 45 | self._l += 1 46 | return out 47 | else: 48 | return '' 49 | 50 | def seek_next_non_empty_line(self): 51 | for l in self[self._l:]: 52 | if l.strip(): 53 | break 54 | else: 55 | self._l += 1 56 | 57 | def eof(self): 58 | return self._l >= len(self._str) 59 | 60 | def read_to_condition(self, condition_func): 61 | start = self._l 62 | for line in self[start:]: 63 | if condition_func(line): 64 | return self[start:self._l] 65 | self._l += 1 66 | if self.eof(): 67 | return self[start:self._l + 1] 68 | return [] 69 | 70 | def read_to_next_empty_line(self): 71 | self.seek_next_non_empty_line() 72 | 73 | def is_empty(line): 74 | return not line.strip() 75 | return self.read_to_condition(is_empty) 76 | 77 | def read_to_next_unindented_line(self): 78 | def is_unindented(line): 79 | return (line.strip() and (len(line.lstrip()) == len(line))) 80 | return self.read_to_condition(is_unindented) 81 | 82 | def peek(self, n=0): 83 | if self._l + n < len(self._str): 84 | return self[self._l + n] 85 | else: 86 | return '' 87 | 88 | def is_empty(self): 89 | return not ''.join(self._str).strip() 90 | 91 | 92 | class NumpyDocString(object): 93 | def __init__(self, docstring, config={}): 94 | docstring = textwrap.dedent(docstring).split('\n') 95 | 96 | self._doc = Reader(docstring) 97 | self._parsed_data = { 98 | 'Signature': '', 99 | 'Summary': [''], 100 | 'Extended Summary': [], 101 | 'Parameters': [], 102 | 'Returns': [], 103 | 'Raises': [], 104 | 'Warns': [], 105 | 'Other Parameters': [], 106 | 'Attributes': [], 107 | 'Methods': [], 108 | 'See Also': [], 109 | 'Notes': [], 110 | 'Warnings': [], 111 | 'References': '', 112 | 'Examples': '', 113 | 'index': {} 114 | } 115 | 116 | self._parse() 117 | 118 | def __getitem__(self, key): 119 | return self._parsed_data[key] 120 | 121 | def __setitem__(self, key, val): 122 | if key not in self._parsed_data: 123 | warn("Unknown section %s" % key) 124 | else: 125 | self._parsed_data[key] = val 126 | 127 | def _is_at_section(self): 128 | self._doc.seek_next_non_empty_line() 129 | 130 | if self._doc.eof(): 131 | return False 132 | 133 | l1 = self._doc.peek().strip() # e.g. Parameters 134 | 135 | if l1.startswith('.. index::'): 136 | return True 137 | 138 | l2 = self._doc.peek(1).strip() # ---------- or ========== 139 | return l2.startswith('-' * len(l1)) or l2.startswith('=' * len(l1)) 140 | 141 | def _strip(self, doc): 142 | i = 0 143 | j = 0 144 | for i, line in enumerate(doc): 145 | if line.strip(): 146 | break 147 | 148 | for j, line in enumerate(doc[::-1]): 149 | if line.strip(): 150 | break 151 | 152 | return doc[i:len(doc) - j] 153 | 154 | def _read_to_next_section(self): 155 | section = self._doc.read_to_next_empty_line() 156 | 157 | while not self._is_at_section() and not self._doc.eof(): 158 | if not self._doc.peek(-1).strip(): # previous line was empty 159 | section += [''] 160 | 161 | section += self._doc.read_to_next_empty_line() 162 | 163 | return section 164 | 165 | def _read_sections(self): 166 | while not self._doc.eof(): 167 | data = self._read_to_next_section() 168 | name = data[0].strip() 169 | 170 | if name.startswith('..'): # index section 171 | yield name, data[1:] 172 | elif len(data) < 2: 173 | yield StopIteration 174 | else: 175 | yield name, self._strip(data[2:]) 176 | 177 | def _parse_param_list(self, content): 178 | r = Reader(content) 179 | params = [] 180 | while not r.eof(): 181 | header = r.read().strip() 182 | if ' : ' in header: 183 | arg_name, arg_type = header.split(' : ')[:2] 184 | else: 185 | arg_name, arg_type = header, '' 186 | 187 | desc = r.read_to_next_unindented_line() 188 | desc = dedent_lines(desc) 189 | 190 | params.append((arg_name, arg_type, desc)) 191 | 192 | return params 193 | 194 | _name_rgx = re.compile(r"^\s*(:(?P\w+):`(?P[a-zA-Z0-9_.-]+)`|" 195 | r" (?P[a-zA-Z0-9_.-]+))\s*", re.X) 196 | 197 | def _parse_see_also(self, content): 198 | """ 199 | func_name : Descriptive text 200 | continued text 201 | another_func_name : Descriptive text 202 | func_name1, func_name2, :meth:`func_name`, func_name3 203 | 204 | """ 205 | items = [] 206 | 207 | def parse_item_name(text): 208 | """Match ':role:`name`' or 'name'""" 209 | m = self._name_rgx.match(text) 210 | if m: 211 | g = m.groups() 212 | if g[1] is None: 213 | return g[3], None 214 | else: 215 | return g[2], g[1] 216 | raise ValueError("%s is not a item name" % text) 217 | 218 | def push_item(name, rest): 219 | if not name: 220 | return 221 | name, role = parse_item_name(name) 222 | items.append((name, list(rest), role)) 223 | del rest[:] 224 | 225 | current_func = None 226 | rest = [] 227 | 228 | for line in content: 229 | if not line.strip(): 230 | continue 231 | 232 | m = self._name_rgx.match(line) 233 | if m and line[m.end():].strip().startswith(':'): 234 | push_item(current_func, rest) 235 | current_func, line = line[:m.end()], line[m.end():] 236 | rest = [line.split(':', 1)[1].strip()] 237 | if not rest[0]: 238 | rest = [] 239 | elif not line.startswith(' '): 240 | push_item(current_func, rest) 241 | current_func = None 242 | if ',' in line: 243 | for func in line.split(','): 244 | push_item(func, []) 245 | elif line.strip(): 246 | current_func = line 247 | elif current_func is not None: 248 | rest.append(line.strip()) 249 | push_item(current_func, rest) 250 | return items 251 | 252 | def _parse_index(self, section, content): 253 | """ 254 | .. index: default 255 | :refguide: something, else, and more 256 | 257 | """ 258 | def strip_each_in(lst): 259 | return [s.strip() for s in lst] 260 | 261 | out = {} 262 | section = section.split('::') 263 | if len(section) > 1: 264 | out['default'] = strip_each_in(section[1].split(','))[0] 265 | for line in content: 266 | line = line.split(':') 267 | if len(line) > 2: 268 | out[line[1]] = strip_each_in(line[2].split(',')) 269 | return out 270 | 271 | def _parse_summary(self): 272 | """Grab signature (if given) and summary""" 273 | if self._is_at_section(): 274 | return 275 | 276 | summary = self._doc.read_to_next_empty_line() 277 | summary_str = " ".join([s.strip() for s in summary]).strip() 278 | if re.compile('^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str): 279 | self['Signature'] = summary_str 280 | if not self._is_at_section(): 281 | self['Summary'] = self._doc.read_to_next_empty_line() 282 | else: 283 | self['Summary'] = summary 284 | 285 | if not self._is_at_section(): 286 | self['Extended Summary'] = self._read_to_next_section() 287 | 288 | def _parse(self): 289 | self._doc.reset() 290 | self._parse_summary() 291 | 292 | for (section, content) in self._read_sections(): 293 | if not section.startswith('..'): 294 | section = ' '.join([s.capitalize() 295 | for s in section.split(' ')]) 296 | if section in ('Parameters', 'Attributes', 'Methods', 297 | 'Returns', 'Raises', 'Warns'): 298 | self[section] = self._parse_param_list(content) 299 | elif section.startswith('.. index::'): 300 | self['index'] = self._parse_index(section, content) 301 | elif section == 'See Also': 302 | self['See Also'] = self._parse_see_also(content) 303 | else: 304 | self[section] = content 305 | 306 | # string conversion routines 307 | 308 | def _str_header(self, name, symbol='-'): 309 | return [name, len(name) * symbol] 310 | 311 | def _str_indent(self, doc, indent=4): 312 | out = [] 313 | for line in doc: 314 | out += [' ' * indent + line] 315 | return out 316 | 317 | def _str_signature(self): 318 | if self['Signature']: 319 | return [self['Signature'].replace('*', '\*')] + [''] 320 | else: 321 | return [''] 322 | 323 | def _str_summary(self): 324 | if self['Summary']: 325 | return self['Summary'] + [''] 326 | else: 327 | return [] 328 | 329 | def _str_extended_summary(self): 330 | if self['Extended Summary']: 331 | return self['Extended Summary'] + [''] 332 | else: 333 | return [] 334 | 335 | def _str_param_list(self, name): 336 | out = [] 337 | if self[name]: 338 | out += self._str_header(name) 339 | for param, param_type, desc in self[name]: 340 | out += ['%s : %s' % (param, param_type)] 341 | out += self._str_indent(desc) 342 | out += [''] 343 | return out 344 | 345 | def _str_section(self, name): 346 | out = [] 347 | if self[name]: 348 | out += self._str_header(name) 349 | out += self[name] 350 | out += [''] 351 | return out 352 | 353 | def _str_see_also(self, func_role): 354 | if not self['See Also']: 355 | return [] 356 | out = [] 357 | out += self._str_header("See Also") 358 | last_had_desc = True 359 | for func, desc, role in self['See Also']: 360 | if role: 361 | link = ':%s:`%s`' % (role, func) 362 | elif func_role: 363 | link = ':%s:`%s`' % (func_role, func) 364 | else: 365 | link = "`%s`_" % func 366 | if desc or last_had_desc: 367 | out += [''] 368 | out += [link] 369 | else: 370 | out[-1] += ", %s" % link 371 | if desc: 372 | out += self._str_indent([' '.join(desc)]) 373 | last_had_desc = True 374 | else: 375 | last_had_desc = False 376 | out += [''] 377 | return out 378 | 379 | def _str_index(self): 380 | idx = self['index'] 381 | out = [] 382 | out += ['.. index:: %s' % idx.get('default', '')] 383 | for section, references in idx.iteritems(): 384 | if section == 'default': 385 | continue 386 | out += [' :%s: %s' % (section, ', '.join(references))] 387 | return out 388 | 389 | def __str__(self, func_role=''): 390 | out = [] 391 | out += self._str_signature() 392 | out += self._str_summary() 393 | out += self._str_extended_summary() 394 | for param_list in ('Parameters', 'Returns', 'Raises'): 395 | out += self._str_param_list(param_list) 396 | out += self._str_section('Warnings') 397 | out += self._str_see_also(func_role) 398 | for s in ('Notes', 'References', 'Examples'): 399 | out += self._str_section(s) 400 | for param_list in ('Attributes', 'Methods'): 401 | out += self._str_param_list(param_list) 402 | out += self._str_index() 403 | return '\n'.join(out) 404 | 405 | 406 | def indent(str, indent=4): 407 | indent_str = ' ' * indent 408 | if str is None: 409 | return indent_str 410 | lines = str.split('\n') 411 | return '\n'.join(indent_str + l for l in lines) 412 | 413 | 414 | def dedent_lines(lines): 415 | """Deindent a list of lines maximally""" 416 | return textwrap.dedent("\n".join(lines)).split("\n") 417 | 418 | 419 | def header(text, style='-'): 420 | return text + '\n' + style * len(text) + '\n' 421 | 422 | 423 | class FunctionDoc(NumpyDocString): 424 | def __init__(self, func, role='func', doc=None, config={}): 425 | self._f = func 426 | self._role = role # e.g. "func" or "meth" 427 | 428 | if doc is None: 429 | if func is None: 430 | raise ValueError("No function or docstring given") 431 | doc = inspect.getdoc(func) or '' 432 | NumpyDocString.__init__(self, doc) 433 | 434 | if not self['Signature'] and func is not None: 435 | func, func_name = self.get_func() 436 | try: 437 | # try to read signature 438 | argspec = inspect.getargspec(func) 439 | argspec = inspect.formatargspec(*argspec) 440 | argspec = argspec.replace('*', '\*') 441 | signature = '%s%s' % (func_name, argspec) 442 | except TypeError as e: 443 | signature = '%s()' % func_name 444 | self['Signature'] = signature 445 | 446 | def get_func(self): 447 | func_name = getattr(self._f, '__name__', self.__class__.__name__) 448 | if inspect.isclass(self._f): 449 | func = getattr(self._f, '__call__', self._f.__init__) 450 | else: 451 | func = self._f 452 | return func, func_name 453 | 454 | def __str__(self): 455 | out = '' 456 | 457 | func, func_name = self.get_func() 458 | signature = self['Signature'].replace('*', '\*') 459 | 460 | roles = {'func': 'function', 461 | 'meth': 'method'} 462 | 463 | if self._role: 464 | if self._role not in roles: 465 | print("Warning: invalid role %s" % self._role) 466 | out += '.. %s:: %s\n \n\n' % (roles.get(self._role, ''), 467 | func_name) 468 | 469 | out += super(FunctionDoc, self).__str__(func_role=self._role) 470 | return out 471 | 472 | 473 | class ClassDoc(NumpyDocString): 474 | def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc, 475 | config=None): 476 | if not inspect.isclass(cls) and cls is not None: 477 | raise ValueError("Expected a class or None, but got %r" % cls) 478 | self._cls = cls 479 | 480 | if modulename and not modulename.endswith('.'): 481 | modulename += '.' 482 | self._mod = modulename 483 | 484 | if doc is None: 485 | if cls is None: 486 | raise ValueError("No class or documentation string given") 487 | doc = pydoc.getdoc(cls) 488 | 489 | NumpyDocString.__init__(self, doc) 490 | 491 | if config is not None and config.get('show_class_members', True): 492 | if not self['Methods']: 493 | self['Methods'] = [(name, '', '') 494 | for name in sorted(self.methods)] 495 | if not self['Attributes']: 496 | self['Attributes'] = [(name, '', '') 497 | for name in sorted(self.properties)] 498 | 499 | @property 500 | def methods(self): 501 | if self._cls is None: 502 | return [] 503 | return [name for name, func in inspect.getmembers(self._cls) 504 | if not name.startswith('_') and callable(func)] 505 | 506 | @property 507 | def properties(self): 508 | if self._cls is None: 509 | return [] 510 | return [name for name, func in inspect.getmembers(self._cls) 511 | if not name.startswith('_') and func is None] 512 | -------------------------------------------------------------------------------- /docs/source/sphinxext/docscrape_sphinx.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inspect 3 | import textwrap 4 | import pydoc 5 | from .docscrape import NumpyDocString 6 | from .docscrape import FunctionDoc 7 | from .docscrape import ClassDoc 8 | 9 | 10 | class SphinxDocString(NumpyDocString): 11 | def __init__(self, docstring, config=None): 12 | config = {} if config is None else config 13 | self.use_plots = config.get('use_plots', False) 14 | NumpyDocString.__init__(self, docstring, config=config) 15 | 16 | # string conversion routines 17 | def _str_header(self, name, symbol='`'): 18 | return ['.. rubric:: ' + name, ''] 19 | 20 | def _str_field_list(self, name): 21 | return [':' + name + ':'] 22 | 23 | def _str_indent(self, doc, indent=4): 24 | out = [] 25 | for line in doc: 26 | out += [' ' * indent + line] 27 | return out 28 | 29 | def _str_signature(self): 30 | return [''] 31 | if self['Signature']: 32 | return ['``%s``' % self['Signature']] + [''] 33 | else: 34 | return [''] 35 | 36 | def _str_summary(self): 37 | return self['Summary'] + [''] 38 | 39 | def _str_extended_summary(self): 40 | return self['Extended Summary'] + [''] 41 | 42 | def _str_param_list(self, name): 43 | out = [] 44 | if self[name]: 45 | out += self._str_field_list(name) 46 | out += [''] 47 | for param, param_type, desc in self[name]: 48 | out += self._str_indent(['**%s** : %s' % (param.strip(), 49 | param_type)]) 50 | out += [''] 51 | out += self._str_indent(desc, 8) 52 | out += [''] 53 | return out 54 | 55 | @property 56 | def _obj(self): 57 | if hasattr(self, '_cls'): 58 | return self._cls 59 | elif hasattr(self, '_f'): 60 | return self._f 61 | return None 62 | 63 | def _str_member_list(self, name): 64 | """ 65 | Generate a member listing, autosummary:: table where possible, 66 | and a table where not. 67 | 68 | """ 69 | out = [] 70 | if self[name]: 71 | out += ['.. rubric:: %s' % name, ''] 72 | prefix = getattr(self, '_name', '') 73 | 74 | if prefix: 75 | prefix = '~%s.' % prefix 76 | 77 | autosum = [] 78 | others = [] 79 | for param, param_type, desc in self[name]: 80 | param = param.strip() 81 | if not self._obj or hasattr(self._obj, param): 82 | autosum += [" %s%s" % (prefix, param)] 83 | else: 84 | others.append((param, param_type, desc)) 85 | 86 | if autosum: 87 | # GAEL: Toctree commented out below because it creates 88 | # hundreds of sphinx warnings 89 | # out += ['.. autosummary::', ' :toctree:', ''] 90 | out += ['.. autosummary::', ''] 91 | out += autosum 92 | 93 | if others: 94 | maxlen_0 = max([len(x[0]) for x in others]) 95 | maxlen_1 = max([len(x[1]) for x in others]) 96 | hdr = "=" * maxlen_0 + " " + "=" * maxlen_1 + " " + "=" * 10 97 | fmt = '%%%ds %%%ds ' % (maxlen_0, maxlen_1) 98 | n_indent = maxlen_0 + maxlen_1 + 4 99 | out += [hdr] 100 | for param, param_type, desc in others: 101 | out += [fmt % (param.strip(), param_type)] 102 | out += self._str_indent(desc, n_indent) 103 | out += [hdr] 104 | out += [''] 105 | return out 106 | 107 | def _str_section(self, name): 108 | out = [] 109 | if self[name]: 110 | out += self._str_header(name) 111 | out += [''] 112 | content = textwrap.dedent("\n".join(self[name])).split("\n") 113 | out += content 114 | out += [''] 115 | return out 116 | 117 | def _str_see_also(self, func_role): 118 | out = [] 119 | if self['See Also']: 120 | see_also = super(SphinxDocString, self)._str_see_also(func_role) 121 | out = ['.. seealso::', ''] 122 | out += self._str_indent(see_also[2:]) 123 | return out 124 | 125 | def _str_warnings(self): 126 | out = [] 127 | if self['Warnings']: 128 | out = ['.. warning::', ''] 129 | out += self._str_indent(self['Warnings']) 130 | return out 131 | 132 | def _str_index(self): 133 | idx = self['index'] 134 | out = [] 135 | if len(idx) == 0: 136 | return out 137 | 138 | out += ['.. index:: %s' % idx.get('default', '')] 139 | for section, references in idx.iteritems(): 140 | if section == 'default': 141 | continue 142 | elif section == 'refguide': 143 | out += [' single: %s' % (', '.join(references))] 144 | else: 145 | out += [' %s: %s' % (section, ','.join(references))] 146 | return out 147 | 148 | def _str_references(self): 149 | out = [] 150 | if self['References']: 151 | out += self._str_header('References') 152 | if isinstance(self['References'], str): 153 | self['References'] = [self['References']] 154 | out.extend(self['References']) 155 | out += [''] 156 | # Latex collects all references to a separate bibliography, 157 | # so we need to insert links to it 158 | import sphinx # local import to avoid test dependency 159 | if sphinx.__version__ >= "0.6": 160 | out += ['.. only:: latex', ''] 161 | else: 162 | out += ['.. latexonly::', ''] 163 | items = [] 164 | for line in self['References']: 165 | m = re.match(r'.. \[([a-z0-9._-]+)\]', line, re.I) 166 | if m: 167 | items.append(m.group(1)) 168 | out += [' ' + ", ".join(["[%s]_" % item for item in items]), ''] 169 | return out 170 | 171 | def _str_examples(self): 172 | examples_str = "\n".join(self['Examples']) 173 | 174 | if (self.use_plots and 'import matplotlib' in examples_str 175 | and 'plot::' not in examples_str): 176 | out = [] 177 | out += self._str_header('Examples') 178 | out += ['.. plot::', ''] 179 | out += self._str_indent(self['Examples']) 180 | out += [''] 181 | return out 182 | else: 183 | return self._str_section('Examples') 184 | 185 | def __str__(self, indent=0, func_role="obj"): 186 | out = [] 187 | out += self._str_signature() 188 | out += self._str_index() + [''] 189 | out += self._str_summary() 190 | out += self._str_extended_summary() 191 | for param_list in ('Parameters', 'Returns', 'Raises', 'Attributes'): 192 | out += self._str_param_list(param_list) 193 | out += self._str_warnings() 194 | out += self._str_see_also(func_role) 195 | out += self._str_section('Notes') 196 | out += self._str_references() 197 | out += self._str_examples() 198 | for param_list in ('Methods',): 199 | out += self._str_member_list(param_list) 200 | out = self._str_indent(out, indent) 201 | return '\n'.join(out) 202 | 203 | 204 | class SphinxFunctionDoc(SphinxDocString, FunctionDoc): 205 | def __init__(self, obj, doc=None, config={}): 206 | self.use_plots = config.get('use_plots', False) 207 | FunctionDoc.__init__(self, obj, doc=doc, config=config) 208 | 209 | 210 | class SphinxClassDoc(SphinxDocString, ClassDoc): 211 | def __init__(self, obj, doc=None, func_doc=None, config={}): 212 | self.use_plots = config.get('use_plots', False) 213 | ClassDoc.__init__(self, obj, doc=doc, func_doc=None, config=config) 214 | 215 | 216 | class SphinxObjDoc(SphinxDocString): 217 | def __init__(self, obj, doc=None, config=None): 218 | self._f = obj 219 | SphinxDocString.__init__(self, doc, config=config) 220 | 221 | 222 | def get_doc_object(obj, what=None, doc=None, config={}): 223 | if what is None: 224 | if inspect.isclass(obj): 225 | what = 'class' 226 | elif inspect.ismodule(obj): 227 | what = 'module' 228 | elif callable(obj): 229 | what = 'function' 230 | else: 231 | what = 'object' 232 | if what == 'class': 233 | return SphinxClassDoc(obj, func_doc=SphinxFunctionDoc, doc=doc, 234 | config=config) 235 | elif what in ('function', 'method'): 236 | return SphinxFunctionDoc(obj, doc=doc, config=config) 237 | else: 238 | if doc is None: 239 | doc = pydoc.getdoc(obj) 240 | return SphinxObjDoc(obj, doc, config=config) 241 | -------------------------------------------------------------------------------- /docs/source/sphinxext/numpydoc.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======== 3 | numpydoc 4 | ======== 5 | 6 | Sphinx extension that handles docstrings in the Numpy standard format. [1] 7 | 8 | It will: 9 | 10 | - Convert Parameters etc. sections to field lists. 11 | - Convert See Also section to a See also entry. 12 | - Renumber references. 13 | - Extract the signature from the docstring, if it can't be determined 14 | otherwise. 15 | 16 | .. [1] http://projects.scipy.org/numpy/wiki/CodingStyleGuidelines#docstring-standard 17 | 18 | """ 19 | 20 | from __future__ import unicode_literals 21 | 22 | import sys # Only needed to check Python version 23 | import os 24 | import re 25 | import pydoc 26 | from .docscrape_sphinx import get_doc_object 27 | from .docscrape_sphinx import SphinxDocString 28 | import inspect 29 | 30 | 31 | def mangle_docstrings(app, what, name, obj, options, lines, 32 | reference_offset=[0]): 33 | 34 | cfg = dict(use_plots=app.config.numpydoc_use_plots, 35 | show_class_members=app.config.numpydoc_show_class_members) 36 | 37 | if what == 'module': 38 | # Strip top title 39 | title_re = re.compile(r'^\s*[#*=]{4,}\n[a-z0-9 -]+\n[#*=]{4,}\s*', 40 | re.I | re.S) 41 | lines[:] = title_re.sub('', "\n".join(lines)).split("\n") 42 | else: 43 | doc = get_doc_object(obj, what, "\n".join(lines), config=cfg) 44 | if sys.version_info[0] < 3: 45 | lines[:] = unicode(doc).splitlines() 46 | else: 47 | lines[:] = str(doc).splitlines() 48 | 49 | if app.config.numpydoc_edit_link and hasattr(obj, '__name__') and \ 50 | obj.__name__: 51 | if hasattr(obj, '__module__'): 52 | v = dict(full_name="%s.%s" % (obj.__module__, obj.__name__)) 53 | else: 54 | v = dict(full_name=obj.__name__) 55 | lines += [u'', u'.. htmlonly::', ''] 56 | lines += [u' %s' % x for x in 57 | (app.config.numpydoc_edit_link % v).split("\n")] 58 | 59 | # replace reference numbers so that there are no duplicates 60 | references = [] 61 | for line in lines: 62 | line = line.strip() 63 | m = re.match(r'^.. \[([a-z0-9_.-])\]', line, re.I) 64 | if m: 65 | references.append(m.group(1)) 66 | 67 | # start renaming from the longest string, to avoid overwriting parts 68 | references.sort(key=lambda x: -len(x)) 69 | if references: 70 | for i, line in enumerate(lines): 71 | for r in references: 72 | if re.match(r'^\d+$', r): 73 | new_r = "R%d" % (reference_offset[0] + int(r)) 74 | else: 75 | new_r = u"%s%d" % (r, reference_offset[0]) 76 | lines[i] = lines[i].replace(u'[%s]_' % r, 77 | u'[%s]_' % new_r) 78 | lines[i] = lines[i].replace(u'.. [%s]' % r, 79 | u'.. [%s]' % new_r) 80 | 81 | reference_offset[0] += len(references) 82 | 83 | 84 | def mangle_signature(app, what, name, obj, 85 | options, sig, retann): 86 | # Do not try to inspect classes that don't define `__init__` 87 | if (inspect.isclass(obj) and 88 | (not hasattr(obj, '__init__') or 89 | 'initializes x; see ' in pydoc.getdoc(obj.__init__))): 90 | return '', '' 91 | 92 | if not (callable(obj) or hasattr(obj, '__argspec_is_invalid_')): 93 | return 94 | if not hasattr(obj, '__doc__'): 95 | return 96 | 97 | doc = SphinxDocString(pydoc.getdoc(obj)) 98 | if doc['Signature']: 99 | sig = re.sub("^[^(]*", "", doc['Signature']) 100 | return sig, '' 101 | 102 | 103 | def setup(app, get_doc_object_=get_doc_object): 104 | global get_doc_object 105 | get_doc_object = get_doc_object_ 106 | 107 | if sys.version_info[0] < 3: 108 | app.connect(b'autodoc-process-docstring', mangle_docstrings) 109 | app.connect(b'autodoc-process-signature', mangle_signature) 110 | else: 111 | app.connect('autodoc-process-docstring', mangle_docstrings) 112 | app.connect('autodoc-process-signature', mangle_signature) 113 | app.add_config_value('numpydoc_edit_link', None, False) 114 | app.add_config_value('numpydoc_use_plots', None, False) 115 | app.add_config_value('numpydoc_show_class_members', True, True) 116 | 117 | # Extra mangling domains 118 | app.add_domain(NumpyPythonDomain) 119 | app.add_domain(NumpyCDomain) 120 | 121 | #----------------------------------------------------------------------------- 122 | # Docstring-mangling domains 123 | #----------------------------------------------------------------------------- 124 | 125 | try: 126 | import sphinx # lazy to avoid test dependency 127 | except ImportError: 128 | CDomain = PythonDomain = object 129 | else: 130 | from sphinx.domains.c import CDomain 131 | from sphinx.domains.python import PythonDomain 132 | 133 | 134 | class ManglingDomainBase(object): 135 | directive_mangling_map = {} 136 | 137 | def __init__(self, *a, **kw): 138 | super(ManglingDomainBase, self).__init__(*a, **kw) 139 | self.wrap_mangling_directives() 140 | 141 | def wrap_mangling_directives(self): 142 | for name, objtype in self.directive_mangling_map.items(): 143 | self.directives[name] = wrap_mangling_directive( 144 | self.directives[name], objtype) 145 | 146 | 147 | class NumpyPythonDomain(ManglingDomainBase, PythonDomain): 148 | name = 'np' 149 | directive_mangling_map = { 150 | 'function': 'function', 151 | 'class': 'class', 152 | 'exception': 'class', 153 | 'method': 'function', 154 | 'classmethod': 'function', 155 | 'staticmethod': 'function', 156 | 'attribute': 'attribute', 157 | } 158 | 159 | 160 | class NumpyCDomain(ManglingDomainBase, CDomain): 161 | name = 'np-c' 162 | directive_mangling_map = { 163 | 'function': 'function', 164 | 'member': 'attribute', 165 | 'macro': 'function', 166 | 'type': 'class', 167 | 'var': 'object', 168 | } 169 | 170 | 171 | def wrap_mangling_directive(base_directive, objtype): 172 | class directive(base_directive): 173 | def run(self): 174 | env = self.state.document.settings.env 175 | 176 | name = None 177 | if self.arguments: 178 | m = re.match(r'^(.*\s+)?(.*?)(\(.*)?', self.arguments[0]) 179 | name = m.group(2).strip() 180 | 181 | if not name: 182 | name = self.arguments[0] 183 | 184 | lines = list(self.content) 185 | mangle_docstrings(env.app, objtype, name, None, None, lines) 186 | # local import to avoid testing dependency 187 | from docutils.statemachine import ViewList 188 | self.content = ViewList(lines, self.content.parent) 189 | 190 | return base_directive.run(self) 191 | 192 | return directive 193 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = __init__.py 3 | # Ignore over/under indented for visual indent 4 | ignore = E127,E128 5 | max-line-length = 85 6 | 7 | [versioneer] 8 | VCS = git 9 | style = pep440 10 | versionfile_source = dask_searchcv/_version.py 11 | versionfile_build = dask_searchcv/_version.py 12 | tag_prefix = 13 | parentdir_prefix = dask_searchcv- 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os.path import exists 2 | 3 | from setuptools import setup 4 | import versioneer 5 | 6 | install_requires = ["dask[delayed] >= 0.14.0", 7 | "toolz >= 0.8.2", 8 | "scikit-learn >= 0.18.0", 9 | "numpy"] 10 | 11 | setup(name='dask-searchcv', 12 | version=versioneer.get_version(), 13 | cmdclass=versioneer.get_cmdclass(), 14 | license='BSD', 15 | url='http://github.com/dask/dask-searchcv', 16 | maintainer='Jim Crist', 17 | maintainer_email='jcrist@continuum.io', 18 | install_requires=install_requires, 19 | description='Tools for doing hyperparameter search with Scikit-Learn and Dask', 20 | long_description=(open('README.rst').read() if exists('README.rst') 21 | else ''), 22 | packages=['dask_searchcv', 'dask_searchcv.tests']) 23 | --------------------------------------------------------------------------------