├── .circleci └── config.yml ├── .gitignore ├── CHANGELOG.rst ├── CONTRIBUTING.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── VERSION ├── docs ├── Makefile ├── apidoc.rst ├── changelog.rst ├── conda.yml ├── conf.py ├── hyperparams.rst ├── index.rst └── quickstart.rst ├── notebooks └── readme_example.ipynb ├── readthedocs.yml ├── requirements ├── base.txt └── dev.txt ├── setup.cfg ├── setup.py ├── skconfig ├── __init__.py ├── condition.py ├── distribution.py ├── exceptions.py ├── forbidden.py ├── mapping.py ├── parameter │ ├── __init__.py │ ├── base.py │ ├── convience.py │ ├── interval.py │ └── types.py ├── sampler.py └── validator.py └── tests └── test_parameter.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | test_py: &test_py 2 | docker: 3 | - image: circleci/python:$PY_VERSION 4 | steps: 5 | - checkout 6 | - restore_cache: 7 | key: > 8 | v1-dependency-cache 9 | -{{ checksum "requirements/base.txt" }} 10 | -{{ checksum "requirements/dev.txt" }} 11 | - run: 12 | name: Install dependencies 13 | command: | 14 | python -m venv venv 15 | . venv/bin/activate 16 | pip install numpy 17 | pip install -e .[dev] 18 | - save_cache: 19 | key: > 20 | v1-dependency-cache 21 | -{{ checksum "requirements/base.txt" }} 22 | -{{ checksum "requirements/dev.txt" }} 23 | paths: 24 | - "venv" 25 | - run: 26 | name: Run tests 27 | command: | 28 | . venv/bin/activate 29 | make lint 30 | pytest 31 | - run: 32 | name: Report converage 33 | command: | 34 | if [ "${PY_VERSION}" == "3.6" ]; then 35 | . venv/bin/activate 36 | codecov 37 | fi 38 | 39 | version: 2 40 | jobs: 41 | test_py36: 42 | environment: 43 | - PY_VERSION: 3.6 44 | <<: *test_py 45 | deploy: 46 | docker: 47 | - image: circleci/python:3.6 48 | steps: 49 | - checkout 50 | - run: 51 | name: Check git tag and version are the same 52 | command: | 53 | VERSION=$(cat VERSION) 54 | [ "$VERSION" = "$CIRCLE_TAG" ] 55 | - run: 56 | name: Init .pypirc 57 | command: | 58 | echo -e "[pypi]" >> ~/.pypirc 59 | echo -e "username = $PYPI_USERNAME" >> ~/.pypirc 60 | echo -e "password = $PYPI_PASSWORD" >> ~/.pypirc 61 | - run: 62 | name: Create packages 63 | command: make release 64 | - run: 65 | name: Upload to pypi 66 | command: | 67 | . venv/bin/activate 68 | twine upload dist/* 69 | 70 | workflows: 71 | version: 2 72 | buildall: 73 | jobs: 74 | - test_py36 75 | - deploy: 76 | requires: 77 | - test_py36 78 | filters: 79 | tags: 80 | only: /[0-9]+(\.[0-9]+)*/ 81 | branches: 82 | ignore: /.*/ 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Release Notes 2 | ============= 3 | 4 | 0.1.0 (2019-03-18) 5 | --------------------- 6 | 7 | - First alpha version 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | 1. Report Bugs on https://github.com/thomasjpfan/skconfig/issues. 13 | 2. Fix Bugs on GitHub issues 14 | 3. Implement Features 15 | 4. Write Documentation 16 | 5. Submit Feedback on GitHub issues 17 | 18 | Development 19 | ----------- 20 | 21 | The development version can be installed by: 22 | 23 | .. code-block:: bash 24 | 25 | git clone https://github.com/thomasjpfan/skconfig 26 | cd skconfig 27 | make dev 28 | 29 | Then we can lint ``make lint`` and tests by running ``make test``. 30 | 31 | Pull Request Guidelines 32 | ----------------------- 33 | 34 | Before you submit a pull request, check that it meets these guidelines: 35 | 36 | 1. The pull request should include tests. 37 | 2. If the pull request adds functionality, the docs should be updated. Put 38 | your new functionality into a function with a docstring, and add the 39 | feature to the list in README.rst. 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Thomas J Fan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.rst 3 | include CHANGELOG.rst 4 | 5 | recursive-include tests * 6 | recursive-exclude * __pycache__ 7 | recursive-exclude * *.py[co] 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: dev dev_conda package clean lint docs 2 | 3 | lint: 4 | flake8 skconfig tests 5 | 6 | dev: 7 | pip install numpy 8 | pip install -e .[dev] 9 | 10 | release: 11 | python setup.py sdist 12 | python setup.py bdist_wheel 13 | 14 | clean: 15 | rm -rf dist build */*.egg-info *.egg-info 16 | $(MAKE) -C docs clean 17 | 18 | docs: 19 | $(MAKE) -C docs html 20 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | skconfig 2 | ======== 3 | 4 | Scikit-learn sampling and validation library. 5 | 6 | Features 7 | -------- 8 | 9 | ``skconfig`` is provides two key features: Validation of parameters for 10 | scikit-learn models, and sampling these parameters. The sampling depends on 11 | `ConfigSpace `_. 12 | 13 | Validation 14 | .......... 15 | 16 | ``skconfig`` creates a DSL for defining the search space for a sklearn model. 17 | For example, we can defined a ``LogRegressionValidator`` as follows: 18 | 19 | .. code-block:: python 20 | 21 | class LogRegressionValidator(BaseValidator): 22 | estimator = LogisticRegression 23 | 24 | penalty = StringParam("l2", "l1") 25 | dual = BoolParam() 26 | tol = FloatIntervalParam(lower=0, include_lower=False) 27 | C = FloatIntervalParam(lower=0) 28 | fit_intercept = BoolParam() 29 | intercept_scaling = FloatIntervalParam(lower=0, include_lower=False) 30 | class_weight = NoneParam() 31 | random_state = UnionParam(IntParam(), NoneParam()) 32 | solver = StringParam("newton-cg", "lbfgs", "liblinear", "sag", "saga", "warn") 33 | max_iter = IntIntervalParam(lower=1) 34 | multi_class = StringParam("ovr", "multinomial", "auto", "warn") 35 | verbose = IntParam() 36 | warm_start = BoolParam() 37 | n_jobs = UnionParam(NoneParam(), IntIntervalParam(lower=-1)) 38 | 39 | forbiddens = [ 40 | ForbiddenAnd([ForbiddenEquals("penalty", "l1"), 41 | ForbiddenIn("solver", ["newton-cg", "sag", "lbfgs"])]), 42 | ForbiddenAnd([ForbiddenEquals("solver", "liblinear"), 43 | ForbiddenEquals("multi_class", "multinomial")]), 44 | ] 45 | 46 | 47 | With this validator object, we can validate a set of parameters: 48 | 49 | .. code-block:: python 50 | 51 | validator = LogRegressionValidator() 52 | 53 | # Does not raise an exception 54 | validator.validate_params(multi_class="ovr") 55 | 56 | # These will raise an exception 57 | validator.validate_params(penalty="hello world") 58 | validator.validate_params(solver="liblinear", multi_class="multinomial") 59 | validator.validate_params(penalty="l1", solver="sag") 60 | 61 | params_dict = {"penalty": "l1", "solver": "sag"} 62 | validator.validate_params(**params_dict) 63 | 64 | Or validate a estimator: 65 | 66 | .. code-block:: python 67 | 68 | est = LogisticRegression(solver="liblienar") 69 | validator.validate_estimator(est) # Will not raise 70 | 71 | Sampling 72 | ........ 73 | 74 | To sample the parameter space, a ``skconfig`` has a DSL for defining the 75 | distribution to be sampled from: 76 | 77 | .. code-block:: python 78 | 79 | sampler = Sampler( 80 | validator, 81 | dual=UniformBoolDistribution(), 82 | C=UniformFloatDistribution(0.0, 1.0), 83 | solver=CategoricalDistribution( 84 | ["newton-cg", "lbfgs", "liblinear", "sag", "saga"]), 85 | random_state=UnionDistribution( 86 | ConstantDistribution(None), UniformIntDistribution(0, 10)), 87 | penalty=CategoricalDistribution(["l2", "l1"]), 88 | multi_class=CategoricalDistribution(["ovr", "multinomial"]) 89 | ) 90 | 91 | To sample from we call `sample`: 92 | 93 | .. code-block:: python 94 | 95 | params_sample = sampler.sample(5) 96 | 97 | which returns a list of 5 parameter dicts to be passed to `set_params`: 98 | 99 | .. code-block:: python 100 | 101 | [{'C': 0.38684515891991544, 102 | 'dual': True, 103 | 'multi_class': 'ovr', 104 | 'penalty': 'l2', 105 | 'solver': 'lbfgs', 106 | 'random_state': 1}, 107 | {'C': 0.017914312843795077, 108 | 'dual': True, 109 | 'multi_class': 'ovr', 110 | 'penalty': 'l2', 111 | 'solver': 'lbfgs', 112 | 'random_state': 0}, 113 | {'C': 0.7044064976675997, 114 | 'dual': True, 115 | 'multi_class': 'ovr', 116 | 'penalty': 'l2', 117 | 'solver': 'liblinear', 118 | 'random_state': 7}, 119 | {'C': 0.9066951378139576, 120 | 'dual': False, 121 | 'multi_class': 'ovr', 122 | 'penalty': 'l2', 123 | 'solver': 'sag', 124 | 'random_state': 10}, 125 | {'C': 0.10402966368097444, 126 | 'dual': True, 127 | 'multi_class': 'multinomial', 128 | 'penalty': 'l2', 129 | 'solver': 'saga', 130 | 'random_state': 7}] 131 | 132 | To create an estimator from the first paramter item in ``params_sample``: 133 | 134 | .. code-block:: python 135 | 136 | est = LogisticRegression(**params_sample[0]) 137 | # or 138 | est.set_params(**params_sample[0]) 139 | 140 | Serialization 141 | ............. 142 | 143 | The sampler can be serialized into a json: 144 | 145 | .. code-block:: python 146 | 147 | import json 148 | json_serialized = json.dumps(sampler.to_dict(), indent=2) 149 | print(json_serialized) 150 | 151 | which outputs: 152 | 153 | .. code-block:: python 154 | 155 | { 156 | "dual": { 157 | "default": true, 158 | "type": "UniformBoolDistribution" 159 | }, 160 | "C": { 161 | "lower": 0.0, 162 | "upper": 1.0, 163 | "default": 0.0, 164 | "log": false, 165 | "type": "UniformFloatDistribution" 166 | }, 167 | "solver": { 168 | "choices": [ 169 | "newton-cg", 170 | "lbfgs", 171 | "liblinear", 172 | "sag", 173 | "saga" 174 | ], 175 | "default": "newton-cg", 176 | "type": "CategoricalDistribution" 177 | }, 178 | "random_state": { 179 | "type": "UnionDistribution", 180 | "dists": [ 181 | { 182 | "type": "ConstantDistribution", 183 | "value": null 184 | }, 185 | { 186 | "lower": 0, 187 | "upper": 10, 188 | "default": 0, 189 | "log": false, 190 | "type": "UniformIntDistribution" 191 | } 192 | ] 193 | }, 194 | "penalty": { 195 | "choices": [ 196 | "l2", 197 | "l1" 198 | ], 199 | "default": "l2", 200 | "type": "CategoricalDistribution" 201 | }, 202 | "multi_class": { 203 | "choices": [ 204 | "ovr", 205 | "multinomial" 206 | ], 207 | "default": "ovr", 208 | "type": "CategoricalDistribution" 209 | } 210 | } 211 | 212 | To load the sampler from json 213 | 214 | .. code-block:: python 215 | 216 | sampler_dict = json.loads(json_serialized) 217 | sampler_new = Sampler(validator).from_dict(sampler_dict) 218 | 219 | 220 | Installation 221 | ------------ 222 | 223 | You can install skconfig directly from pypi: 224 | 225 | .. code-block:: bash 226 | 227 | pip install git+https://github.com/thomasjpfan/skconfig 228 | 229 | Development 230 | ----------- 231 | 232 | The development version can be installed by running ``make dev``. Then we can lint ``make lint`` and tests by running ``pytest``. 233 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = skconfig 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/apidoc.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | add 8 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CHANGELOG.rst 2 | -------------------------------------------------------------------------------- /docs/conda.yml: -------------------------------------------------------------------------------- 1 | name: skconfig_docs 2 | 3 | dependencies: 4 | - python=3.6.* 5 | - scikit-learn=0.20.3 6 | - sphinx 7 | - sphinx_rtd_theme 8 | - pip: 9 | - configspace==0.4.9 10 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | sys.path.insert(0, os.path.abspath('..')) 6 | 7 | 8 | extensions = ['sphinx.ext.autodoc', 9 | 'sphinx.ext.viewcode', 10 | 'sphinx.ext.napoleon', 11 | 'sphinx.ext.intersphinx'] 12 | 13 | intersphinx_mapping = { 14 | 'python': ('https://docs.python.org/3', None), 15 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 16 | } 17 | 18 | templates_path = ['_templates'] 19 | 20 | source_suffix = '.rst' 21 | master_doc = 'index' 22 | 23 | project = u'skconfig' 24 | copyright = u"2019, Thomas J Fan" 25 | author = u"Thomas J Fan" 26 | 27 | with open('../VERSION', 'r') as f: 28 | release = f.read().strip() 29 | version = release.rsplit('.', 1)[0] 30 | 31 | language = 'en' 32 | 33 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 34 | 35 | pygments_style = 'sphinx' 36 | 37 | todo_include_todos = False 38 | 39 | # HTML output theme 40 | html_theme = 'sphinx_rtd_theme' 41 | 42 | htmlhelp_basename = 'skconfigdoc' 43 | 44 | # Latex 45 | latex_elements = {} 46 | 47 | latex_documents = [ 48 | (master_doc, 'skconfig.tex', 49 | u'skconfig Documentation', 50 | u'Thomas J Fan', 'manual'), 51 | ] 52 | 53 | # Manual page output 54 | man_pages = [ 55 | (master_doc, 'skconfig', 56 | u'skconfig Documentation', 57 | [author], 1) 58 | ] 59 | 60 | # Textinfo output 61 | texinfo_documents = [ 62 | (master_doc, 'skconfig', 63 | u'skconfig Documentation', 64 | author, 65 | 'skconfig', 66 | 'One line description of project.', 67 | 'Miscellaneous'), 68 | ] 69 | -------------------------------------------------------------------------------- /docs/hyperparams.rst: -------------------------------------------------------------------------------- 1 | Add Module 2 | ========== 3 | 4 | .. automodule:: skconfig.hyperparams 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | Contents 4 | -------- 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | quickstart 10 | apidoc 11 | changelog 12 | 13 | Index 14 | ----- 15 | * :ref:`genindex` 16 | -------------------------------------------------------------------------------- /docs/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | Installation 5 | ------------ 6 | 7 | You can install skconfig directly from pypi: 8 | 9 | .. code-block:: bash 10 | 11 | pip install git+https://github.com/thomasjpfan/skconfig 12 | -------------------------------------------------------------------------------- /notebooks/readme_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 18, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from skconfig.condition import EqualsCondition\n", 10 | "from skconfig.validator import BaseValidator\n", 11 | "from skconfig.forbidden import ForbiddenEquals, ForbiddenIn, ForbiddenAnd\n", 12 | "from skconfig.parameter import StringParam, BoolParam, FloatIntervalParam, NoneParam, UnionParam, IntParam, IntIntervalParam\n", 13 | "from skconfig.distribution import UniformBoolDistribution, UniformFloatDistribution, CategoricalDistribution, UnionDistribution, ConstantDistribution, UniformIntDistribution\n", 14 | "from skconfig.sampler import Sampler\n", 15 | "from sklearn.linear_model import LogisticRegression" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Validation" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "skconfig creates a DSL for defining the search space for a sklearn model. For example, we can defined a LogRegressionValidator as follows:" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 19, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "class LogRegressionValidator(BaseValidator):\n", 39 | " estimator = LogisticRegression\n", 40 | " \n", 41 | " penalty = StringParam(\"l2\", \"l1\")\n", 42 | " dual = BoolParam()\n", 43 | " tol = FloatIntervalParam(lower=0, include_lower=False)\n", 44 | " C = FloatIntervalParam(lower=0)\n", 45 | " fit_intercept = BoolParam()\n", 46 | " intercept_scaling = FloatIntervalParam(lower=0, include_lower=False)\n", 47 | " class_weight = NoneParam()\n", 48 | " random_state = UnionParam(IntParam(), NoneParam())\n", 49 | " solver = StringParam(\"newton-cg\", \"lbfgs\", \"liblinear\", \"sag\", \"saga\", \"warn\")\n", 50 | " max_iter = IntIntervalParam(lower=1)\n", 51 | " multi_class = StringParam(\"ovr\", \"multinomial\", \"auto\", \"warn\")\n", 52 | " verbose = IntParam()\n", 53 | " warm_start = BoolParam()\n", 54 | " n_jobs = UnionParam(NoneParam(), IntIntervalParam(lower=-1))\n", 55 | " \n", 56 | " forbiddens = [\n", 57 | " ForbiddenAnd([ForbiddenEquals(\"penalty\", \"l1\"), ForbiddenIn(\"solver\", [\"newton-cg\", \"sag\", \"lbfgs\"])]),\n", 58 | " ForbiddenAnd([ForbiddenEquals(\"solver\", \"liblinear\"), ForbiddenEquals(\"multi_class\", \"multinomial\")]),\n", 59 | " ]" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "With this validator object, we can validate a set of parameters:" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 23, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "validator = LogRegressionValidator()\n", 76 | "\n", 77 | "validator.validate_params(multi_class=\"ovr\") # Does not raise" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 15, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "ename": "InvalidParamChoices", 87 | "evalue": "penalty must be one of ('l2', 'l1')", 88 | "output_type": "error", 89 | "traceback": [ 90 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 91 | "\u001b[0;31mInvalidParamChoices\u001b[0m Traceback (most recent call last)", 92 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvalidator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpenalty\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"hello world\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 93 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/validator.py\u001b[0m in \u001b[0;36mvalidate_params\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mall_kwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcondition_names\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# Check conditions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 94 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/parameter/types.py\u001b[0m in \u001b[0;36mvalidate\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mInvalidParamType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoices\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mInvalidParamChoices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchoices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 95 | "\u001b[0;31mInvalidParamChoices\u001b[0m: penalty must be one of ('l2', 'l1')" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "validator.validate_params(penalty=\"hello world\")" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "ename": "ForbiddenValue", 110 | "evalue": "solver and multi_class with value liblinear and multinomial is forbidden", 111 | "output_type": "error", 112 | "traceback": [ 113 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 114 | "\u001b[0;31mForbiddenValue\u001b[0m Traceback (most recent call last)", 115 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mvalidator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLogRegressionValidator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mvalidator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msolver\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"liblinear\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmulti_class\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"multinomial\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 116 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/validator.py\u001b[0m in \u001b[0;36mvalidate_params\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;31m# check for forbidden\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mforbidden\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforbiddens\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mforbidden\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_forbidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mall_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mcondition_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchild\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcond\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconditions\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 117 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/forbidden.py\u001b[0m in \u001b[0;36mis_forbidden\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mnames_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" and \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mvalues_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" and \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mForbiddenValue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 118 | "\u001b[0;31mForbiddenValue\u001b[0m: solver and multi_class with value liblinear and multinomial is forbidden" 119 | ] 120 | } 121 | ], 122 | "source": [ 123 | "validator.validate_params(solver=\"liblinear\", multi_class=\"multinomial\")" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 5, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "ename": "ForbiddenValue", 133 | "evalue": "penalty and solver with value l1 and ['newton-cg', 'sag', 'lbfgs'] is forbidden", 134 | "output_type": "error", 135 | "traceback": [ 136 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 137 | "\u001b[0;31mForbiddenValue\u001b[0m Traceback (most recent call last)", 138 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvalidator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpenalty\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"l1\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"sag\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 139 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/validator.py\u001b[0m in \u001b[0;36mvalidate_params\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;31m# check for forbidden\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mforbidden\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforbiddens\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mforbidden\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_forbidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mall_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mcondition_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchild\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcond\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconditions\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 140 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/forbidden.py\u001b[0m in \u001b[0;36mis_forbidden\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mnames_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" and \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mvalues_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" and \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mForbiddenValue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 141 | "\u001b[0;31mForbiddenValue\u001b[0m: penalty and solver with value l1 and ['newton-cg', 'sag', 'lbfgs'] is forbidden" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "validator.validate_params(penalty=\"l1\", solver=\"sag\")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 6, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "ename": "ForbiddenValue", 156 | "evalue": "penalty and solver with value l1 and ['newton-cg', 'sag', 'lbfgs'] is forbidden", 157 | "output_type": "error", 158 | "traceback": [ 159 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 160 | "\u001b[0;31mForbiddenValue\u001b[0m Traceback (most recent call last)", 161 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mparams_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"penalty\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"l1\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"solver\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m\"sag\"\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mvalidator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mparams_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 162 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/validator.py\u001b[0m in \u001b[0;36mvalidate_params\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;31m# check for forbidden\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mforbidden\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforbiddens\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mforbidden\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_forbidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mall_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mcondition_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchild\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcond\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconditions\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 163 | "\u001b[0;32m~/Documents/Repos_Desk/skconfig/skconfig/forbidden.py\u001b[0m in \u001b[0;36mis_forbidden\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mnames_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" and \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mvalues_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\" and \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mForbiddenValue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnames_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 164 | "\u001b[0;31mForbiddenValue\u001b[0m: penalty and solver with value l1 and ['newton-cg', 'sag', 'lbfgs'] is forbidden" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "params_dict = {\"penalty\": \"l1\", \"solver\": \"sag\"}\n", 170 | "validator.validate_params(**params_dict)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "Or validate a estimator:" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 7, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "est = LogisticRegression()\n", 187 | "validator.validate_estimator(est)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "## Sampling\n", 195 | "\n", 196 | "To sample the parameter space, a skconfig has a DSL for defining the distribution to be sampled from:" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 8, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "validator = LogRegressionValidator()\n", 206 | "sampler = Sampler(validator, \n", 207 | " dual=UniformBoolDistribution(),\n", 208 | " C=UniformFloatDistribution(0.0, 1.0),\n", 209 | " solver=CategoricalDistribution([\"newton-cg\", \"lbfgs\", \"liblinear\", \"sag\", \"saga\"]),\n", 210 | " random_state=UnionDistribution(ConstantDistribution(None), UniformIntDistribution(0, 10)),\n", 211 | " penalty=CategoricalDistribution([\"l2\", \"l1\"]),\n", 212 | " multi_class=CategoricalDistribution([\"ovr\", \"multinomial\"])\n", 213 | ")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 12, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "text/plain": [ 224 | "[{'C': 0.49609443571092127,\n", 225 | " 'dual': False,\n", 226 | " 'multi_class': 'ovr',\n", 227 | " 'penalty': 'l2',\n", 228 | " 'solver': 'saga',\n", 229 | " 'random_state': None},\n", 230 | " {'C': 0.7169968253334416,\n", 231 | " 'dual': True,\n", 232 | " 'multi_class': 'multinomial',\n", 233 | " 'penalty': 'l2',\n", 234 | " 'solver': 'saga',\n", 235 | " 'random_state': 0},\n", 236 | " {'C': 0.7899051166909798,\n", 237 | " 'dual': False,\n", 238 | " 'multi_class': 'multinomial',\n", 239 | " 'penalty': 'l2',\n", 240 | " 'solver': 'lbfgs',\n", 241 | " 'random_state': None},\n", 242 | " {'C': 0.4268914739541635,\n", 243 | " 'dual': False,\n", 244 | " 'multi_class': 'ovr',\n", 245 | " 'penalty': 'l1',\n", 246 | " 'solver': 'saga',\n", 247 | " 'random_state': 2},\n", 248 | " {'C': 0.3945090918540334,\n", 249 | " 'dual': True,\n", 250 | " 'multi_class': 'multinomial',\n", 251 | " 'penalty': 'l2',\n", 252 | " 'solver': 'lbfgs',\n", 253 | " 'random_state': 4}]" 254 | ] 255 | }, 256 | "execution_count": 12, 257 | "metadata": {}, 258 | "output_type": "execute_result" 259 | } 260 | ], 261 | "source": [ 262 | "params_sample = sampler.sample(5)\n", 263 | "params_sample" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "Create an estimator from the first param from params_sample" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 10, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "{'C': 0.7215182701146491,\n", 282 | " 'class_weight': None,\n", 283 | " 'dual': False,\n", 284 | " 'fit_intercept': True,\n", 285 | " 'intercept_scaling': 1,\n", 286 | " 'max_iter': 100,\n", 287 | " 'multi_class': 'multinomial',\n", 288 | " 'n_jobs': None,\n", 289 | " 'penalty': 'l2',\n", 290 | " 'random_state': None,\n", 291 | " 'solver': 'newton-cg',\n", 292 | " 'tol': 0.0001,\n", 293 | " 'verbose': 0,\n", 294 | " 'warm_start': False}" 295 | ] 296 | }, 297 | "execution_count": 10, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "est = LogisticRegression(**params_sample[0])\n", 304 | "est.get_params()" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "## Serialization\n", 312 | "\n", 313 | "The sampler can be serialized into a json:" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 24, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "data": { 323 | "application/json": { 324 | "C": { 325 | "default": 0, 326 | "log": false, 327 | "lower": 0, 328 | "type": "UniformFloatDistribution", 329 | "upper": 1 330 | }, 331 | "dual": { 332 | "default": true, 333 | "type": "UniformBoolDistribution" 334 | }, 335 | "multi_class": { 336 | "choices": [ 337 | "ovr", 338 | "multinomial" 339 | ], 340 | "default": "ovr", 341 | "type": "CategoricalDistribution" 342 | }, 343 | "penalty": { 344 | "choices": [ 345 | "l2", 346 | "l1" 347 | ], 348 | "default": "l2", 349 | "type": "CategoricalDistribution" 350 | }, 351 | "random_state": { 352 | "dists": [ 353 | { 354 | "type": "ConstantDistribution", 355 | "value": null 356 | }, 357 | { 358 | "default": 0, 359 | "log": false, 360 | "lower": 0, 361 | "type": "UniformIntDistribution", 362 | "upper": 10 363 | } 364 | ], 365 | "type": "UnionDistribution" 366 | }, 367 | "solver": { 368 | "choices": [ 369 | "newton-cg", 370 | "lbfgs", 371 | "liblinear", 372 | "sag", 373 | "saga" 374 | ], 375 | "default": "newton-cg", 376 | "type": "CategoricalDistribution" 377 | } 378 | }, 379 | "text/plain": [ 380 | "" 381 | ] 382 | }, 383 | "execution_count": 24, 384 | "metadata": { 385 | "application/json": { 386 | "expanded": false, 387 | "root": "root" 388 | } 389 | }, 390 | "output_type": "execute_result" 391 | } 392 | ], 393 | "source": [ 394 | "import json\n", 395 | "from IPython.display import JSON\n", 396 | "\n", 397 | "serialized = sampler.to_dict()\n", 398 | "json_serialized = json.dumps(serialized, indent=2)\n", 399 | "JSON(serialized)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 25, 405 | "metadata": {}, 406 | "outputs": [ 407 | { 408 | "data": { 409 | "text/plain": [ 410 | "dual: UniformBoolDistribution(default=True)\n", 411 | "C: UniformFloatDistribution(lower=0.0, upper=1.0, default=0.0, log=False)\n", 412 | "solver: CategoricalDistribution(choices=['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'], default=newton-cg)\n", 413 | "random_state: UnionDistribution(dists=[{'type': 'ConstantDistribution', 'value': None}, {'lower': 0, 'upper': 10, 'default': 0, 'log': False, 'type': 'UniformIntDistribution'}])\n", 414 | "penalty: CategoricalDistribution(choices=['l2', 'l1'], default=l2)\n", 415 | "multi_class: CategoricalDistribution(choices=['ovr', 'multinomial'], default=ovr)" 416 | ] 417 | }, 418 | "execution_count": 25, 419 | "metadata": {}, 420 | "output_type": "execute_result" 421 | } 422 | ], 423 | "source": [ 424 | "sampler_dict = json.loads(json_serialized)\n", 425 | "sampler_new = Sampler(validator).from_dict(sampler_dict)\n", 426 | "sampler_new" 427 | ] 428 | } 429 | ], 430 | "metadata": { 431 | "kernelspec": { 432 | "display_name": "Python [conda env:skconfig]", 433 | "language": "python", 434 | "name": "conda-env-skconfig-py" 435 | }, 436 | "language_info": { 437 | "codemirror_mode": { 438 | "name": "ipython", 439 | "version": 3 440 | }, 441 | "file_extension": ".py", 442 | "mimetype": "text/x-python", 443 | "name": "python", 444 | "nbconvert_exporter": "python", 445 | "pygments_lexer": "ipython3", 446 | "version": "3.7.2" 447 | } 448 | }, 449 | "nbformat": 4, 450 | "nbformat_minor": 2 451 | } 452 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | conda: 2 | file: docs/conda.yml 3 | formats: 4 | - none 5 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn==0.20.3 3 | configspace==0.4.9 4 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | pytest 3 | twine 4 | pytest-cov 5 | codecov 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | addopts = 3 | --cov=skconfig 4 | testpaths = tests/ 5 | 6 | [flake8] 7 | max-line-length = 119 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | 5 | from codecs import open 6 | from setuptools import setup, find_packages 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | with open('requirements/base.txt') as f: 11 | install_requires = f.read().strip().split('\n') 12 | 13 | with open('requirements/dev.txt') as f: 14 | dev_requires = f.read().strip().split('\n') 15 | 16 | with open('VERSION', 'r') as f: 17 | version = f.read().rstrip() 18 | 19 | with open('README.rst', 'r', 'utf-8') as f: 20 | readme = f.read() 21 | 22 | setup( 23 | name='skconfig', 24 | version=version, 25 | description='Scikit-learn sampling and validation library', 26 | long_description=readme, 27 | author='Thomas J Fan', 28 | author_email='thomasjpfan@gmail.com', 29 | url='https://github.com/thomasjpfan/skconfig', 30 | packages=find_packages(include=['skconfig']), 31 | install_requires=install_requires, 32 | include_package_data=True, 33 | python_requires='>=3.5', 34 | zip_safe=False, 35 | license='MIT', 36 | classifiers=[ 37 | 'Development Status :: 2 - Pre-Alpha', 38 | 'Intended Audience :: Developers', 39 | 'Programming Language :: Python', 40 | 'Topic :: Software Development', 41 | 'Topic :: Scientific/Engineering', 42 | 'Natural Language :: English', 43 | "License :: OSI Approved :: MIT License", 44 | 'Programming Language :: Python :: 3 :: Only', 45 | 'Programming Language :: Python :: 3.5', 46 | 'Programming Language :: Python :: 3.6', 47 | ], 48 | extras_require={ 49 | 'dev': dev_requires, 50 | }) 51 | -------------------------------------------------------------------------------- /skconfig/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasjpfan/skconfig/962eb6486f1d11fc4858396189e54957bd95db07/skconfig/__init__.py -------------------------------------------------------------------------------- /skconfig/condition.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from .exceptions import SKConfigValueError 3 | 4 | 5 | class Condition(metaclass=ABCMeta): 6 | def __init__(self, child, parent, conditioned_value): 7 | self.child = child 8 | self.parent = parent 9 | self.conditioned_value = conditioned_value 10 | 11 | @abstractmethod 12 | def is_active(self, **kwargs): 13 | ... 14 | 15 | 16 | class EqualsCondition(Condition): 17 | def is_active(self, **kwargs): 18 | value = kwargs.get(self.parent) 19 | if value is None: 20 | return False 21 | return value == self.conditioned_value 22 | 23 | def __repr__(self): 24 | return "Child: {} Condition: {} == {}".format(self.child, self.parent, 25 | self.conditioned_value) 26 | 27 | 28 | class NotEqualsCondition(Condition): 29 | def is_active(self, **kwargs): 30 | value = kwargs.get(self.parent) 31 | if value is None: 32 | return False 33 | return value != self.conditioned_value 34 | 35 | def __repr__(self): 36 | return "Child: {} Condition: {} != {}".format(self.child, self.parent, 37 | self.conditioned_value) 38 | 39 | 40 | class LessThanCondition(Condition): 41 | def is_active(self, **kwargs): 42 | value = kwargs.get(self.parent) 43 | if value is None: 44 | return True 45 | return value < self.conditioned_value 46 | 47 | def __repr__(self): 48 | return "Child: {} Condition: {} < {}".format(self.child, self.parent, 49 | self.conditioned_value) 50 | 51 | 52 | class GreaterThanCondition(Condition): 53 | def is_active(self, **kwargs): 54 | value = kwargs.get(self.parent) 55 | if value is None: 56 | return True 57 | return value > self.conditioned_value 58 | 59 | def __repr__(self): 60 | return "Child: {} Condition: {} > {}".format(self.child, self.parent, 61 | self.conditioned_value) 62 | 63 | 64 | class InCondition(Condition): 65 | def is_active(self, **kwargs): 66 | value = kwargs.get(self.parent) 67 | if value is None: 68 | return True 69 | return value in self.conditioned_value 70 | 71 | def __repr__(self): 72 | return "Child: {} Condition: {} in {}".format(self.child, self.parent, 73 | self.conditioned_value) 74 | 75 | 76 | class AndCondition(Condition): 77 | def __init__(self, *conditions): 78 | name = {c.child for c in conditions} 79 | if len(name) != 1: 80 | raise SKConfigValueError("multiple names given: {}".format(name)) 81 | self.child = name.pop() 82 | self.conditions = conditions 83 | 84 | def is_active(self, **kwargs): 85 | for condition in self.conditions: 86 | if not condition.is_active(**kwargs): 87 | return False 88 | return True 89 | 90 | def __repr__(self): 91 | return " & ".join("({})".format(c) for c in self.conditions) 92 | 93 | 94 | class OrCondition(Condition): 95 | def __init__(self, *conditions): 96 | name = {c.child for c in conditions} 97 | if len(name) != 1: 98 | raise SKConfigValueError("multiple names given: {}".format(name)) 99 | self.child = name.pop() 100 | self.conditions = conditions 101 | 102 | def is_active(self, **kwargs): 103 | for condition in self.conditions: 104 | if condition.is_active(**kwargs): 105 | return True 106 | return False 107 | 108 | def __repr__(self): 109 | return " | ".join("({})".format(c) for c in self.conditions) 110 | -------------------------------------------------------------------------------- /skconfig/distribution.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from inspect import getfullargspec 3 | 4 | from ConfigSpace import EqualsCondition 5 | from ConfigSpace.hyperparameters import CategoricalHyperparameter 6 | from ConfigSpace.hyperparameters import UniformIntegerHyperparameter 7 | from ConfigSpace.hyperparameters import UniformFloatHyperparameter 8 | from ConfigSpace.hyperparameters import Constant 9 | 10 | 11 | class BaseDistribution(metaclass=ABCMeta): 12 | def to_dict(self): 13 | arg_spec = getfullargspec(self.__init__) 14 | args = arg_spec.args[1:] 15 | args.extend(arg_spec.kwonlyargs) 16 | output = {arg: getattr(self, arg) for arg in args} 17 | output['type'] = self.__class__.__name__ 18 | return output 19 | 20 | @classmethod 21 | def from_dict(cls, p_dict): 22 | type_of_dict = p_dict['type'] 23 | assert type_of_dict == cls.__name__ 24 | return cls(**p_dict) 25 | 26 | def __repr__(self): 27 | as_dict = self.to_dict() 28 | type_of_dict = as_dict['type'] 29 | del as_dict['type'] 30 | value_str = ", ".join("{}={}".format(k, v) for k, v in as_dict.items()) 31 | return "{}({})".format(type_of_dict, value_str) 32 | 33 | def post_process(self, name, config_space_dict, value=None): 34 | if value is not None: 35 | config_space_dict[name] = value 36 | return config_space_dict 37 | 38 | def value_to_name_value(self, name, value): 39 | return (name, value) 40 | 41 | def child_name(self, name): 42 | return name 43 | 44 | def is_constant(self): 45 | return False 46 | 47 | 48 | class UnionDistribution(BaseDistribution): 49 | def __init__(self, *dists, **kwargs): 50 | self.dists = list(dists) 51 | 52 | def to_dict(self): 53 | output = {'type': self.__class__.__name__} 54 | output_dists = [] 55 | for dist in self.dists: 56 | output_dists.append(dist.to_dict()) 57 | output['dists'] = output_dists 58 | return output 59 | 60 | @classmethod 61 | def from_dict(cls, p_dict): 62 | type_of_dict = p_dict['type'] 63 | assert type_of_dict == cls.__name__ 64 | dists = p_dict['dists'] 65 | dist_objs = [] 66 | for dist in dists: 67 | dist_objs.append(load_dist_dict(dist)) 68 | return cls(*dist_objs) 69 | 70 | def add_to_config_space(self, name, cs): 71 | control_name = "{}:control".format(name) 72 | 73 | type_name_to_dist = self.type_name_to_dist(name) 74 | type_names = list(type_name_to_dist) 75 | control = CategoricalHyperparameter( 76 | name=control_name, choices=type_names, default_value=type_names[0]) 77 | 78 | cs.add_hyperparameter(control) 79 | for type_name, dist in type_name_to_dist.items(): 80 | cs_hp = dist.add_to_config_space(type_name, cs) 81 | cs.add_condition(EqualsCondition(cs_hp, control, type_name)) 82 | 83 | def post_process(self, name, config_space_dict): 84 | type_name_to_dist = self.type_name_to_dist(name) 85 | control_name = "{}:control".format(name) 86 | control_value = config_space_dict[control_name] 87 | dist = type_name_to_dist[control_value] 88 | 89 | actual_value = config_space_dict[control_value] 90 | dist.post_process(name, config_space_dict, value=actual_value) 91 | 92 | del config_space_dict[control_name] 93 | del config_space_dict[control_value] 94 | return config_space_dict 95 | 96 | def type_name_to_dist(self, name): 97 | type_name_to_dist = {} 98 | for dist in self.dists: 99 | type_name = "{}:{}".format(name, dist.dtype.__name__) 100 | type_name_to_dist[type_name] = dist 101 | return type_name_to_dist 102 | 103 | @property 104 | def type_to_dist(self): 105 | if hasattr(self, "_type_to_dist"): 106 | return self._type_to_dist 107 | self._type_to_dist = {} 108 | for dist in self.dists: 109 | self._type_to_dist[dist.dtype] = dist 110 | return self._type_to_dist 111 | 112 | def value_to_name_value(self, name, value): 113 | for dist in self.dists: 114 | if isinstance(value, dist.dtype): 115 | return "{}:{}".format(name, dist.dtype.__name__), value 116 | raise TypeError("Unrecognized type for name, {}, with value {}".format( 117 | name, value)) 118 | 119 | def child_name(self, name): 120 | return "{}:control".format(name) 121 | 122 | def in_distrubution(self, value): 123 | for dist_type, dist in self.type_to_dist.items(): 124 | if isinstance(value, dist_type): 125 | return dist.in_distrubution(value) 126 | return False 127 | 128 | 129 | class UniformBoolDistribution(BaseDistribution): 130 | dtype = bool 131 | 132 | def __init__(self, default=True, **kwargs): 133 | self.default = default 134 | 135 | def add_to_config_space(self, name, cs): 136 | default_value = 'T' if self.default else 'F' 137 | hp = CategoricalHyperparameter( 138 | name=name, choices=['T', 'F'], default_value=default_value) 139 | cs.add_hyperparameter(hp) 140 | return hp 141 | 142 | def post_process(self, name, config_space_dict, value=None): 143 | value = config_space_dict[name] 144 | config_space_dict[name] = value == 'T' 145 | return config_space_dict 146 | 147 | def in_distrubution(self, value): 148 | return value in [True, False] 149 | 150 | def value_to_name(self, name, value): 151 | if value: 152 | return name, 'T' 153 | return name, 'F' 154 | 155 | 156 | class UniformIntDistribution(BaseDistribution): 157 | dtype = int 158 | 159 | def __init__(self, lower, upper, default=None, log=False, **kwargs): 160 | self.lower = lower 161 | self.upper = upper 162 | self.log = log 163 | self.default = self.lower if default is None else default 164 | 165 | def add_to_config_space(self, name, cs): 166 | hp = UniformIntegerHyperparameter( 167 | name=name, 168 | lower=self.lower, 169 | upper=self.upper, 170 | log=self.log, 171 | default_value=self.default) 172 | cs.add_hyperparameter(hp) 173 | return hp 174 | 175 | def in_distrubution(self, value): 176 | return self.lower <= value <= self.upper 177 | 178 | 179 | class UniformFloatDistribution(BaseDistribution): 180 | dtype = float 181 | 182 | def __init__(self, lower, upper, default=None, log=False, **kwargs): 183 | self.lower = lower 184 | self.upper = upper 185 | self.log = log 186 | self.default = self.lower if default is None else default 187 | 188 | def add_to_config_space(self, name, cs): 189 | hp = UniformFloatHyperparameter( 190 | name=name, 191 | lower=self.lower, 192 | upper=self.upper, 193 | log=self.log, 194 | default_value=self.default) 195 | cs.add_hyperparameter(hp) 196 | return hp 197 | 198 | def in_distrubution(self, value): 199 | return self.lower <= value <= self.upper 200 | 201 | 202 | class CategoricalDistribution(BaseDistribution): 203 | dtype = str 204 | 205 | def __init__(self, choices, default=None, **kwargs): 206 | self.choices = choices 207 | self.default = self.choices[0] if default is None else default 208 | 209 | def add_to_config_space(self, name, cs): 210 | hp = CategoricalHyperparameter( 211 | name=name, choices=self.choices, default_value=self.default) 212 | cs.add_hyperparameter(hp) 213 | return hp 214 | 215 | def in_distrubution(self, value): 216 | return value in self.choices 217 | 218 | 219 | class ConstantDistribution(BaseDistribution): 220 | def __init__(self, value, **kwargs): 221 | self.value = value 222 | self.dtype = type(value) 223 | 224 | def to_dict(self): 225 | return {'type': self.__class__.__name__, 'value': self.value} 226 | 227 | @classmethod 228 | def from_dict(cls, p_dict): 229 | assert p_dict['type'] == cls.__name__ 230 | return cls(**p_dict) 231 | 232 | def add_to_config_space(self, name, cs): 233 | if self.value is None: 234 | hp = Constant(name, type(None).__name__) 235 | else: 236 | hp = Constant(name, self.value) 237 | cs.add_hyperparameter(hp) 238 | return hp 239 | 240 | def post_process(self, name, config_space_dict, value=None): 241 | config_space_dict[name] = self.value 242 | return config_space_dict 243 | 244 | def in_distrubution(self, value): 245 | return self.value == value 246 | 247 | def is_constant(self): 248 | return True 249 | 250 | 251 | def load_dist_dict(dist_dict): 252 | supported_dists = [ 253 | UniformBoolDistribution, UniformIntDistribution, 254 | UniformFloatDistribution, CategoricalDistribution, 255 | ConstantDistribution, UnionDistribution 256 | ] 257 | name_to_dist_cls = {d.__name__: d for d in supported_dists} 258 | dist_cls = name_to_dist_cls[dist_dict['type']] 259 | return dist_cls.from_dict(dist_dict) 260 | -------------------------------------------------------------------------------- /skconfig/exceptions.py: -------------------------------------------------------------------------------- 1 | class SKConfigValueError(ValueError): 2 | pass 3 | 4 | 5 | class InvalidParam(SKConfigValueError): 6 | pass 7 | 8 | 9 | class InvalidParamName(InvalidParam): 10 | def __init__(self, name): 11 | super().__init__("{} is a invalid parameter name".format(name)) 12 | 13 | 14 | class InvalidParamType(InvalidParam): 15 | def __init__(self, name, correct_type): 16 | super().__init__("{} must be of type {}".format(name, correct_type)) 17 | 18 | 19 | class InvalidParamRange(InvalidParam): 20 | def __init__(self, 21 | name, 22 | value, 23 | lower=None, 24 | upper=None, 25 | include_lower=True, 26 | include_upper=True): 27 | if lower is None and upper is None: 28 | raise ValueError("Lower and uppoer bounds cannot both be None") 29 | 30 | msg_list = ["{} with value {} not in range:".format(name, value)] 31 | if lower is not None: 32 | if include_lower: 33 | lower_str = "[{},".format(lower) 34 | else: 35 | lower_str = "({},".format(lower) 36 | msg_list.append(lower_str) 37 | else: 38 | msg_list.append("(-inf") 39 | 40 | if upper is not None: 41 | if include_upper: 42 | upper_str = "{}]".format(upper) 43 | else: 44 | upper_str = "{})".format(upper) 45 | msg_list.append(upper_str) 46 | else: 47 | msg_list.append("inf)") 48 | super().__init__(" ".join(msg_list)) 49 | 50 | 51 | class InvalidParamChoices(InvalidParam): 52 | def __init__(self, name, choices): 53 | super().__init__("{} must be one of {}".format(name, choices)) 54 | 55 | 56 | class ForbiddenValue(SKConfigValueError): 57 | def __init__(self, name, value): 58 | super().__init__("{} with value {} is forbidden".format(name, value)) 59 | 60 | 61 | class InactiveConditionedValue(SKConfigValueError): 62 | def __init__(self, name, condition): 63 | super().__init__("{} has an unmet condition: {}".format( 64 | name, condition)) 65 | -------------------------------------------------------------------------------- /skconfig/forbidden.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from .exceptions import ForbiddenValue 3 | 4 | 5 | class ForbiddenClause(metaclass=ABCMeta): 6 | def __init__(self, name, value): 7 | self.name = name 8 | self.value = value 9 | 10 | @abstractmethod 11 | def is_forbidden(self, **kwargs): 12 | ... 13 | 14 | def __repr__(self): 15 | return "{self.__class__.__name__}: {self.name}, {self.value}".format( 16 | self=self) 17 | 18 | 19 | class ForbiddenIn(ForbiddenClause): 20 | def is_forbidden(self, **kwargs): 21 | value = kwargs.get(self.name) 22 | if value is None: 23 | return 24 | if value in self.value: 25 | raise ForbiddenValue(self.name, value) 26 | 27 | 28 | class ForbiddenEquals(ForbiddenClause): 29 | def is_forbidden(self, **kwargs): 30 | value = kwargs.get(self.name) 31 | if value is None: 32 | return 33 | if value == self.value: 34 | raise ForbiddenValue(self.name, value) 35 | 36 | 37 | class ForbiddenAnd(ForbiddenClause): 38 | def __init__(self, forbidden_clauses): 39 | self.forbidden_clauses = forbidden_clauses 40 | 41 | def is_forbidden(self, **kwargs): 42 | names = [] 43 | values = [] 44 | for for_clauses in self.forbidden_clauses: 45 | value = kwargs.get(for_clauses.name) 46 | if value is None: 47 | return 48 | try: 49 | for_clauses.is_forbidden(**kwargs) 50 | except ForbiddenValue: 51 | names.append(for_clauses.name) 52 | values.append(for_clauses.value) 53 | 54 | if len(names) == len(self.forbidden_clauses): 55 | names_str = " and ".join(names) 56 | values_str = " and ".join(str(v) for v in values) 57 | raise ForbiddenValue(names_str, values_str) 58 | 59 | def __repr__(self): 60 | names = [] 61 | for clause in self.forbidden_clauses: 62 | names.append(str(clause)) 63 | return "ForbiddenAnd: ({})".format(", ".join(names)) 64 | -------------------------------------------------------------------------------- /skconfig/mapping.py: -------------------------------------------------------------------------------- 1 | from .condition import (EqualsCondition, NotEqualsCondition, LessThanCondition, 2 | GreaterThanCondition, InCondition, AndCondition, 3 | OrCondition) 4 | from .forbidden import (ForbiddenEquals, ForbiddenIn, ForbiddenAnd) 5 | 6 | from ConfigSpace import EqualsCondition as CSEqualsCondition 7 | from ConfigSpace import NotEqualsCondition as CSNotEqualsCondition 8 | from ConfigSpace import LessThanCondition as CSLessThanCondition 9 | from ConfigSpace import GreaterThanCondition as CSGreaterThanCondition 10 | from ConfigSpace import InCondition as CSInCondition 11 | from ConfigSpace import AndConjunction as CSAndCondition 12 | from ConfigSpace import OrConjunction as CSOrCondition 13 | from ConfigSpace import ForbiddenAndConjunction as CSForbiddenAnd 14 | from ConfigSpace import ForbiddenEqualsClause as CSForbiddenEqual 15 | from ConfigSpace import ForbiddenInClause as CSForbiddenIn 16 | 17 | 18 | def skconfig_obj_to_config_space(skconfig_obj, cs): 19 | if isinstance(skconfig_obj, EqualsCondition): 20 | child_hp = cs.get_hyperparameter(skconfig_obj.child) 21 | parent_hp = cs.get_hyperparameter(skconfig_obj.parent) 22 | return CSEqualsCondition(child_hp, parent_hp, 23 | skconfig_obj.conditioned_value) 24 | elif isinstance(skconfig_obj, NotEqualsCondition): 25 | child_hp = cs.get_hyperparameter(skconfig_obj.child) 26 | parent_hp = cs.get_hyperparameter(skconfig_obj.parent) 27 | return CSNotEqualsCondition(child_hp, parent_hp, 28 | skconfig_obj.conditioned_value) 29 | elif isinstance(skconfig_obj, LessThanCondition): 30 | child_hp = cs.get_hyperparameter(skconfig_obj.child) 31 | parent_hp = cs.get_hyperparameter(skconfig_obj.parent) 32 | return CSLessThanCondition(child_hp, parent_hp, 33 | skconfig_obj.conditioned_value) 34 | elif isinstance(skconfig_obj, GreaterThanCondition): 35 | child_hp = cs.get_hyperparameter(skconfig_obj.child) 36 | parent_hp = cs.get_hyperparameter(skconfig_obj.parent) 37 | return CSGreaterThanCondition(child_hp, parent_hp, 38 | skconfig_obj.conditioned_value) 39 | elif isinstance(skconfig_obj, InCondition): 40 | child_hp = cs.get_hyperparameter(skconfig_obj.child) 41 | parent_hp = cs.get_hyperparameter(skconfig_obj.parent) 42 | return CSInCondition(child_hp, parent_hp, 43 | skconfig_obj.conditioned_value) 44 | elif isinstance(skconfig_obj, AndCondition): 45 | output = [] 46 | for cond in skconfig_obj.conditons: 47 | output.append(skconfig_obj_to_config_space(cond, cs)) 48 | return CSAndCondition(*output) 49 | elif isinstance(skconfig_obj, OrCondition): 50 | output = [] 51 | for cond in skconfig_obj.conditons: 52 | output.append(skconfig_obj_to_config_space(cond, cs)) 53 | return CSOrCondition(*output) 54 | elif isinstance(skconfig_obj, ForbiddenEquals): 55 | hp = cs.get_hyperparameter(skconfig_obj.name) 56 | return CSForbiddenEqual(hp, skconfig_obj.value) 57 | elif isinstance(skconfig_obj, ForbiddenIn): 58 | hp = cs.get_hyperparameter(skconfig_obj.name) 59 | return CSForbiddenIn(hp, skconfig_obj.value) 60 | elif isinstance(skconfig_obj, ForbiddenAnd): 61 | output = [] 62 | for forb in skconfig_obj.forbidden_clauses: 63 | output.append(skconfig_obj_to_config_space(forb, cs)) 64 | return CSForbiddenAnd(*output) 65 | raise TypeError("Unable to recognize type: {}".format(skconfig_obj)) 66 | -------------------------------------------------------------------------------- /skconfig/parameter/__init__.py: -------------------------------------------------------------------------------- 1 | from .interval import FloatIntervalParam 2 | from .interval import IntIntervalParam 3 | from .types import BoolParam 4 | from .types import NoneParam 5 | from .types import FloatParam 6 | from .types import IntParam 7 | from .types import StringParam 8 | from .types import UnionParam 9 | 10 | __all__ = [ 11 | "BoolParam", "FloatIntervalParam", "IntIntervalParam", "NoneParam", 12 | "FloatParam", "IntParam", "StringParam", "UnionParam" 13 | ] 14 | -------------------------------------------------------------------------------- /skconfig/parameter/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Param(metaclass=ABCMeta): 5 | @abstractmethod 6 | def validate(self, param): 7 | ... 8 | -------------------------------------------------------------------------------- /skconfig/parameter/convience.py: -------------------------------------------------------------------------------- 1 | from numpy.random import RandomState 2 | from .types import NoneParam 3 | from .types import ObjectParam 4 | from .types import IntParam 5 | from .types import UnionParam 6 | 7 | 8 | def RandomStateParam(): 9 | return UnionParam(NoneParam(), IntParam(), ObjectParam(RandomState)) 10 | -------------------------------------------------------------------------------- /skconfig/parameter/interval.py: -------------------------------------------------------------------------------- 1 | from .base import Param 2 | from ..exceptions import InvalidParamRange 3 | from ..exceptions import InvalidParamType 4 | 5 | 6 | class NumericalInterval(Param): 7 | def __init__(self, 8 | lower=None, 9 | upper=None, 10 | include_lower=True, 11 | include_upper=True): 12 | self.lower = lower 13 | self.upper = upper 14 | self.include_lower = include_lower 15 | self.include_upper = include_upper 16 | 17 | def validate(self, name, value): 18 | if not isinstance(value, self.value_type): 19 | raise InvalidParamType(name, self.type_str) 20 | 21 | if self.lower is not None: 22 | if self.include_lower: 23 | lower_in_range = self.lower <= value 24 | else: 25 | lower_in_range = self.lower < value 26 | if not lower_in_range: 27 | raise InvalidParamRange( 28 | name, 29 | value, 30 | lower=self.lower, 31 | upper=self.upper, 32 | include_lower=self.include_lower, 33 | include_upper=self.include_upper) 34 | 35 | if self.upper is not None: 36 | if self.include_upper: 37 | upper_in_range = value <= self.upper 38 | else: 39 | upper_in_range = value < self.upper 40 | if not upper_in_range: 41 | raise InvalidParamRange( 42 | name, 43 | value, 44 | lower=self.lower, 45 | upper=self.upper, 46 | include_lower=self.include_lower, 47 | include_upper=self.include_upper) 48 | 49 | 50 | class FloatIntervalParam(NumericalInterval): 51 | value_type = (float, int) 52 | type_str = 'float, int' 53 | 54 | 55 | class IntIntervalParam(NumericalInterval): 56 | value_type = int 57 | type_str = 'int' 58 | -------------------------------------------------------------------------------- /skconfig/parameter/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from .base import Param 4 | from ..exceptions import InvalidParamType 5 | from ..exceptions import InvalidParamChoices 6 | from ..exceptions import SKConfigValueError 7 | 8 | 9 | class TypedParam(Param): 10 | def validate(self, name, value): 11 | if not isinstance(value, self.value_type): 12 | raise InvalidParamType(name, self.type_str) 13 | 14 | 15 | class BoolParam(TypedParam): 16 | value_type = bool 17 | type_str = 'bool' 18 | 19 | 20 | class NoneParam(TypedParam): 21 | value_type = type(None) 22 | type_str = 'NoneType' 23 | 24 | 25 | class FloatParam(TypedParam): 26 | value_type = (float, int) 27 | type_str = 'float, int' 28 | 29 | 30 | class IntParam(TypedParam): 31 | value_type = int 32 | type_str = 'int' 33 | 34 | 35 | class StringParam(Param): 36 | value_type = str 37 | type_str = 'str' 38 | 39 | def __init__(self, *choices): 40 | if not isinstance(choices, tuple) or not choices: 41 | raise SKConfigValueError("choices must be all strings") 42 | for choice in choices: 43 | if not isinstance(choice, self.value_type): 44 | raise SKConfigValueError("choices must be all strings") 45 | self.choices = choices 46 | 47 | def validate(self, name, value): 48 | if not isinstance(value, self.value_type): 49 | raise InvalidParamType(name, self.type_str) 50 | if value not in self.choices: 51 | raise InvalidParamChoices(name, self.choices) 52 | 53 | 54 | class CallableParam(TypedParam): 55 | value_type = Callable 56 | type_str = 'callable' 57 | 58 | 59 | class ObjectParam(Param): 60 | value_type = object 61 | type_str = 'object' 62 | 63 | def __init__(self, *objects): 64 | self.objects = objects 65 | 66 | def validate(self, name, value): 67 | if not isinstance(value, self.objects): 68 | raise InvalidParamType(name, self.type_str) 69 | 70 | 71 | class UnionParam(Param): 72 | def __init__(self, *parameters): 73 | if not isinstance(parameters, tuple) or not parameters: 74 | raise SKConfigValueError("parameters must be a non empty") 75 | for parameter in parameters: 76 | if not isinstance(parameter, Param): 77 | raise SKConfigValueError("parameters must be all Param") 78 | self.parameters = parameters 79 | 80 | def validate(self, name, value): 81 | for param in self.parameters: 82 | if isinstance(value, param.value_type): 83 | param.validate(name, value) 84 | break 85 | else: 86 | p_types = [param.type_str for param in self.parameters] 87 | raise InvalidParamType(name, p_types) 88 | -------------------------------------------------------------------------------- /skconfig/sampler.py: -------------------------------------------------------------------------------- 1 | from contextlib import suppress 2 | 3 | import ConfigSpace as CS 4 | from .distribution import load_dist_dict 5 | from .exceptions import SKConfigValueError 6 | from .distribution import BaseDistribution 7 | from .condition import AndCondition 8 | from .condition import OrCondition 9 | from .condition import InCondition 10 | from .condition import EqualsCondition 11 | from .condition import Condition 12 | from .forbidden import ForbiddenAnd 13 | from .forbidden import ForbiddenEquals 14 | from .forbidden import ForbiddenIn 15 | 16 | from .mapping import skconfig_obj_to_config_space 17 | 18 | 19 | class Sampler: 20 | def __init__(self, validator, **kwargs): 21 | self.hps = {} 22 | for k, v in kwargs.items(): 23 | if k in validator.parameters_: 24 | if isinstance(v, BaseDistribution): 25 | self.hps[k] = v 26 | else: 27 | raise SKConfigValueError( 28 | "{} is not a distribution".format(k)) 29 | else: 30 | raise SKConfigValueError("{} is an invalid key".format(k)) 31 | self.validator = validator 32 | self._generate_config_space() 33 | 34 | def to_dict(self): 35 | output = {} 36 | for k, v in self.hps.items(): 37 | output[k] = v.to_dict() 38 | return output 39 | 40 | def from_dict(self, p_dict): 41 | for k, v in p_dict.items(): 42 | self.hps[k] = load_dist_dict(v) 43 | self._generate_config_space() 44 | return self 45 | 46 | def __repr__(self): 47 | lines = [] 48 | for k, v in self.hps.items(): 49 | lines.append("{}: {}".format(k, v)) 50 | return "\n".join(lines) 51 | 52 | def _generate_config_space(self): 53 | active_params = set(self.hps) 54 | active_conditions = [] 55 | active_forbiddens = [] 56 | 57 | # Find activate conditions 58 | for cond in self.validator.conditions: 59 | if cond.child not in active_params: 60 | continue 61 | 62 | active_cond = self._get_active_condition(cond) 63 | if active_cond is None: 64 | with suppress(KeyError): 65 | active_params.remove(cond.child) 66 | continue 67 | active_conditions.append(active_cond) 68 | 69 | # Find activate forbiddens 70 | for forb in self.validator.forbiddens: 71 | active_forb = self._get_active_forbidden(forb, active_params) 72 | if active_forb is None: 73 | continue 74 | active_forbiddens.append(active_forb) 75 | 76 | # Create configuration space 77 | config_space = CS.ConfigurationSpace() 78 | self.config_space = config_space 79 | 80 | for name in active_params: 81 | self.hps[name].add_to_config_space(name, config_space) 82 | 83 | self.normalized_conditions = [ 84 | self._normalize_condition_names(cond) for cond in active_conditions 85 | ] 86 | 87 | self.normalized_forbiddens = [ 88 | self._normalize_forbidden_names(forb) for forb in active_forbiddens 89 | ] 90 | 91 | self.cs_conditions = [ 92 | skconfig_obj_to_config_space(cond, config_space) 93 | for cond in self.normalized_conditions 94 | ] 95 | self.cs_forbiddens = [ 96 | skconfig_obj_to_config_space(forb, config_space) 97 | for forb in self.normalized_forbiddens 98 | ] 99 | config_space.add_conditions(self.cs_conditions) 100 | config_space.add_forbidden_clauses(self.cs_forbiddens) 101 | 102 | def sample(self, size=1): 103 | configs = self.config_space.sample_configuration(size) 104 | if size == 1: 105 | configs = [configs] 106 | output = [] 107 | for config in configs: 108 | config_dict = config.get_dictionary() 109 | for name, dist in self.hps.items(): 110 | dist.post_process(name, config_dict) 111 | output.append(config_dict) 112 | return output 113 | 114 | def _get_active_condition(self, cond): 115 | if isinstance(cond, OrCondition): 116 | output = [] 117 | for inner_cond in cond.conditions: 118 | active_cond = self._get_active_condition(inner_cond) 119 | if active_cond is None: 120 | continue 121 | output.append(active_cond) 122 | if not output: 123 | return None 124 | return OrCondition(*output) 125 | 126 | if isinstance(cond, AndCondition): 127 | for inner_cond in cond.conditions: 128 | active_cond = self._get_active_condition(inner_cond) 129 | if active_cond is None: 130 | return None 131 | return cond 132 | 133 | parent = cond.parent 134 | if parent not in self.hps: 135 | return None 136 | conditioned_value = cond.conditioned_value 137 | dist = self.hps[parent] 138 | if dist.in_distrubution(conditioned_value): 139 | return cond 140 | 141 | def _get_active_forbidden(self, forbidden, active_params): 142 | if isinstance(forbidden, ForbiddenAnd): 143 | values = [] 144 | for forb in forbidden.forbidden_clauses: 145 | name = forb.name 146 | if name not in active_params: 147 | return 148 | dist = self.hps[name] 149 | active_forb = self._get_active_forbidden(forb, active_params) 150 | if active_forb is None: 151 | return 152 | values.append(active_forb) 153 | return ForbiddenAnd(values) 154 | elif isinstance(forbidden, ForbiddenIn): 155 | name = forbidden.name 156 | if name not in active_params: 157 | return 158 | 159 | values = [] 160 | dist = self.hps[name] 161 | for value in forbidden.value: 162 | if (dist.in_distrubution(value) and not dist.is_constant()): 163 | values.append(value) 164 | if not values: 165 | return 166 | return ForbiddenIn(name, values) 167 | elif isinstance(forbidden, ForbiddenEquals): 168 | name = forbidden.name 169 | if name not in active_params: 170 | return 171 | 172 | dist = self.hps[name] 173 | if (dist.in_distrubution(forbidden.value) 174 | and not dist.is_constant()): 175 | return forbidden 176 | return 177 | raise TypeError("Unrecognized type {}".format(forbidden)) 178 | 179 | def _normalize_condition_names(self, condition): 180 | if isinstance(condition, (AndCondition, OrCondition)): 181 | output = [] 182 | for cond in condition.conditions: 183 | new_conditions = self._normalize_condition_names(cond) 184 | output.extend(new_conditions) 185 | return condition.__class__(*output) 186 | elif isinstance(condition, InCondition): 187 | output = [] 188 | for c_value in condition.conditioned_value: 189 | child = condition.child 190 | parent = condition.parent 191 | child_dist = self.hps[child] 192 | parent_dist = self.hps[parent] 193 | child = child_dist.child_name(child) 194 | parent, c_value = parent_dist.value_to_name_value( 195 | parent, c_value) 196 | output.append(EqualsCondition(child, parent, c_value)) 197 | return OrCondition(*output) 198 | elif isinstance(condition, Condition): 199 | child = condition.child 200 | parent = condition.parent 201 | conditioned_value = condition.conditioned_value 202 | child_dist = self.hps[child] 203 | parent_dist = self.hps[parent] 204 | 205 | child = child_dist.child_name(child) 206 | parent, conditioned_value = parent_dist.value_to_name_value( 207 | parent, conditioned_value) 208 | 209 | return condition.__class__(child, parent, conditioned_value) 210 | raise TypeError("Unrecognized type {}".format(condition)) 211 | 212 | def _normalize_forbidden_names(self, forbidden): 213 | if isinstance(forbidden, ForbiddenAnd): 214 | output = [] 215 | for forb in forbidden.forbidden_clauses: 216 | output.append(self._normalize_forbidden_names(forb)) 217 | return ForbiddenAnd(output) 218 | elif isinstance(forbidden, ForbiddenIn): 219 | values = [] 220 | dist = self.hps[forbidden.name] 221 | f_name = None 222 | for value in forbidden.value: 223 | name, value = dist.value_to_name_value(forbidden.name, value) 224 | if f_name is None: 225 | f_name = name 226 | elif f_name != name: 227 | raise ValueError("ForbiddenIn must be the same type: {}". 228 | format(forbidden)) 229 | values.append(value) 230 | return ForbiddenIn(f_name, values) 231 | elif isinstance(forbidden, ForbiddenEquals): 232 | dist = self.hps[forbidden.name] 233 | name, value = dist.value_to_name_value(forbidden.name, 234 | forbidden.value) 235 | return forbidden.__class__(name, value) 236 | raise TypeError("Unrecognized type {}".format(forbidden)) 237 | -------------------------------------------------------------------------------- /skconfig/validator.py: -------------------------------------------------------------------------------- 1 | from .parameter.base import Param 2 | from .exceptions import InvalidParamName 3 | from .exceptions import InactiveConditionedValue 4 | from .exceptions import SKConfigValueError 5 | 6 | 7 | class BaseValidator: 8 | conditions = [] 9 | forbiddens = [] 10 | estimator = None 11 | 12 | def __init__(self): 13 | # check parameters forbidden and conditions are compalible 14 | self.parameters_ = { 15 | name: param 16 | for name, param in self.__class__.__dict__.items() 17 | if isinstance(param, Param) 18 | } 19 | if self.estimator is None: 20 | raise SKConfigValueError("estimator must be defined") 21 | 22 | def validate_params(self, **kwargs): 23 | # Check kwargs get in params 24 | for name in kwargs: 25 | if name not in self.parameters_: 26 | raise InvalidParamName(name) 27 | 28 | all_kwargs = {**self.estimator().get_params(), **kwargs} 29 | 30 | # check for forbidden 31 | for forbidden in self.forbiddens: 32 | forbidden.is_forbidden(**all_kwargs) 33 | 34 | condition_names = [cond.child for cond in self.conditions] 35 | for name, value in all_kwargs.items(): 36 | if name not in condition_names: 37 | self.parameters_[name].validate(name, value) 38 | 39 | # Check conditions 40 | for condition in self.conditions: 41 | name = condition.child 42 | value = all_kwargs.get(name) 43 | if condition.is_active(**all_kwargs): 44 | self.parameters_[name].validate(name, value) 45 | elif value is not None: 46 | raise InactiveConditionedValue(name, condition) 47 | 48 | def validate_estimator(self, estimator): 49 | self.validate_params(**estimator.get_params()) 50 | -------------------------------------------------------------------------------- /tests/test_parameter.py: -------------------------------------------------------------------------------- 1 | from skconfig.parameter import BoolParam 2 | 3 | 4 | def test_bool_param_type(): 5 | p = BoolParam() 6 | assert p.value_type == bool 7 | --------------------------------------------------------------------------------