├── sklearn_callbacks ├── tests │ ├── __init__.py │ ├── test_base.py │ ├── test_convergence_monitor.py │ ├── test_debug.py │ ├── test_computational_graph.py │ └── test_estimators.py ├── __init__.py ├── _debug.py ├── _convergence_monitor.py ├── _progressbar.py └── _computational_graph.py ├── .gitignore ├── setup.py ├── doc └── static │ └── img │ ├── progressbar-sgd.gif │ ├── convergence-monitor.png │ └── progressbar-pipeline.gif ├── pyproject.toml ├── examples ├── progressbar-sgdclassifier.py ├── convergence-monitor-ridge.py ├── logging-pipeline.py └── progressbar-pipeline.py ├── setup.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── .circleci └── config.yml └── README.md /sklearn_callbacks/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sklearn_callbacks/tests/test_base.py: -------------------------------------------------------------------------------- 1 | def test_fake(): 2 | pass 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | __pycache__ 3 | .mypy_cache 4 | .ipynb_checkpoints 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /doc/static/img/progressbar-sgd.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rth/sklearn-callbacks/HEAD/doc/static/img/progressbar-sgd.gif -------------------------------------------------------------------------------- /doc/static/img/convergence-monitor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rth/sklearn-callbacks/HEAD/doc/static/img/convergence-monitor.png -------------------------------------------------------------------------------- /doc/static/img/progressbar-pipeline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rth/sklearn-callbacks/HEAD/doc/static/img/progressbar-pipeline.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel", "setuptools_scm[toml]>=3.4"] 3 | 4 | # enable versionning with setuptools_scm 5 | [tool.setuptools_scm] 6 | 7 | 8 | [tool.black] 9 | line-length = 79 10 | -------------------------------------------------------------------------------- /examples/progressbar-sgdclassifier.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import make_classification 2 | from sklearn.linear_model import SGDClassifier 3 | 4 | from sklearn_callbacks import ProgressBar 5 | 6 | 7 | X, y = make_classification(n_samples=200000, n_features=200, random_state=0) 8 | 9 | est = SGDClassifier(max_iter=100, tol=1e-4) 10 | 11 | pbar = ProgressBar() 12 | est._set_callbacks(pbar) 13 | 14 | est.fit(X, y) 15 | 16 | pbar.pbar.close() 17 | -------------------------------------------------------------------------------- /sklearn_callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # pkg_resources is installed with setuptools 2 | from pkg_resources import get_distribution, DistributionNotFound 3 | 4 | try: 5 | __version__ = get_distribution(__name__).version 6 | except DistributionNotFound: 7 | # package is not installed 8 | pass 9 | 10 | from ._debug import DebugCallback 11 | from ._progressbar import ProgressBar 12 | from ._convergence_monitor import ConvergenceMonitor 13 | 14 | __all__ = ["DebugCallback", "ProgressBar", "ConvergenceMonitor"] 15 | -------------------------------------------------------------------------------- /examples/convergence-monitor-ridge.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import Ridge 2 | from sklearn.model_selection import train_test_split 3 | from sklearn.datasets import make_regression 4 | from sklearn.pipeline import make_pipeline 5 | from sklearn.preprocessing import StandardScaler 6 | 7 | from sklearn_callbacks import ConvergenceMonitor 8 | 9 | X, y = make_regression(random_state=0) 10 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 11 | 12 | conv_mon = ConvergenceMonitor("mean_absolute_error", X_test, y_test) 13 | 14 | pipe = make_pipeline(StandardScaler(), Ridge(solver="sag", alpha=1)) 15 | pipe._set_callbacks(conv_mon) 16 | _ = pipe.fit(X_train, y_train) 17 | 18 | conv_mon.plot() 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = sklearn-callbacks 3 | description = "Experimental callbacks for scikit-learn: progress bars, monitoring convergence etc." 4 | url = "https://github.com/rth/sklearn-callbacks" 5 | long_description = file: README.md 6 | license = BSD-3-clause 7 | classifiers = 8 | Development Status :: 4 - Beta 9 | Natural Language :: English 10 | Programming Language :: Python 11 | Programming Language :: Python :: 3 12 | Programming Language :: Python :: 3.6 13 | Programming Language :: Python :: 3.7 14 | Programming Language :: Python :: 3.8 15 | 16 | [options] 17 | zip_safe = False 18 | include_package_data = True 19 | packages = find: 20 | install_requires = 21 | tqdm 22 | -------------------------------------------------------------------------------- /examples/logging-pipeline.py: -------------------------------------------------------------------------------- 1 | from sklearn.compose import make_column_transformer 2 | from sklearn.datasets import make_classification 3 | from sklearn.impute import SimpleImputer 4 | from sklearn.linear_model import SGDClassifier 5 | from sklearn.pipeline import make_pipeline 6 | from sklearn.preprocessing import MinMaxScaler, StandardScaler 7 | from sklearn_callbacks import DebugCallback 8 | 9 | X, y = make_classification(n_samples=10000, n_features=100, random_state=0) 10 | 11 | pipe = make_pipeline( 12 | SimpleImputer(), 13 | make_column_transformer( 14 | (StandardScaler(), slice(0, 80)), (MinMaxScaler(), slice(80, 90)), 15 | ), 16 | SGDClassifier(max_iter=20), 17 | verbose=1, 18 | ) 19 | 20 | 21 | pbar = DebugCallback() 22 | # pipe._set_callbacks(pbar) 23 | 24 | _ = pipe.fit(X, y) 25 | -------------------------------------------------------------------------------- /sklearn_callbacks/tests/test_convergence_monitor.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import Ridge 2 | from sklearn.model_selection import train_test_split 3 | from sklearn.datasets import make_regression 4 | from sklearn.pipeline import make_pipeline 5 | from sklearn.preprocessing import StandardScaler 6 | 7 | from sklearn_callbacks import ConvergenceMonitor 8 | 9 | 10 | def test_convergence_ridge(): 11 | X, y = make_regression(random_state=0) 12 | # X, y = load_diabetes(return_X_y=True) 13 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 14 | 15 | conv_mon = ConvergenceMonitor("mean_absolute_error", X_test, y_test) 16 | 17 | pipe = make_pipeline(StandardScaler(), Ridge(solver="sag", alpha=1)) 18 | pipe._set_callbacks(conv_mon) 19 | _ = pipe.fit(X_train, y_train) 20 | assert len(conv_mon.data) > 1 21 | -------------------------------------------------------------------------------- /sklearn_callbacks/tests/test_debug.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from sklearn.datasets import load_iris 4 | from sklearn.exceptions import ConvergenceWarning 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn_callbacks import DebugCallback 7 | 8 | 9 | def test_debug_callback(): 10 | X, y = load_iris(return_X_y=True) 11 | 12 | callback = DebugCallback(verbose=False) 13 | 14 | est = LogisticRegression(max_iter=3) 15 | est._set_callbacks(callback) 16 | with warnings.catch_warnings(): 17 | warnings.filterwarnings("ignore", category=ConvergenceWarning) 18 | est.fit(X, y) 19 | 20 | log_expected = [ 21 | r"fit_begin LogisticRegression\(max_iter=3\)", 22 | "iter_end coef=.*", 23 | "iter_end coef=.*", 24 | "iter_end coef=.*", 25 | ] 26 | callback.check_log_expected(log_expected) 27 | -------------------------------------------------------------------------------- /examples/progressbar-pipeline.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import make_classification 2 | from sklearn.linear_model import LogisticRegression 3 | from sklearn.pipeline import make_pipeline 4 | from sklearn.compose import make_column_transformer 5 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 6 | from sklearn.impute import SimpleImputer 7 | 8 | from sklearn_callbacks import ProgressBar 9 | 10 | X, y = make_classification(n_samples=500000, n_features=200, random_state=0) 11 | 12 | pipe = make_pipeline( 13 | SimpleImputer(), 14 | make_column_transformer( 15 | (StandardScaler(), slice(0, 80)), 16 | (MinMaxScaler(), slice(80, 120)), 17 | (StandardScaler(with_mean=False), slice(120, 180)), 18 | ), 19 | LogisticRegression(), 20 | ) 21 | 22 | 23 | pbar = ProgressBar() 24 | pipe._set_callbacks(pbar) 25 | 26 | _ = pipe.fit(X, y) 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 19.10b0 10 | hooks: 11 | - id: black 12 | - repo: https://gitlab.com/pycqa/flake8 13 | rev: 3.7.7 14 | hooks: 15 | - id: flake8 16 | types: 17 | - file 18 | args: [--select=F401,F405] 19 | - repo: https://github.com/pre-commit/mirrors-mypy 20 | rev: v0.730 21 | hooks: 22 | - id: mypy 23 | args: 24 | - --ignore-missing-imports 25 | - --follow-imports 26 | - skip 27 | files: sklearn_callbacks/ 28 | - repo: https://github.com/pre-commit/mirrors-isort 29 | rev: v4.3.21 30 | hooks: 31 | - id: isort 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Roman Yurchak 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | jobs: 4 | python38: 5 | docker: 6 | - image: circleci/python:3.8.2-buster 7 | environment: 8 | - OMP_NUM_THREADS: 2 9 | - MKL_NUM_THREADS: 2 10 | steps: 11 | - checkout 12 | - restore_cache: 13 | keys: 14 | - ccache-{{ .Branch }} 15 | - ccache 16 | - run: | 17 | set -ex 18 | python -m venv ./venv 19 | source ./venv/bin/activate 20 | pip install -q numpy scipy pandas matplotlib pytest 21 | pip install https://github.com/rth/scikit-learn/archive/progress-bar.zip 22 | - save_cache: 23 | key: ccache-{{ .Branch }}-{{ .BuildNum }} 24 | paths: 25 | - ~/.ccache 26 | - ~/.cache/pip 27 | - run: | 28 | source ./venv/bin/activate 29 | pip install -e . 30 | - run: | 31 | set -ex 32 | source ./venv/bin/activate 33 | pytest sklearn_callbacks/ 34 | 35 | lint: 36 | docker: 37 | - image: circleci/python:3.7 38 | steps: 39 | - checkout 40 | - run: 41 | name: dependencies 42 | command: sudo pip install flake8 mypy black 43 | 44 | - run: 45 | name: black 46 | command: black --check 47 | 48 | - run: 49 | name: flake8 50 | command: flake8 --select=F401,F405 51 | 52 | - run: 53 | name: mypy 54 | command: mypy --ignore-missing-imports sklearn_callbacks 55 | 56 | 57 | workflows: 58 | version: 2 59 | build-and-test: 60 | jobs: 61 | - lint 62 | - python38 63 | -------------------------------------------------------------------------------- /sklearn_callbacks/_debug.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import sys 4 | from typing import List 5 | 6 | from sklearn._callbacks import BaseCallback 7 | 8 | 9 | class DebugCallback(BaseCallback): 10 | def __init__(self, verbose=True): 11 | self.verbose = verbose 12 | self.formatter = logging.Formatter( 13 | fmt="%(asctime)s %(levelname)-8s %(message)s", 14 | ) 15 | self.log = [] 16 | self.handler = logging.StreamHandler(stream=sys.stdout) 17 | self.handler.setFormatter(self.formatter) 18 | self.logger = logging.getLogger("sklearn") 19 | self.logger.setLevel(logging.DEBUG) 20 | self.logger.addHandler(self.handler) 21 | 22 | def add_message(self, msg): 23 | self.log.append(msg) 24 | if self.verbose: 25 | self.logger.info(msg) 26 | 27 | def on_fit_begin(self, estimator, X, y): 28 | self.add_message("fit_begin " + str(estimator)) 29 | 30 | def check_log_expected(self, log: List[str]): 31 | """Check that the recored log matches expected values 32 | 33 | Parameters 34 | ---------- 35 | log 36 | list of regexp with the expected lines for each log entry. 37 | """ 38 | assert len(self.log) == len(log) 39 | 40 | for val, expected in zip(self.log, log): 41 | if not re.match(expected, val): 42 | raise AssertionError( 43 | f"Expected regexp {expected} does not match '{val}'." 44 | ) 45 | 46 | def on_iter_end(self, **kwargs): 47 | 48 | self.add_message( 49 | "iter_end " 50 | + ", ".join(f"{key}={val}" for key, val in kwargs.items()) 51 | ) 52 | -------------------------------------------------------------------------------- /sklearn_callbacks/tests/test_computational_graph.py: -------------------------------------------------------------------------------- 1 | from sklearn.compose import make_column_transformer 2 | from sklearn.linear_model import LogisticRegression 3 | from sklearn.pipeline import make_pipeline 4 | from sklearn.preprocessing import MinMaxScaler, StandardScaler 5 | from sklearn_callbacks._computational_graph import ComputeGraph 6 | 7 | 8 | def test_graph_single_estimator(): 9 | est = StandardScaler() 10 | graph = ComputeGraph.from_estimator(est) 11 | assert str(graph.root_node) == "StandardScaler 0 / 1" 12 | assert len(graph) == 1 13 | assert graph.root_node.depth == 0 14 | 15 | 16 | def test_graph_pipeline(): 17 | pipe = make_pipeline(StandardScaler(), LogisticRegression(max_iter=7)) 18 | graph = ComputeGraph.from_estimator(pipe) 19 | assert len(graph) == 3 20 | graph_flat_str = [(node.depth, str(node)) for node in graph] 21 | assert graph_flat_str == [ 22 | (0, "Pipeline 0 / 2"), 23 | (1, "StandardScaler 0 / 1"), 24 | (1, "LogisticRegression 0 / 7"), 25 | ] 26 | 27 | 28 | def test_graph_pipeline_column_transformer(): 29 | pipe_0 = make_pipeline(MinMaxScaler()) 30 | pipe = make_pipeline( 31 | make_column_transformer((StandardScaler(), [0, 1]), (pipe_0, [2, 3]),), 32 | LogisticRegression(max_iter=7), 33 | ) 34 | 35 | graph = ComputeGraph.from_estimator(pipe) 36 | assert len(graph) == 6 37 | graph_flat_str = [(node.depth, str(node)) for node in graph] 38 | assert graph_flat_str == [ 39 | (0, "Pipeline 0 / 2"), 40 | (1, "ColumnTransformer 0 / 2"), 41 | (2, "StandardScaler 0 / 1"), 42 | (2, "Pipeline 0 / 1"), 43 | (3, "MinMaxScaler 0 / 1"), 44 | (1, "LogisticRegression 0 / 7"), 45 | ] 46 | graph.update_state(pipe_0) 47 | 48 | graph_flat_str = [(node.depth, str(node)) for node in graph] 49 | assert graph_flat_str == [ 50 | (0, "Pipeline 0 / 2"), 51 | (1, "ColumnTransformer 1 / 2"), # parent progress changed 52 | (2, "StandardScaler 0 / 1"), 53 | (2, "Pipeline 0 / 1"), 54 | (3, "MinMaxScaler 0 / 1"), 55 | (1, "LogisticRegression 0 / 7"), 56 | ] 57 | -------------------------------------------------------------------------------- /sklearn_callbacks/tests/test_estimators.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from sklearn.compose import make_column_transformer 4 | from sklearn.datasets import load_iris 5 | from sklearn.exceptions import ConvergenceWarning 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.pipeline import make_pipeline 8 | from sklearn.preprocessing import MinMaxScaler, StandardScaler 9 | from sklearn_callbacks import DebugCallback 10 | 11 | 12 | def test_pipeline(): 13 | X, y = load_iris(return_X_y=True) 14 | 15 | callback = DebugCallback(verbose=False) 16 | print("") 17 | 18 | pipe = make_pipeline(StandardScaler(), LogisticRegression(max_iter=3)) 19 | pipe._set_callbacks(callback) 20 | with warnings.catch_warnings(): 21 | warnings.filterwarnings("ignore", category=ConvergenceWarning) 22 | pipe.fit(X, y) 23 | 24 | log_expected = [ 25 | "fit_begin Pipeline", 26 | "fit_begin StandardScaler", 27 | "fit_begin StandardScaler", # why second time? 28 | r"fit_begin LogisticRegression\(max_iter=3\)", 29 | "iter_end coef=.*", 30 | "iter_end coef=.*", 31 | "iter_end coef=.*", 32 | ] 33 | callback.check_log_expected(log_expected) 34 | 35 | 36 | def test_pipeline_column_transformer(): 37 | X, y = load_iris(return_X_y=True) 38 | 39 | callback = DebugCallback(verbose=False) 40 | 41 | pipe = make_pipeline( 42 | make_column_transformer( 43 | (StandardScaler(), [0, 1]), (MinMaxScaler(), [2, 3]), 44 | ), 45 | LogisticRegression(max_iter=3), 46 | ) 47 | pipe._set_callbacks(callback) 48 | 49 | with warnings.catch_warnings(): 50 | warnings.filterwarnings("ignore", category=ConvergenceWarning) 51 | pipe.fit(X, y) 52 | 53 | log_expected = [ 54 | "fit_begin Pipeline", 55 | "fit_begin ColumnTransformer", 56 | "fit_begin StandardScaler", 57 | "fit_begin StandardScaler", # why second time? 58 | "fit_begin MinMaxScaler", 59 | r"fit_begin LogisticRegression\(max_iter=3\)", 60 | "iter_end coef=.*", 61 | "iter_end coef=.*", 62 | "iter_end coef=.*", 63 | ] 64 | callback.check_log_expected(log_expected) 65 | -------------------------------------------------------------------------------- /sklearn_callbacks/_convergence_monitor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import sklearn.metrics 4 | from sklearn._callbacks import BaseCallback 5 | from sklearn.base import clone 6 | from sklearn.linear_model._base import LinearModel 7 | 8 | 9 | class ConvergenceMonitor(BaseCallback): 10 | """Monitor model convergence. 11 | 12 | Currently only a few linear models are supported 13 | (e.g. ``Ridge(solver="sag")``) 14 | 15 | Parameters 16 | ---------- 17 | metric 18 | metric to evaluate 19 | X_test, y_test 20 | optional validation data 21 | """ 22 | 23 | def __init__(self, metric: str, X_test=None, y_test=None): 24 | self.metric = metric 25 | self.metric_func = getattr(sklearn.metrics, metric, None) 26 | if self.metric_func is None: 27 | raise ValueError(f"uknown metric={metric}") 28 | self.data: List[Dict] = [] 29 | self.X_test = X_test 30 | self.y_test = y_test 31 | 32 | def on_fit_begin(self, estimator, X, y): 33 | if not isinstance(estimator, LinearModel): 34 | # not implemented 35 | return 36 | self.X_train = X 37 | self.y_train = y 38 | # Explicitly clone later, so that estimator 39 | # attributes can still be modified in fit 40 | self.estimator = estimator 41 | 42 | def on_iter_end(self, **kwargs): 43 | coef = kwargs.get("coef", None) 44 | intercept = kwargs.get("intercept", None) 45 | if coef is None or intercept is None: 46 | raise NotImplementedError 47 | 48 | # create a new estimator with updated coefs 49 | est = clone(self.estimator) 50 | est.coef_ = coef.reshape(-1) 51 | est.intercept_ = intercept.reshape(-1) 52 | 53 | y_pred = est.predict(self.X_train) 54 | score_train = self.metric_func(self.y_train, y_pred) 55 | res = {"score_train": score_train} 56 | if self.X_test is not None: 57 | y_pred = est.predict(self.X_test) 58 | score_test = self.metric_func(self.y_test, y_pred) 59 | res["score_test"] = score_test 60 | self.data.append(res) 61 | 62 | def plot(self, ax=None): 63 | import pandas as pd 64 | import matplotlib.pyplot as plt 65 | 66 | if ax is None: 67 | fig, ax = plt.subplots() 68 | df = pd.DataFrame(self.data) 69 | df.plot(ax=ax) 70 | ax.set_xlabel("Number of iterations") 71 | ax.set_ylabel(self.metric) 72 | with sklearn.config_context(print_changed_only=True): 73 | ax.set_title(str(self.estimator)) 74 | -------------------------------------------------------------------------------- /sklearn_callbacks/_progressbar.py: -------------------------------------------------------------------------------- 1 | from sklearn._callbacks import BaseCallback 2 | 3 | from ._computational_graph import ComputeGraph 4 | 5 | 6 | class TqdmPbar: 7 | def __init__(self, name, **kwargs): 8 | from tqdm.auto import tqdm 9 | 10 | self.pbar = tqdm(**kwargs) 11 | self.n_steps = kwargs.get("total", None) 12 | self.n_iter = 0 13 | self.name = name 14 | 15 | def update(self, node, n_iter=None, **kwargs): 16 | desc = node.name 17 | if "loss" in kwargs: 18 | desc += f", loss={kwargs['loss']:8g}" 19 | elif "score" in kwargs: 20 | desc += f", score={kwargs['score']:.4f}" 21 | if desc is not None: 22 | self.pbar.set_description(desc) 23 | if n_iter is None: 24 | n_iter = node.n_iter + 1 25 | self.pbar.update(max(n_iter - self.n_iter, 0)) 26 | self.n_iter = max(n_iter, self.n_iter) 27 | 28 | def close(self): 29 | self.pbar.close() 30 | 31 | def finalize(self): 32 | if self.n_steps and not None and self.n_iter < self.n_steps: 33 | self.pbar.update(self.n_steps - self.n_iter) 34 | return self 35 | 36 | 37 | def _get_node_at_depth(node_init, depth=1): 38 | if depth > node_init.depth: 39 | raise ValueError 40 | 41 | node = node_init 42 | while True: 43 | node_depth = node.depth 44 | if node_depth == 1: 45 | return node 46 | node = node.parent 47 | 48 | 49 | class ProgressBar(BaseCallback): 50 | def __init__(self): 51 | self.pbar = None 52 | self.pbar2 = None 53 | self.compute_graph = None 54 | 55 | def on_fit_begin(self, estimator, X, y): 56 | if self.compute_graph is None: 57 | # assume this first call was made from the root node. 58 | self.compute_graph = ComputeGraph.from_estimator(estimator) 59 | self.compute_graph.update_state(estimator) 60 | root = self.compute_graph.root_node 61 | if self.pbar is None: 62 | self.pbar = TqdmPbar( 63 | total=root.n_steps, name=root.name, desc=root.name, leave=False 64 | ) 65 | self.pbar.update(root) 66 | current_node = self.compute_graph.current_node 67 | if current_node.depth >= 1: 68 | node = _get_node_at_depth(current_node, depth=1) 69 | 70 | if self.pbar2 is not None and self.pbar2.name != node.name: 71 | self.pbar2.finalize().close() 72 | self.pbar2 = None 73 | 74 | if self.pbar2 is None: 75 | self.pbar2 = TqdmPbar( 76 | total=node.n_steps, 77 | name=node.name, 78 | desc=node.name, 79 | leave=False, 80 | ) 81 | 82 | self.pbar2.update(node) 83 | 84 | else: 85 | if self.pbar2 is not None: 86 | self.pbar2.close() 87 | 88 | def on_iter_end(self, **kwargs): 89 | self.compute_graph.current_node.n_iter += 1 90 | 91 | root = self.compute_graph.root_node 92 | 93 | self.pbar.update(root, **kwargs) 94 | 95 | current_node = self.compute_graph.current_node 96 | if current_node.depth >= 1 and self.pbar2 is not None: 97 | node = _get_node_at_depth(current_node, depth=1) 98 | self.pbar2.update(node, **kwargs) 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sklearn-callbacks 2 | 3 | [![CircleCI](https://circleci.com/gh/rth/sklearn-callbacks.svg?style=svg)](https://circleci.com/gh/rth/sklearn-callbacks) 4 | 5 | Experimental callbacks for scikit-learn: progress bars, monitoring convergence etc. 6 | 7 | ## Install 8 | 9 | This package require a patched scikit-learn 0.24.0dev0 from [scikit-learn#16925](https://github.com/scikit-learn/scikit-learn/pull/16925), 10 | ``` 11 | pip install https://github.com/rth/scikit-learn/archive/progress-bar.zip 12 | pip install git+https://github.com/rth/sklearn-callbacks.git 13 | ``` 14 | 15 | ## Usage 16 | 17 | ### Progress bars 18 | 19 | This package implements progress bars for estimators with iterative solvers, 20 | ```py 21 | from sklearn.datasets import make_classification 22 | from sklearn.linear_model import SGDClassifier 23 | from sklearn_callbacks import ProgressBar 24 | 25 | X, y = make_classification(n_samples=200000, n_features=200, random_state=0) 26 | 27 | est = SGDClassifier(max_iter=100, tol=1e-4) 28 | est._set_callbacks(ProgressBar()) 29 | 30 | est.fit(X, y) 31 | ``` 32 | ![SGD progress bar](./doc/static/img/progressbar-sgd.gif?raw=true "SGD progress bar") 33 | 34 | more complex scikit-learn pipelines are also supported, 35 | ```py 36 | # see details for full list of imports 37 | from sklearn_callbacks import ProgressBar 38 | 39 | X, y = make_classification(n_samples=500000, n_features=200, random_state=0) 40 | 41 | pipe = make_pipeline( 42 | SimpleImputer(), 43 | make_column_transformer( 44 | (StandardScaler(), slice(0, 80)), 45 | (MinMaxScaler(), slice(80, 120)), 46 | (StandardScaler(with_mean=False), slice(120, 180)), 47 | ), 48 | LogisticRegression(), 49 | ) 50 | 51 | 52 | pipe._set_callbacks(ProgressBar()) 53 | pipe.fit(X, y) 54 | ``` 55 | 56 |
57 | 58 | ```py 59 | from sklearn.datasets import make_classification 60 | from sklearn.linear_model import LogisticRegression 61 | from sklearn.pipeline import make_pipeline 62 | from sklearn.compose import make_column_transformer 63 | from sklearn.preprocessing import StandardScaler, MinMaxScaler 64 | from sklearn.impute import SimpleImputer 65 | 66 | from sklearn_callbacks import ProgressBar 67 | 68 | X, y = make_classification(n_samples=500000, n_features=200, random_state=0) 69 | 70 | pipe = make_pipeline( 71 | SimpleImputer(), 72 | make_column_transformer( 73 | (StandardScaler(), slice(0, 80)), 74 | (MinMaxScaler(), slice(80, 120)), 75 | (StandardScaler(with_mean=False), slice(120, 180)), 76 | ), 77 | LogisticRegression(), 78 | ) 79 | 80 | pipe._set_callbacks(ProgressBar()) 81 | 82 | pipe.fit(X, y) 83 | ``` 84 |
85 | 86 | ![pipeline progress bar](./doc/static/img/progressbar-pipeline.gif?raw=true "pipeline progress bar") 87 | 88 | ### Monitoring convergence 89 | 90 | ```py 91 | from sklearn.linear_model import Ridge 92 | from sklearn.model_selection import train_test_split 93 | from sklearn.datasets import make_regression 94 | from sklearn.pipeline import make_pipeline 95 | from sklearn.preprocessing import StandardScaler 96 | 97 | from sklearn_callbacks import ConvergenceMonitor 98 | 99 | X, y = make_regression(random_state=0) 100 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) 101 | 102 | conv_mon = ConvergenceMonitor("mean_absolute_error", X_test, y_test) 103 | 104 | pipe = make_pipeline(StandardScaler(), Ridge(solver="sag", alpha=1)) 105 | pipe._set_callbacks(conv_mon) 106 | _ = pipe.fit(X_train, y_train) 107 | 108 | conv_mon.plot() 109 | ``` 110 | 111 | ![convergence monitor](./doc/static/img/convergence-monitor.png?raw=true "convergence monitor") 112 | 113 | 114 | 115 | ## License 116 | 117 | This project is distributed under the BSD 3-clause license. 118 | -------------------------------------------------------------------------------- /sklearn_callbacks/_computational_graph.py: -------------------------------------------------------------------------------- 1 | # Building an approximate graph for scikit-learn calculations. 2 | # 3 | # The goal is to determine for each callback call which estimator 4 | # it belongs to, and what impact it has on the overall progress 5 | # of computations. Very experimental. 6 | from sklearn.base import BaseEstimator 7 | 8 | from collections.abc import Iterator 9 | 10 | 11 | class ComputeNode: 12 | """A node in a scikit-learn computational graph""" 13 | 14 | def __init__(self, _id, name, parent=None, children=None, max_iter=None): 15 | self._id = _id 16 | self.name = name 17 | if children is None: 18 | self.children = [] 19 | else: 20 | self.children = children 21 | self.parent = parent 22 | self.n_iter = 0 23 | self.max_iter = max_iter 24 | 25 | @classmethod 26 | def from_estimator(cls, estimator, max_depth=0, parent=None): 27 | """Create a computational node from estimator. 28 | 29 | Set max_depth=0, to avoid recursive build. And max_depth=-1 30 | to remove limitation on the recursion depth 31 | """ 32 | name = estimator.__class__.__name__ 33 | _id = id(estimator) 34 | max_iter = estimator.get_params().get("max_iter", None) 35 | node = cls(_id, name, max_iter=max_iter, parent=parent) 36 | 37 | if max_depth == 0: 38 | return node 39 | 40 | if max_depth == -1: 41 | child_max_depth = -1 42 | else: 43 | child_max_depth = max_depth - 1 44 | for attr_name in getattr(estimator, "_required_parameters", []): 45 | # likely a meta-estimator 46 | if attr_name not in ["steps", "transformers"]: 47 | continue 48 | for attr in getattr(estimator, attr_name): 49 | if isinstance(attr, BaseEstimator): 50 | node.children.append( 51 | cls.from_estimator( 52 | attr, max_depth=child_max_depth, parent=node 53 | ) 54 | ) 55 | elif ( 56 | hasattr(attr, "__len__") 57 | and len(attr) >= 2 58 | and isinstance(attr[1], BaseEstimator) 59 | ): 60 | # e.g. Pipeline or ColumnTransformer 61 | node.children.append( 62 | cls.from_estimator( 63 | attr[1], max_depth=child_max_depth, parent=node 64 | ) 65 | ) 66 | return node 67 | 68 | @property 69 | def root(self): 70 | """Find the root node""" 71 | if self.parent is None: 72 | return self 73 | else: 74 | return self.parent.root 75 | 76 | @property 77 | def depth(self): 78 | if self.parent is None: 79 | return 0 80 | else: 81 | return self.parent.depth + 1 82 | 83 | @property 84 | def n_steps(self): 85 | if len(self.children): 86 | return len(self.children) 87 | elif self.max_iter is not None: 88 | return self.max_iter 89 | else: 90 | return 1 91 | 92 | def __repr__(self): 93 | return f"{self.name} {self.n_iter} / {self.n_steps}" 94 | 95 | def next(self) -> "ComputeNode": 96 | """Depth first tree traversal""" 97 | if self.children: 98 | # go down the graph 99 | return self.children[0] 100 | # one (or several levels up) and onto the next child 101 | prev_node = self 102 | while True: 103 | if prev_node.parent is None: 104 | raise StopIteration 105 | node = prev_node.parent 106 | idx = node.children.index(prev_node) 107 | if idx + 1 < len(node.children): 108 | return node.children[idx + 1] 109 | else: 110 | # go one level up 111 | prev_node = node 112 | 113 | 114 | class ComputeGraph(Iterator): 115 | def __init__(self, root: ComputeNode): 116 | self.root_node = root 117 | self.current_node = root 118 | self._id_map = {node._id: node for node in self} 119 | 120 | @classmethod 121 | def from_estimator(cls, estimator, max_depth=-1): 122 | """Build the computational graph from the root estimator""" 123 | root_node = ComputeNode.from_estimator(estimator, max_depth=max_depth) 124 | return cls(root_node) 125 | 126 | def __next__(self): 127 | node = self.current_node.next() 128 | self.current_node = node 129 | return node 130 | 131 | def __iter__(self): 132 | yield self.root_node 133 | node_prev = self.root_node 134 | while True: 135 | try: 136 | node = node_prev.next() 137 | except StopIteration: 138 | break 139 | yield node 140 | node_prev = node 141 | 142 | def __repr__(self): 143 | out = [] 144 | for node in self: 145 | indent = node.depth 146 | out.append("{}- {}".format(" " * indent, node)) 147 | return "\n".join(out) 148 | 149 | def __len__(self): 150 | return len([None for el in self]) 151 | 152 | def update_state(self, estimator): 153 | """Set the next active state 154 | 155 | All earlier node are assumed to be computed, but only 156 | parent n_iter is properly updated. 157 | """ 158 | node = self._id_map.get(id(estimator), None) 159 | name = estimator.__class__.__name__ 160 | if node is None: 161 | next_node = self.current_node.next() 162 | if self.current_node.name == name: 163 | return 164 | elif next_node.name == name: 165 | node = next_node 166 | else: 167 | raise ValueError(f"Could not identify state for {estimator}") 168 | 169 | self.current_node = node 170 | # update parent progress 171 | if node.parent is not None: 172 | idx = node.parent.children.index(node) 173 | node.parent.n_iter = idx 174 | --------------------------------------------------------------------------------