├── tests ├── __init__.py ├── test_logger.py ├── test_sklearn_pipeline.py ├── test_default_lightgbm_parameters.py ├── test_reproducibility.py ├── test_imputed_accuracy.py ├── test_utils.py └── test_ImputationKernel.py ├── README_files ├── README_48_0.png ├── README_49_0.png ├── README_50_0.png ├── README_51_0.png ├── README_60_0.png ├── README_61_0.png ├── README_63_0.png ├── README_64_0.png └── README_65_0.png ├── pytest.ini ├── docs ├── ImputedData.rst ├── ImputationKernel.rst ├── Makefile ├── conf.py ├── index.rst └── requirements.txt ├── .readthedocs.yml ├── .gitignore ├── miceforest ├── __init__.py ├── default_lightgbm_parameters.py ├── logger.py ├── utils.py ├── imputed_data.py └── imputation_kernel.py ├── pyproject.toml ├── .github └── workflows │ └── run_tests.yml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README_files/README_48_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_48_0.png -------------------------------------------------------------------------------- /README_files/README_49_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_49_0.png -------------------------------------------------------------------------------- /README_files/README_50_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_50_0.png -------------------------------------------------------------------------------- /README_files/README_51_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_51_0.png -------------------------------------------------------------------------------- /README_files/README_60_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_60_0.png -------------------------------------------------------------------------------- /README_files/README_61_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_61_0.png -------------------------------------------------------------------------------- /README_files/README_63_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_63_0.png -------------------------------------------------------------------------------- /README_files/README_64_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_64_0.png -------------------------------------------------------------------------------- /README_files/README_65_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnotherSamWilson/miceforest/HEAD/README_files/README_65_0.png -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | norecursedirs = .git ignore build __pycache__ 3 | filterwarnings= default 4 | testpaths = 5 | tests -------------------------------------------------------------------------------- /docs/ImputedData.rst: -------------------------------------------------------------------------------- 1 | ImputedData Class 2 | ================= 3 | 4 | .. autoclass:: miceforest.imputed_data.ImputedData 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/ImputationKernel.rst: -------------------------------------------------------------------------------- 1 | ImputationKernel Class 2 | ====================== 3 | 4 | .. autoclass:: miceforest.imputation_kernel.ImputationKernel 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | :inherited-members: -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.10" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | 12 | sphinx: 13 | configuration: docs/conf.py -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .gitignore 2 | *.egg-info/ 3 | dist/ 4 | build/ 5 | scratch/ 6 | .idea/ 7 | .pytest_cache 8 | **/__pycache__ 9 | .coverage 10 | htmlcov/ 11 | coverage.* 12 | .codecovtoken 13 | examples/icon_small.png 14 | support/ 15 | setupdev.py 16 | */_build/ 17 | *.bat 18 | dev/ 19 | .Rhistory 20 | benchmarks/* 21 | .venv 22 | poetry.lock 23 | pyproject.toml 24 | *.DS_Store* 25 | .devcontainer 26 | Dockerfile 27 | dev_guide.md 28 | .pypirc 29 | -------------------------------------------------------------------------------- /miceforest/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | miceforest, Multiple Imputation by Chained Equations with LightGBM. 3 | 4 | Class / method / function documentation can be found in the readthedocs: 5 | https://miceforest.readthedocs.io/en/latest/index.html 6 | 7 | Extensive tutorials can be found on the github README: 8 | https://github.com/AnotherSamWilson/miceforest 9 | """ 10 | 11 | import importlib.metadata 12 | 13 | from .imputation_kernel import ImputationKernel 14 | from .imputed_data import ImputedData 15 | from .utils import ampute_data 16 | 17 | __version__ = importlib.metadata.version("miceforest") 18 | 19 | 20 | __all__ = [ 21 | "ImputedData", 22 | "ImputationKernel", 23 | "ampute_data", 24 | ] 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pytest 4 | 5 | from miceforest.logger import Logger 6 | 7 | 8 | def test_logger_records_time_and_summary_structure(): 9 | logger = Logger( 10 | name="unit-test", timed_levels=["dataset", "iteration"], verbose=False 11 | ) 12 | key = (0, 0) 13 | logger.set_start_time(key) 14 | time.sleep(0.01) 15 | logger.record_time(key) 16 | 17 | assert key in logger.time_seconds 18 | assert key not in logger.started_timers 19 | assert logger.time_seconds[key] >= 0 20 | 21 | summary = logger.get_time_spend_summary() 22 | assert summary.index.names == ["dataset", "iteration"] 23 | assert pytest.approx(summary.iloc[0]) == logger.time_seconds[key] 24 | 25 | 26 | def test_logger_prevents_duplicate_starts_and_missing_records(): 27 | logger = Logger(name="safety", timed_levels=["dataset", "iteration"]) 28 | key = (1, 2) 29 | 30 | logger.set_start_time(key) 31 | with pytest.raises(AssertionError): 32 | logger.set_start_time(key) 33 | 34 | with pytest.raises(AssertionError): 35 | logger.record_time((3, 4)) 36 | 37 | 38 | def test_logger_repr_contains_name(): 39 | logger = Logger(name="pretty", timed_levels=["only"]) 40 | assert "pretty" in repr(logger) 41 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'miceforest' 10 | copyright = '2024, Samuel Von Wilson' 11 | author = 'Samuel Von Wilson' 12 | 13 | # -- General configuration --------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 15 | 16 | import os 17 | import sys 18 | sys.path.insert(0, os.path.abspath('../')) # Source code dir relative to this file 19 | 20 | extensions = [ 21 | 'sphinx.ext.autodoc', # Core library for html generation from docstrings 22 | 'sphinx.ext.autosummary', # Create neat summary tables 23 | 'sphinx.ext.napoleon', # The type of docstrings used in miceforest. 24 | ] 25 | autosummary_generate = True 26 | 27 | templates_path = ['_templates'] 28 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 29 | 30 | 31 | # -- Options for HTML output ------------------------------------------------- 32 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 33 | 34 | html_theme = 'sphinx_rtd_theme' 35 | html_static_path = ['_static'] 36 | -------------------------------------------------------------------------------- /tests/test_sklearn_pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.preprocessing import StandardScaler 3 | from sklearn.datasets import load_iris 4 | from sklearn.pipeline import Pipeline 5 | import miceforest as mf 6 | import pandas as pd 7 | 8 | 9 | def make_dataset(seed): 10 | 11 | iris = pd.concat(load_iris(return_X_y=True, as_frame=True), axis=1) 12 | del iris["target"] 13 | iris.rename( 14 | { 15 | "sepal length (cm)": "sl", 16 | "sepal width (cm)": "sw", 17 | "petal length (cm)": "pl", 18 | "petal width (cm)": "pw", 19 | }, 20 | axis=1, 21 | inplace=True, 22 | ) 23 | iris_amp = mf.utils.ampute_data(iris, perc=0.20) 24 | 25 | return iris_amp 26 | 27 | 28 | def test_pipeline(): 29 | 30 | iris_amp_train = make_dataset(1) 31 | iris_amp_test = make_dataset(2) 32 | 33 | kernel = mf.ImputationKernel(iris_amp_train, num_datasets=1) 34 | 35 | pipe = Pipeline( 36 | [ 37 | ("impute", kernel), 38 | ("scaler", StandardScaler()), 39 | ] 40 | ) 41 | 42 | # The pipeline can be used as any other estimator 43 | # and avoids leaking the test set into the train set 44 | X_train_t = pipe.fit_transform(X=iris_amp_train, y=None, impute__iterations=2) 45 | X_test_t = pipe.transform(iris_amp_test) 46 | 47 | assert not np.any(np.isnan(X_train_t)) 48 | assert not np.any(np.isnan(X_test_t)) 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "miceforest" 3 | license = "MIT" 4 | version = "6.0.5" 5 | description = "Multiple Imputation by Chained Equations with LightGBM" 6 | authors = [{name="Sam Von Wilson"}] 7 | readme = "README.md" 8 | classifiers = [ 9 | 'Natural Language :: English', 10 | 'Programming Language :: Python :: 3.9', 11 | 'Programming Language :: Python :: 3.10', 12 | 'Programming Language :: Python :: 3.11', 13 | 'Programming Language :: Python :: 3.12', 14 | "Operating System :: OS Independent" 15 | ] 16 | dependencies = [ 17 | "lightgbm>=4.1.0", 18 | "pandas>=2.1.0", 19 | "numpy", 20 | "scipy>=1.6.0", 21 | "pyarrow>=6.0.1" 22 | ] 23 | 24 | [project.optional-dependencies] 25 | plotting = [ 26 | "plotnine>=0.13.6", 27 | "matplotlib>=3.3.0", 28 | "matplotlib!=3.9.1" 29 | ] 30 | pipeline = [ 31 | "scikit-learn!=0.22.0" 32 | ] 33 | 34 | [project.urls] 35 | Homepage = "https://github.com/AnotherSamWilson/miceforest" 36 | Issues = "https://github.com/AnotherSamWilson/miceforest/issues" 37 | changelog = "https://github.com/AnotherSamWilson/miceforest/releases" 38 | 39 | [tool.setuptools] 40 | packages = ['miceforest'] 41 | 42 | [tool.poetry.dependencies] 43 | python = "^3.10" 44 | 45 | [tool.poetry.group.dev.dependencies] 46 | black = "^25.9.0" 47 | dill = "^0.4.0" 48 | ipython = "^8.17.2" 49 | pytest = "^8.0.0" 50 | jupyterlab = "^3.5.0" 51 | nbconvert = "^7.16.4" 52 | pandoc = "^2.3" 53 | isort = "^5.13.2" 54 | mypy = "^1.11.0" 55 | build = "^1.2.1" 56 | pytest-cov = "^5.0.0" 57 | twine = "^6.2.0" 58 | sphinx = "^7.4.7" 59 | sphinxcontrib-napoleon = "^0.7" 60 | sphinx-rtd-theme = "^2.0.0" 61 | 62 | [tool.mypy] 63 | ignore_missing_imports = true 64 | 65 | [tool.isort] 66 | profile = "black" -------------------------------------------------------------------------------- /tests/test_default_lightgbm_parameters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from miceforest.default_lightgbm_parameters import _LOG_SPACE_SEARCH, _sample_parameters 5 | 6 | 7 | def test_sample_parameters_handles_scalar_list_and_tuple_inputs(): 8 | rng = np.random.RandomState(10) 9 | params = { 10 | "learning_rate": (0.05, 0.2), 11 | "feature_fraction": [0.5, 0.75, 1.0], 12 | "min_data_in_leaf": (1, 8), 13 | "regular_param": 42, 14 | } 15 | 16 | sampled = _sample_parameters(params, rng, "random") 17 | 18 | assert 0.05 <= sampled["learning_rate"] <= 0.2 19 | assert sampled["feature_fraction"] in params["feature_fraction"] 20 | assert 1 <= sampled["min_data_in_leaf"] <= 8 21 | assert isinstance(sampled["min_data_in_leaf"], int) 22 | assert sampled["regular_param"] == 42 23 | assert "seed" in sampled 24 | assert sampled["seed"] >= 0 25 | 26 | 27 | def test_sample_parameters_respects_log_space_search(): 28 | rng = np.random.RandomState(0) 29 | params = {name: (0.1, 1.0) for name in _LOG_SPACE_SEARCH[:2]} 30 | sampled = _sample_parameters(params, rng, "random") 31 | 32 | for name in params: 33 | assert params[name][0] <= sampled[name] <= params[name][1] 34 | 35 | 36 | def test_sample_parameters_requires_random_strategy(): 37 | rng = np.random.RandomState(1) 38 | with pytest.raises(AssertionError): 39 | _sample_parameters({}, rng, "grid") 40 | 41 | 42 | def test_sample_parameters_validates_tuple_bounds(): 43 | rng = np.random.RandomState(2) 44 | with pytest.raises(AssertionError): 45 | _sample_parameters({"bad": (1, 1)}, rng, "random") 46 | 47 | with pytest.raises(AssertionError): 48 | _sample_parameters({"bad": (1, 2, 3)}, rng, "random") 49 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yml: -------------------------------------------------------------------------------- 1 | name: tests + mypy 2 | 3 | on: 4 | push: 5 | branches: [ "major_update_6", "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11", "3.12"] 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | # install & configure poetry 25 | - name: Install Poetry 26 | uses: snok/install-poetry@v1 27 | with: 28 | virtualenvs-create: true 29 | virtualenvs-in-project: true 30 | installer-parallel: true 31 | 32 | # # load cached venv if cache exists 33 | # - name: Load cached venv 34 | # id: cached-poetry-dependencies 35 | # uses: actions/cache@v3 36 | # with: 37 | # path: .venv 38 | # key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 39 | 40 | # # install dependencies if cache does not exist 41 | # - name: Install dependencies 42 | # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 43 | # run: poetry install --no-interaction --no-root 44 | 45 | # install root project 46 | - name: Install project 47 | run: poetry install --no-interaction --with dev --all-extras 48 | 49 | - name: MyPy Checks 50 | run: poetry run mypy miceforest --ignore-missing-imports 51 | 52 | - name: Black Formatting - Package 53 | run: poetry run black miceforest --check 54 | 55 | - name: Black Formatting - Tests 56 | run: poetry run black tests --check 57 | 58 | - name: Isort Checks 59 | run: poetry run isort miceforest --diff 60 | 61 | - name: Pytest 62 | run: poetry run pytest --cov=miceforest --cov-report html 63 | 64 | - name: Upload coverage reports to Codecov 65 | run: | 66 | curl -Os https://cli.codecov.io/latest/linux/codecov 67 | chmod +x codecov 68 | poetry run ./codecov --verbose upload-process -t ${{ secrets.CODECOV_TOKEN }} -n 'service'-${{ github.run_id }} -F service -f coverage-service.xml 69 | 70 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. miceforest documentation master file, created by 2 | sphinx-quickstart on Sat Jul 27 20:34:30 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to miceforest's Documentation! 7 | ====================================== 8 | 9 | This documentation is meant to describe class methods and parameters only, 10 | for a thorough walkthrough of usage, please see the 11 | `Github README `_. 12 | 13 | In general, the user will only be interacting with these two classes: 14 | 15 | 16 | .. toctree:: 17 | :maxdepth: 1 18 | :caption: Classes: 19 | 20 | ImputationKernel 21 | ImputedData 22 | 23 | 24 | How miceforest Works 25 | -------------------- 26 | 27 | Multiple Imputation by Chained Equations ‘fills in’ (imputes) missing 28 | data in a dataset through an iterative series of predictive models. In 29 | each iteration, each specified variable in the dataset is imputed using 30 | the other variables in the dataset. These iterations should be run until 31 | it appears that convergence has been met. 32 | 33 | .. image:: https://i.imgur.com/2L403kU.png 34 | :target: https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#the-mice-algorithm 35 | 36 | This process is continued until all specified variables have been 37 | imputed. Additional iterations can be run if it appears that the average 38 | imputed values have not converged, although no more than 5 iterations 39 | are usually necessary. 40 | 41 | This package provides fast, memory efficient Multiple Imputation by Chained 42 | Equations (MICE) with lightgbm. The R version of this package may be found 43 | `here `_. 44 | 45 | `miceforest` was designed to be: 46 | 47 | - **Fast** 48 | - Uses lightgbm as a backend 49 | - Has efficient mean matching solutions. 50 | - Can utilize GPU training 51 | - **Flexible** 52 | - Can impute pandas dataframes 53 | - Handles categorical data automatically 54 | - Fits into a sklearn pipeline 55 | - User can customize every aspect of the imputation process 56 | - **Production Ready** 57 | - Can impute new, unseen datasets quickly 58 | - Kernels are efficiently compressed during saving and loading 59 | - Data can be imputed in place to save memory 60 | - Can build models on non-missing data -------------------------------------------------------------------------------- /miceforest/default_lightgbm_parameters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .utils import _draw_random_int32 4 | 5 | # A few parameters that generally benefit 6 | # from searching in the log space. 7 | _LOG_SPACE_SEARCH = [ 8 | "min_data_in_leaf", 9 | "min_sum_hessian_in_leaf", 10 | "lambda_l1", 11 | "lambda_l2", 12 | "cat_l2", 13 | "cat_smooth", 14 | "path_smooth", 15 | "min_gain_to_split", 16 | ] 17 | 18 | 19 | # THESE VALUES WILL ALWAYS BE USED WHEN VALUES ARE NOT PASSED BY USER. 20 | # seed is always set by the calling processes _random_state. 21 | # These need to be main parameter names, not aliases 22 | _DEFAULT_LGB_PARAMS = { 23 | "boosting": "random_forest", 24 | "data_sample_strategy": "bagging", 25 | "num_iterations": 48, 26 | "max_depth": 8, 27 | "num_leaves": 128, 28 | "min_data_in_leaf": 1, 29 | "min_sum_hessian_in_leaf": 0.01, 30 | "min_gain_to_split": 0.0, 31 | "bagging_fraction": 0.632, 32 | "feature_fraction_bynode": 0.632, 33 | "bagging_freq": 1, 34 | "verbosity": -1, 35 | } 36 | 37 | 38 | def _sample_parameters(parameters: dict, random_state, parameter_sampling_method: str): 39 | """ 40 | Searches through a parameter set and selects a random 41 | number between the values in any provided tuple of length 2. 42 | """ 43 | assert ( 44 | parameter_sampling_method == "random" 45 | ), "Only random parameter sampling is supported right now." 46 | parameters = parameters.copy() 47 | for p, v in parameters.items(): 48 | if isinstance(v, list): 49 | choice = random_state.choice(v) 50 | elif isinstance(v, tuple): 51 | assert ( 52 | len(v) == 2 53 | ), "Tuples passed must be length 2, representing the bounds." 54 | assert v[0] < v[1], f"{p} lower bound > upper bound" 55 | if p in _LOG_SPACE_SEARCH: 56 | choice = np.exp( 57 | random_state.uniform( 58 | np.log(v[0]), 59 | np.log(v[1]), 60 | size=1, 61 | )[0] 62 | ) 63 | else: 64 | choice = random_state.uniform( 65 | v[0], 66 | v[1], 67 | size=1, 68 | )[0] 69 | if isinstance(v[0], int): 70 | choice = int(choice) 71 | else: 72 | choice = parameters[p] 73 | parameters[p] = choice 74 | 75 | parameters["seed"] = _draw_random_int32(random_state, size=1)[0] 76 | 77 | return parameters 78 | -------------------------------------------------------------------------------- /miceforest/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | from pandas import Series 5 | 6 | 7 | class Logger: 8 | def __init__( 9 | self, 10 | name: str, 11 | timed_levels: List[str], 12 | verbose: bool = False, 13 | ): 14 | """ 15 | miceforest logger. 16 | 17 | Parameters 18 | ---------- 19 | name: str 20 | Name of this logger 21 | datasets: int 22 | How many datasets are in this logger 23 | variable_names: list[str] 24 | The names of the variables being acted on 25 | iterations: int 26 | How many iterations are being run 27 | timed_events: list[str] 28 | A list of the events that are going to be timed 29 | verbose: bool 30 | Should information be printed. 31 | """ 32 | self.name = name 33 | self.verbose = verbose 34 | self.initialization_time = datetime.now() 35 | self.timed_levels = timed_levels 36 | self.started_timers: dict = {} 37 | 38 | if self.verbose: 39 | print(f"Initialized logger with name {name} and {len(timed_levels)} levels") 40 | 41 | self.time_seconds: Dict[Any, float] = {} 42 | 43 | def __repr__(self): 44 | summary_string = f"miceforest logger: {self.name}" 45 | return summary_string 46 | 47 | def log(self, *args, **kwargs): 48 | if self.verbose: 49 | print(*args, **kwargs) 50 | 51 | def set_start_time(self, time_key: Tuple): 52 | assert len(time_key) == len(self.timed_levels) 53 | assert time_key not in list( 54 | self.started_timers 55 | ), f"Timer {time_key} already started" 56 | self.started_timers[time_key] = datetime.now() 57 | 58 | def record_time(self, time_key: Tuple): 59 | """ 60 | Compares the current time with the start time, and records the time difference 61 | in our time log in the appropriate register. Times can stack for a context. 62 | """ 63 | assert time_key in list(self.started_timers), f"Timer {time_key} never started" 64 | seconds = (datetime.now() - self.started_timers[time_key]).total_seconds() 65 | del self.started_timers[time_key] 66 | if time_key in self.time_seconds: 67 | self.time_seconds[time_key] += seconds 68 | else: 69 | self.time_seconds[time_key] = seconds 70 | 71 | def get_time_spend_summary(self): 72 | """ 73 | Returns a frame of the total time taken per variable, event. 74 | Returns a pandas dataframe if pandas is installed. Otherwise, np.array. 75 | """ 76 | summary = Series(self.time_seconds) 77 | summary.index.names = self.timed_levels 78 | return summary 79 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==22.1.0 2 | aiosqlite==0.20.0 3 | alabaster==0.7.16 4 | anyio==4.4.0 5 | argon2-cffi==23.1.0 6 | argon2-cffi-bindings==21.2.0 7 | arrow==1.3.0 8 | asttokens==2.4.1 9 | attrs==23.2.0 10 | Babel==2.15.0 11 | backports.tarfile==1.2.0 12 | beautifulsoup4==4.12.3 13 | black==24.4.2 14 | bleach==6.1.0 15 | build==1.2.1 16 | certifi==2024.7.4 17 | cffi==1.16.0 18 | charset-normalizer==3.3.2 19 | click==8.1.7 20 | comm==0.2.2 21 | contourpy==1.2.1 22 | coverage==7.6.0 23 | cryptography==43.0.0 24 | cycler==0.12.1 25 | debugpy==1.8.2 26 | decorator==5.1.1 27 | defusedxml==0.7.1 28 | dill==0.3.8 29 | docutils==0.20.1 30 | entrypoints==0.4 31 | exceptiongroup==1.2.2 32 | executing==2.0.1 33 | fastjsonschema==2.20.0 34 | fonttools==4.53.1 35 | fqdn==1.5.1 36 | idna==3.7 37 | imagesize==1.4.1 38 | importlib_metadata==8.2.0 39 | iniconfig==2.0.0 40 | ipykernel==6.29.5 41 | ipython==8.26.0 42 | ipython-genutils==0.2.0 43 | isoduration==20.11.0 44 | isort==5.13.2 45 | jaraco.classes==3.4.0 46 | jaraco.context==5.3.0 47 | jaraco.functools==4.0.1 48 | jedi==0.19.1 49 | jeepney==0.8.0 50 | Jinja2==3.1.4 51 | joblib==1.4.2 52 | json5==0.9.25 53 | jsonpointer==3.0.0 54 | jsonschema==4.23.0 55 | jsonschema-specifications==2023.12.1 56 | jupyter-events==0.10.0 57 | jupyter-ydoc==0.2.5 58 | jupyter_client==8.6.2 59 | jupyter_core==5.7.2 60 | jupyter_server==2.14.2 61 | jupyter_server_fileid==0.9.2 62 | jupyter_server_terminals==0.5.3 63 | jupyter_server_ydoc==0.8.0 64 | jupyterlab==3.6.7 65 | jupyterlab_pygments==0.3.0 66 | jupyterlab_server==2.27.3 67 | keyring==25.2.1 68 | kiwisolver==1.4.5 69 | lightgbm==4.5.0 70 | markdown-it-py==3.0.0 71 | MarkupSafe==2.1.5 72 | matplotlib==3.9.1 73 | matplotlib-inline==0.1.7 74 | mdurl==0.1.2 75 | -e git+https://github.com/AnotherSamWilson/miceforest.git@51896d0ff654ea809d4fe8fc6bf3868ad82077ac#egg=miceforest 76 | mistune==3.0.2 77 | mizani==0.11.4 78 | more-itertools==10.3.0 79 | mypy==1.11.0 80 | mypy-extensions==1.0.0 81 | nbclassic==1.1.0 82 | nbclient==0.10.0 83 | nbconvert==7.16.4 84 | nbformat==5.10.4 85 | nest-asyncio==1.6.0 86 | nh3==0.2.18 87 | notebook==6.5.4 88 | notebook_shim==0.2.4 89 | numpy==1.26.4 90 | overrides==7.7.0 91 | packaging==24.1 92 | pandas==2.2.0 93 | pandoc==2.3 94 | pandocfilters==1.5.1 95 | parso==0.8.4 96 | pathspec==0.12.1 97 | patsy==0.5.6 98 | pexpect==4.9.0 99 | pillow==10.4.0 100 | pkginfo==1.10.0 101 | platformdirs==4.2.2 102 | plotnine==0.13.6 103 | pluggy==1.5.0 104 | plumbum==1.8.3 105 | ply==3.11 106 | pockets==0.9.1 107 | prometheus_client==0.20.0 108 | prompt_toolkit==3.0.47 109 | psutil==6.0.0 110 | ptyprocess==0.7.0 111 | pure_eval==0.2.3 112 | pyarrow==17.0.0 113 | pycparser==2.22 114 | Pygments==2.18.0 115 | pyparsing==3.1.2 116 | pyproject_hooks==1.1.0 117 | pytest==8.3.2 118 | pytest-cov==5.0.0 119 | python-dateutil==2.9.0.post0 120 | python-json-logger==2.0.7 121 | pytz==2024.1 122 | PyYAML==6.0.1 123 | pyzmq==26.0.3 124 | readme_renderer==43.0 125 | referencing==0.35.1 126 | requests==2.32.3 127 | requests-toolbelt==1.0.0 128 | rfc3339-validator==0.1.4 129 | rfc3986==2.0.0 130 | rfc3986-validator==0.1.1 131 | rich==13.7.1 132 | rpds-py==0.19.1 133 | scikit-learn==1.5.1 134 | scipy==1.14.0 135 | seaborn==0.13.2 136 | SecretStorage==3.3.3 137 | Send2Trash==1.8.3 138 | six==1.16.0 139 | sniffio==1.3.1 140 | snowballstemmer==2.2.0 141 | soupsieve==2.5 142 | Sphinx==7.4.7 143 | sphinx-rtd-theme==2.0.0 144 | sphinxcontrib-applehelp==1.0.8 145 | sphinxcontrib-devhelp==1.0.6 146 | sphinxcontrib-htmlhelp==2.0.6 147 | sphinxcontrib-jquery==4.1 148 | sphinxcontrib-jsmath==1.0.1 149 | sphinxcontrib-napoleon==0.7 150 | sphinxcontrib-qthelp==1.0.8 151 | sphinxcontrib-serializinghtml==1.1.10 152 | stack-data==0.6.3 153 | statsmodels==0.14.2 154 | terminado==0.18.1 155 | threadpoolctl==3.5.0 156 | tinycss2==1.3.0 157 | tomli==2.0.1 158 | tornado==6.4.1 159 | traitlets==5.14.3 160 | twine==5.1.1 161 | types-python-dateutil==2.9.0.20240316 162 | typing_extensions==4.12.2 163 | tzdata==2024.1 164 | uri-template==1.3.0 165 | urllib3==2.2.2 166 | wcwidth==0.2.13 167 | webcolors==24.6.0 168 | webencodings==0.5.1 169 | websocket-client==1.8.0 170 | y-py==0.6.2 171 | ypy-websocket==0.8.4 172 | zipp==3.19.2 173 | -------------------------------------------------------------------------------- /tests/test_reproducibility.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_iris 2 | import pandas as pd 3 | import numpy as np 4 | import miceforest as mf 5 | 6 | 7 | # Make random state and load data 8 | # Define data 9 | random_state = np.random.RandomState(1991) 10 | iris = pd.concat(load_iris(as_frame=True, return_X_y=True), axis=1) 11 | iris["sp"] = iris["target"].astype("category") 12 | del iris["target"] 13 | iris.rename( 14 | { 15 | "sepal length (cm)": "sl", 16 | "sepal width (cm)": "ws", 17 | "petal length (cm)": "pl", 18 | "petal width (cm)": "pw", 19 | }, 20 | axis=1, 21 | inplace=True, 22 | ) 23 | iris["bc"] = pd.Series(np.random.binomial(n=1, p=0.5, size=150)).astype("category") 24 | iris_amp = mf.ampute_data(iris, perc=0.25, random_state=random_state) 25 | rows = iris_amp.shape[0] 26 | random_seed_array = np.random.choice(range(1000), size=rows, replace=False).astype( 27 | "int32" 28 | ) 29 | 30 | 31 | def test_pandas_reproducibility(): 32 | 33 | datasets = 2 34 | kernel = mf.ImputationKernel( 35 | data=iris_amp, num_datasets=datasets, initialize_empty=False, random_state=2 36 | ) 37 | 38 | kernel2 = mf.ImputationKernel( 39 | data=iris_amp, num_datasets=datasets, initialize_empty=False, random_state=2 40 | ) 41 | 42 | assert kernel.complete_data(0).equals( 43 | kernel2.complete_data(0) 44 | ), "random_state initialization failed to be deterministic" 45 | assert kernel.complete_data(1).equals( 46 | kernel2.complete_data(1) 47 | ), "random_state initialization failed to be deterministic" 48 | 49 | # Run mice for 2 iterations 50 | kernel.mice(2) 51 | kernel2.mice(2) 52 | 53 | assert kernel.complete_data(0).equals( 54 | kernel2.complete_data(0) 55 | ), "random_state after mice() failed to be deterministic" 56 | assert kernel.complete_data(1).equals( 57 | kernel2.complete_data(1) 58 | ), "random_state after mice() failed to be deterministic" 59 | 60 | kernel_imputed_as_new = kernel.impute_new_data( 61 | iris_amp, random_state=4, random_seed_array=random_seed_array 62 | ) 63 | 64 | # Generate and impute new data as a reordering of original 65 | new_order = np.arange(rows) 66 | random_state.shuffle(new_order) 67 | new_data = iris_amp.loc[new_order].reset_index(drop=True) 68 | new_seeds = random_seed_array[new_order] 69 | new_imputed = kernel.impute_new_data( 70 | new_data, random_state=4, random_seed_array=new_seeds 71 | ) 72 | 73 | # Expect deterministic imputations at the record level, since seeds were passed. 74 | for i in range(datasets): 75 | reordered_kernel_completed = ( 76 | kernel_imputed_as_new.complete_data(dataset=0) 77 | .loc[new_order] 78 | .reset_index(drop=True) 79 | ) 80 | new_data_completed = new_imputed.complete_data(dataset=0) 81 | 82 | assert ( 83 | (reordered_kernel_completed == new_data_completed).all().all() 84 | ), "Seeds did not cause deterministic imputations when data was reordered." 85 | 86 | # Generate and impute new data as a subset of original 87 | new_ind = [0, 1, 4, 7, 8, 10] 88 | new_data = iris_amp.loc[new_ind].reset_index(drop=True) 89 | new_seeds = random_seed_array[new_ind] 90 | new_imputed = kernel.impute_new_data( 91 | new_data, random_state=4, random_seed_array=new_seeds 92 | ) 93 | 94 | # Expect deterministic imputations at the record level, since seeds were passed. 95 | for i in range(datasets): 96 | reordered_kernel_completed = ( 97 | kernel_imputed_as_new.complete_data(dataset=0) 98 | .loc[new_ind] 99 | .reset_index(drop=True) 100 | ) 101 | new_data_completed = new_imputed.complete_data(dataset=0) 102 | 103 | assert ( 104 | (reordered_kernel_completed == new_data_completed).all().all() 105 | ), "Seeds did not cause deterministic imputations when data was reordered." 106 | 107 | # Generate and impute new data as a reordering of original 108 | new_order = np.arange(rows) 109 | random_state.shuffle(new_order) 110 | new_data = iris_amp.loc[new_order].reset_index(drop=True) 111 | new_imputed = kernel.impute_new_data( 112 | new_data, random_state=4, random_seed_array=random_seed_array 113 | ) 114 | 115 | # Expect deterministic imputations at the record level, since seeds were passed. 116 | for i in range(datasets): 117 | reordered_kernel_completed = ( 118 | kernel_imputed_as_new.complete_data(dataset=0) 119 | .loc[new_order] 120 | .reset_index(drop=True) 121 | ) 122 | new_data_completed = new_imputed.complete_data(dataset=0) 123 | 124 | assert ( 125 | not (reordered_kernel_completed == new_data_completed).all().all() 126 | ), "Different seeds caused deterministic imputations for all rows / columns." 127 | -------------------------------------------------------------------------------- /tests/test_imputed_accuracy.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_iris 2 | import pandas as pd 3 | import numpy as np 4 | import miceforest as mf 5 | from sklearn.metrics import roc_auc_score 6 | 7 | 8 | def make_dataset(seed): 9 | 10 | random_state = np.random.RandomState(seed) 11 | iris = pd.concat(load_iris(return_X_y=True, as_frame=True), axis=1) 12 | iris["bi"] = random_state.binomial( 13 | 1, (iris["target"] == 0).map({True: 0.9, False: 0.10}), size=150 14 | ) 15 | iris["bi"] = iris["bi"].astype("category") 16 | iris["sp"] = iris["target"].map({0: "A", 1: "B", 2: "C"}).astype("category") 17 | del iris["target"] 18 | iris.rename( 19 | { 20 | "sepal length (cm)": "sl", 21 | "sepal width (cm)": "sw", 22 | "petal length (cm)": "pl", 23 | "petal width (cm)": "pw", 24 | }, 25 | axis=1, 26 | inplace=True, 27 | ) 28 | iris_amp = mf.utils.ampute_data(iris, perc=0.20, random_state=random_state) 29 | 30 | return iris, iris_amp 31 | 32 | 33 | def get_numeric_performance(kernel, variables, iris): 34 | r_squares = {} 35 | iterations = kernel.iteration_count() 36 | for col in variables: 37 | ind = kernel.na_where[col] 38 | orig = iris.loc[ind, col] 39 | imps = kernel[col, iterations, 0] 40 | r_squares[col] = np.corrcoef(orig, imps)[0, 1] ** 2 41 | r_squares = pd.Series(r_squares) 42 | return r_squares 43 | 44 | 45 | def get_imp_mse(kernel, variables, iris): 46 | mses = {} 47 | iterations = kernel.iteration_count() 48 | for col in variables: 49 | ind = kernel.na_where[col] 50 | orig = iris.loc[ind, col] 51 | imps = kernel[col, iterations, 0] 52 | mses[col] = ((orig - imps) ** 2).sum() 53 | mses = pd.Series(mses) 54 | return mses 55 | 56 | 57 | def get_mean_pred_mse(kernel: mf.ImputationKernel, variables, iris): 58 | mses = {} 59 | for col in variables: 60 | ind = kernel.na_where[col] 61 | orig = iris.loc[ind, col] 62 | target = kernel._get_nonmissing_values(col) 63 | pred = target.mean() 64 | mses[col] = ((orig - pred) ** 2).sum() 65 | mses = pd.Series(mses) 66 | return mses 67 | 68 | 69 | def get_categorical_performance(kernel: mf.ImputationKernel, variables, iris): 70 | 71 | rocs = {} 72 | accs = {} 73 | rand_accs = {} 74 | iterations = kernel.iteration_count() 75 | for col in variables: 76 | ind = kernel.na_where[col] 77 | model = kernel.get_model(col, 0, -1) 78 | cand = kernel._make_label(col, seed=model.params["seed"]) 79 | orig = iris.loc[ind, col] 80 | imps = kernel[col, iterations, 0] 81 | bf = kernel._get_bachelor_features(col) 82 | preds = model.predict(bf) 83 | rocs[col] = roc_auc_score(orig, preds, multi_class="ovr", average="macro") 84 | accs[col] = (imps == orig).mean() 85 | rand_accs[col] = np.sum( 86 | cand.value_counts(normalize=True) * orig.value_counts(normalize=True) 87 | ) 88 | rocs = pd.Series(rocs) 89 | accs = pd.Series(accs) 90 | rand_accs = pd.Series(rand_accs) 91 | return rocs, accs, rand_accs 92 | 93 | 94 | def test_defaults(): 95 | 96 | for i in range(10): 97 | # i = 3 98 | print(i) 99 | iris, iris_amp = make_dataset(i) 100 | kernel_1 = mf.ImputationKernel( 101 | iris_amp, 102 | num_datasets=1, 103 | data_subset=0, 104 | mean_match_candidates=3, 105 | initialize_empty=True, 106 | random_state=i, 107 | ) 108 | kernel_1.mice(4, verbose=False) 109 | kernel_1.complete_data(0, inplace=True) 110 | 111 | rocs, accs, rand_accs = get_categorical_performance( 112 | kernel_1, ["bi", "sp"], iris 113 | ) 114 | assert np.all(accs > rand_accs) 115 | assert np.all(rocs > 0.6) 116 | 117 | # sw Just doesn't have the information density to pass this test reliably. 118 | # It's definitely the hardest variable to model. 119 | mses = get_imp_mse(kernel_1, ["sl", "pl", "pw"], iris) 120 | mpses = get_mean_pred_mse(kernel_1, ["sl", "pl", "pw"], iris) 121 | assert np.all(mpses > mses) 122 | 123 | 124 | def test_no_mean_match(): 125 | 126 | for i in range(10): 127 | # i = 0 128 | iris, iris_amp = make_dataset(i) 129 | kernel_1 = mf.ImputationKernel( 130 | iris_amp, 131 | num_datasets=1, 132 | data_subset=0, 133 | mean_match_candidates=0, 134 | initialize_empty=True, 135 | random_state=i, 136 | ) 137 | kernel_1.mice(4, verbose=False) 138 | kernel_1.complete_data(0, inplace=True) 139 | 140 | rocs, accs, rand_accs = get_categorical_performance( 141 | kernel=kernel_1, variables=["bi", "sp"], iris=iris 142 | ) 143 | assert np.all(accs > rand_accs) 144 | assert np.all(rocs > 0.5) 145 | 146 | # sw Just doesn't have the information density to pass this test reliably. 147 | # It's definitely the hardest variable to model. 148 | mses = get_imp_mse(kernel_1, ["sl", "pl", "pw"], iris) 149 | mpses = get_mean_pred_mse(kernel_1, ["sl", "pl", "pw"], iris) 150 | assert np.all(mpses > mses) 151 | 152 | 153 | def test_custom_params(): 154 | 155 | for i in range(10): 156 | # i = 0 157 | iris, iris_amp = make_dataset(i) 158 | kernel_1 = mf.ImputationKernel( 159 | iris_amp, 160 | num_datasets=1, 161 | data_subset=0, 162 | mean_match_candidates=1, 163 | initialize_empty=True, 164 | random_state=i, 165 | ) 166 | kernel_1.mice( 167 | iterations=4, 168 | verbose=False, 169 | boosting="random_forest", 170 | num_iterations=200, 171 | min_data_in_leaf=2, 172 | ) 173 | kernel_1.complete_data(0, inplace=True) 174 | 175 | rocs, accs, rand_accs = get_categorical_performance( 176 | kernel=kernel_1, variables=["bi", "sp"], iris=iris 177 | ) 178 | assert np.all(accs > rand_accs) 179 | assert np.all(rocs > 0.5) 180 | 181 | # sw Just doesn't have the information density to pass this test reliably. 182 | # It's definitely the hardest variable to model. 183 | mses = get_imp_mse(kernel_1, ["sl", "pl", "pw"], iris) 184 | mpses = get_mean_pred_mse(kernel_1, ["sl", "pl", "pw"], iris) 185 | assert np.all(mpses > mses) 186 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from miceforest.utils import ( 2 | _draw_random_int32, 3 | _expand_value_to_dict, 4 | _list_union, 5 | ampute_data, 6 | ensure_rng, 7 | get_best_int_downcast, 8 | hash_numpy_int_array, 9 | logodds, 10 | logistic_function, 11 | stratified_categorical_folds, 12 | stratified_continuous_folds, 13 | stratified_subset, 14 | ) 15 | import pytest 16 | import numpy as np 17 | import pandas as pd 18 | 19 | 20 | def test_get_best_int_downcast_selects_smallest_dtype(): 21 | assert get_best_int_downcast(12) == "uint8" 22 | assert get_best_int_downcast(np.iinfo("uint8").max) == "uint8" 23 | assert get_best_int_downcast(np.iinfo("uint16").max) == "uint16" 24 | assert get_best_int_downcast(np.iinfo("uint32").max) == "uint32" 25 | with pytest.raises(ValueError): 26 | get_best_int_downcast(np.iinfo("uint64").max + 1) 27 | 28 | 29 | def test_ampute_data_reproducible_and_scoped(): 30 | data = pd.DataFrame( 31 | { 32 | "a": np.arange(6, dtype="float64"), 33 | "b": np.arange(6, 12, dtype="float64"), 34 | "c": np.arange(12, 18, dtype="float64"), 35 | } 36 | ) 37 | amputed_one = ampute_data(data, variables=["a", "b"], perc=0.5, random_state=7) 38 | amputed_two = ampute_data(data, variables=["a", "b"], perc=0.5, random_state=7) 39 | 40 | assert amputed_one.equals(amputed_two) 41 | assert amputed_one["a"].isna().sum() == 3 42 | assert amputed_one["b"].isna().sum() == 3 43 | assert amputed_one["c"].isna().sum() == 0 44 | assert amputed_one.isna().sum().sum() == 6 45 | 46 | 47 | def test_hash_numpy_int_array_mutates_and_validates_dtype(): 48 | arr = np.array([1, 2, 3], dtype="int32") 49 | baseline = arr.copy() 50 | hash_numpy_int_array(arr) 51 | assert not np.array_equal(arr, baseline) 52 | assert arr.dtype == baseline.dtype 53 | 54 | arr_uint = np.array([1, 2, 3], dtype="uint64") 55 | baseline_uint = arr_uint.copy() 56 | hash_numpy_int_array(arr_uint) 57 | assert not np.array_equal(arr_uint, baseline_uint) 58 | 59 | arr_float = np.array([1.0, 2.0], dtype="float64") 60 | with pytest.raises(ValueError): 61 | hash_numpy_int_array(arr_float) 62 | 63 | 64 | def test_draw_random_int32_respects_bounds_and_dtype(): 65 | rng = np.random.RandomState(4) 66 | samples = _draw_random_int32(rng, size=10) 67 | assert samples.dtype == np.int32 68 | assert np.all(samples >= 0) 69 | assert np.all(samples <= np.iinfo("int32").max) 70 | 71 | 72 | def test_expand_value_to_dict_with_scalar_and_partial_dict(): 73 | keys = ["x", "y"] 74 | scalar_result = _expand_value_to_dict(0, 5, keys) 75 | assert scalar_result == {"x": 5, "y": 5} 76 | 77 | dict_result = _expand_value_to_dict(0, {"x": 1}, keys) 78 | assert dict_result == {"x": 1, "y": 0} 79 | 80 | 81 | def test_list_union_returns_intersection_preserving_order(): 82 | assert _list_union(["a", "b", "c"], ["b", "d", "c"]) == ["b", "c"] 83 | 84 | 85 | def test_logodds_and_logistic_are_inverses_for_valid_probability(): 86 | probability = 0.2 87 | logits = logodds(probability) 88 | recovered = logistic_function(logits) 89 | assert pytest.approx(recovered) == probability 90 | 91 | with pytest.raises(ValueError): 92 | logodds(1.0) 93 | 94 | 95 | def test_stratified_continuous_folds_cover_all_indices(): 96 | y = pd.Series(np.linspace(0.0, 1.0, 12)) 97 | folds = list(stratified_continuous_folds(y, nfold=4)) 98 | 99 | assert len(folds) == 4 100 | all_validation = np.concatenate([val for _, val in folds]) 101 | assert np.array_equal(np.sort(all_validation), np.arange(len(y))) 102 | for train, val in folds: 103 | assert len(np.intersect1d(train, val)) == 0 104 | assert len(val) == len(y) // 4 105 | 106 | 107 | def test_stratified_categorical_folds_adjusts_fold_count(capsys): 108 | y = pd.Series([0, 0, 0, 1, 1, 2], dtype="int64") 109 | folds = list(stratified_categorical_folds(y, nfold=3)) 110 | captured = capsys.readouterr() 111 | 112 | assert "Decreasing nfold" in captured.out 113 | assert len(folds) == 1 114 | train, val = folds[0] 115 | assert len(val) == len(y) 116 | assert len(train) == 0 117 | 118 | 119 | def test_stratified_categorical_folds_balanced_counts(): 120 | y = pd.Series([0, 0, 1, 1, 2, 2], dtype="int64") 121 | folds = list(stratified_categorical_folds(y, nfold=2)) 122 | 123 | assert len(folds) == 2 124 | all_val = np.concatenate([val for _, val in folds]) 125 | assert np.array_equal(np.sort(all_val), np.arange(len(y))) 126 | 127 | 128 | def test_subset(): 129 | 130 | strat_std_closer = [] 131 | strat_mean_closer = [] 132 | for i in range(1000): 133 | y = pd.Series(np.random.normal(size=1000)) 134 | size = 100 135 | ss_ind = stratified_subset(y, size, groups=10, random_state=i) 136 | y_strat_sub = y[ss_ind] 137 | y_rand_sub = np.random.choice(y, size, replace=False) 138 | 139 | # See which random sample has a closer stdev 140 | y_strat_std_diff = abs(y.std() - y_strat_sub.std()) 141 | y_rand_std_diff = abs(y.std() - y_rand_sub.std()) 142 | strat_std_closer.append(y_strat_std_diff < y_rand_std_diff) 143 | 144 | # See which random sample has a closer mean 145 | y_strat_mean_diff = abs(y.mean() - y_strat_sub.mean()) 146 | y_rand_mean_diff = abs(y.mean() - y_rand_sub.mean()) 147 | strat_mean_closer.append(y_strat_mean_diff < y_rand_mean_diff) 148 | 149 | # Assert that the mean and stdev of the 150 | # stratified random draws are closer to the 151 | # original distribution over 50% of the time. 152 | assert np.array(strat_std_closer).mean() > 0.5 153 | assert np.array(strat_mean_closer).mean() > 0.5 154 | 155 | 156 | def test_subset_continuous_reproduce(): 157 | # Tests for reproducibility in numeric stratified subsetting 158 | for i in range(100): 159 | y = pd.Series(np.random.normal(size=1000)) 160 | size = 100 161 | 162 | ss1 = stratified_subset(y, size, groups=10, random_state=i) 163 | ss2 = stratified_subset(y, size, groups=10, random_state=i) 164 | 165 | assert np.all(ss1 == ss2) 166 | 167 | 168 | def test_subset_categorical_reproduce(): 169 | # Tests for reproducibility in categorical stratified subsetting 170 | for i in range(100): 171 | y = pd.Series(np.random.randint(low=1, high=10, size=1000)).astype("category") 172 | size = 100 173 | 174 | ss1 = stratified_subset(y, size, groups=10, random_state=i) 175 | ss2 = stratified_subset(y, size, groups=10, random_state=i) 176 | 177 | assert np.all(ss1 == ss2) 178 | -------------------------------------------------------------------------------- /miceforest/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import numpy as np 4 | from numpy.random import RandomState 5 | from pandas import DataFrame, Series 6 | 7 | 8 | def get_best_int_downcast(x: int): 9 | assert isinstance(x, int) 10 | int_dtypes = ["uint8", "uint16", "uint32", "uint64"] 11 | np_iinfo_max = {dtype: np.iinfo(dtype).max for dtype in int_dtypes} 12 | for dtype, max in np_iinfo_max.items(): 13 | if x <= max: 14 | break 15 | if dtype == "uint64": 16 | raise ValueError("Number too large to downcast") 17 | return dtype 18 | 19 | 20 | def ampute_data( 21 | data: DataFrame, 22 | variables: Optional[List[str]] = None, 23 | perc: float = 0.1, 24 | random_state: Optional[Union[int, np.random.RandomState]] = None, 25 | ): 26 | """ 27 | Ampute Data 28 | 29 | Returns a copy of data with specified variables amputed. 30 | 31 | Parameters 32 | ---------- 33 | data : Pandas DataFrame 34 | The data to ampute 35 | 36 | variables : None or list 37 | If None, are variables are amputed. 38 | 39 | perc : double 40 | The percentage of the data to ampute. 41 | 42 | random_state: None, int, or np.random.RandomState 43 | The random state to use. 44 | 45 | Returns 46 | ------- 47 | pandas DataFrame 48 | The amputed data 49 | """ 50 | amputed_data = data.copy() 51 | num_rows = amputed_data.shape[0] 52 | amp_rows = int(perc * num_rows) 53 | rs = ensure_rng(random_state) 54 | variables = list(data.columns) if variables is None else variables 55 | 56 | for col in variables: 57 | ind = rs.choice(amputed_data.index, size=amp_rows, replace=False) 58 | amputed_data.loc[ind, col] = np.nan 59 | 60 | return amputed_data 61 | 62 | 63 | def stratified_subset( 64 | y: Series, 65 | size: int, 66 | groups: int, 67 | random_state: Optional[Union[int, np.random.RandomState]], 68 | ): 69 | """ 70 | Subsample y using stratification. y is divided into quantiles, 71 | and then elements are randomly chosen from each quantile to 72 | come up with the subsample. 73 | 74 | Parameters 75 | ---------- 76 | y: np.ndarray 77 | The variable to use for stratification 78 | size: int 79 | How large the subset should be 80 | groups: int 81 | How many groups to break y into. The more groups, the more 82 | balanced (but less random) y will be 83 | cat: bool 84 | Is y already categorical? If so, we can skip the group creation 85 | seed: int 86 | The random seed to use. 87 | 88 | Returns 89 | ------- 90 | The indices of y that have been chosen. 91 | 92 | """ 93 | 94 | rs = ensure_rng(random_state=random_state) 95 | 96 | cat = False 97 | if y.dtype.name == "category": 98 | cat = True 99 | y = y.cat.codes 100 | y = y.to_numpy() 101 | 102 | if cat: 103 | digits = y 104 | else: 105 | q = [x / groups for x in range(1, groups)] 106 | bins = np.quantile(y, q) 107 | digits = np.digitize(y, bins, right=True) 108 | 109 | digits_v, digits_c = np.unique(digits, return_counts=True) 110 | digits_i = np.arange(digits_v.shape[0]) 111 | digits_p = digits_c / digits_c.sum() 112 | digits_s = (digits_p * size).round(0).astype("int32") 113 | diff = size - digits_s.sum() 114 | if diff != 0: 115 | digits_fix = rs.choice(digits_i, size=abs(diff), p=digits_p, replace=False) 116 | if diff < 0: 117 | for d in digits_fix: 118 | digits_s[d] -= 1 119 | else: 120 | for d in digits_fix: 121 | digits_s[d] += 1 122 | 123 | sub = np.zeros(shape=size).astype("int32") 124 | added = 0 125 | for d_i in digits_i: 126 | d_v = digits_v[d_i] 127 | n = digits_s[d_i] 128 | ind = np.where(digits == d_v)[0] 129 | choice = rs.choice(ind, size=n, replace=False) 130 | sub[added : (added + n)] = choice 131 | added += n 132 | 133 | sub.sort() 134 | 135 | return sub 136 | 137 | 138 | def stratified_continuous_folds(y: Series, nfold: int): 139 | """ 140 | Create primitive stratified folds for continuous data. 141 | Should be digestible by lightgbm.cv function. 142 | """ 143 | y = y.to_numpy() 144 | elements = y.shape[0] 145 | assert elements >= nfold, "more splits then elements." 146 | sorted = np.argsort(y) 147 | val = [sorted[range(i, len(y), nfold)] for i in range(nfold)] 148 | for v in val: 149 | yield (np.setdiff1d(np.arange(elements), v), v) 150 | 151 | 152 | def stratified_categorical_folds(y: Series, nfold: int): 153 | """ 154 | Create primitive stratified folds for categorical data. 155 | Should be digestible by lightgbm.cv function. 156 | """ 157 | assert isinstance(y, Series), "y must be a pandas Series" 158 | assert y.dtype.name[0:3].lower() == "int", "y should be the category codes" 159 | y = y.to_numpy() 160 | elements = len(y) 161 | uniq, inv, counts = np.unique(y, return_counts=True, return_inverse=True) 162 | assert elements >= nfold, "more splits then elements." 163 | if any(counts < nfold): 164 | print("Decreasing nfold to lowest categorical level count...") 165 | nfold = min(counts) 166 | sorted = np.argsort(inv) 167 | val = [sorted[range(i, len(y), nfold)] for i in range(nfold)] 168 | for v in val: 169 | yield (np.setdiff1d(range(elements), v), v) 170 | 171 | 172 | # https://stackoverflow.com/questions/664014/what-integer-hash-function-are-good-that-accepts-an-integer-hash-key 173 | # This hash performs well enough in testing. 174 | def hash_int32(x: np.ndarray): 175 | """ 176 | A hash function which generates random uniform (enough) 177 | int32 integers. Used in mean matching and initialization. 178 | """ 179 | assert isinstance(x, np.ndarray) 180 | assert x.dtype in ["uint32", "int32"], "x must be int32" 181 | x = ((x >> 16) ^ x) * 0x45D9F3B 182 | x = ((x >> 16) ^ x) * 0x45D9F3B 183 | x = (x >> 16) ^ x 184 | return x 185 | 186 | 187 | def hash_uint64(x: np.ndarray): 188 | assert isinstance(x, np.ndarray) 189 | assert x.dtype == "uint64", "x must be uint64" 190 | x = (x ^ (x >> 30)) * 0xBF58476D1CE4E5B9 191 | x = (x ^ (x >> 27)) * 0x94D049BB133111EB 192 | x = x ^ (x >> 31) 193 | return x 194 | 195 | 196 | def hash_numpy_int_array(x: np.ndarray, ind: Union[np.ndarray, slice] = slice(None)): 197 | """ 198 | Deterministically set the values of the elements in x 199 | at the locations ind to some uniformly distributed number 200 | within the range of the datatype of x. 201 | 202 | This function acts on x in place 203 | """ 204 | assert isinstance(x, np.ndarray) 205 | if x.dtype in ["uint32", "int32"]: 206 | x[ind] = hash_int32(x[ind]) 207 | elif x.dtype == "uint64": 208 | x[ind] = hash_uint64(x[ind]) 209 | else: 210 | raise ValueError("random_seed_array must be uint32, int32, or uint64 datatype") 211 | 212 | 213 | def _draw_random_int32(random_state, size): 214 | nums = random_state.randint( 215 | low=0, high=np.iinfo("int32").max, size=size, dtype="int32" 216 | ) 217 | return nums 218 | 219 | 220 | def ensure_rng(random_state) -> RandomState: 221 | """ 222 | Creates a random number generator based on an optional seed. This can be 223 | an integer or another random state for a seeded rng, or None for an 224 | unseeded rng. 225 | """ 226 | if random_state is None: 227 | random_state = RandomState() 228 | elif isinstance(random_state, int): 229 | random_state = RandomState(random_state) 230 | else: 231 | assert isinstance(random_state, RandomState) 232 | return random_state 233 | 234 | 235 | # def _ensure_iterable(x): 236 | # """ 237 | # If the object is iterable, return the object. 238 | # Else, return the object in a length 1 list. 239 | # """ 240 | # return x if hasattr(x, "__iter__") else [x] 241 | 242 | 243 | # def _assert_dataset_equivalent(ds1: _t_dat, ds2: _t_dat): 244 | # if isinstance(ds1, DataFrame): 245 | # assert isinstance(ds2, DataFrame) 246 | # assert ds1.equals(ds2) 247 | # else: 248 | # assert isinstance(ds2, np.ndarray) 249 | # np.testing.assert_array_equal(ds1, ds2) 250 | 251 | 252 | # def _ensure_np_array(x): 253 | # if isinstance(x, np.ndarray): 254 | # return x 255 | # if isinstance(x, DataFrame) | isinstance(x, Series): 256 | # return x.values 257 | # else: 258 | # raise ValueError("Can't cast to numpy array") 259 | 260 | 261 | def _expand_value_to_dict(default, value, keys) -> dict: 262 | if isinstance(value, dict): 263 | ret = {key: value.get(key, default) for key in keys} 264 | else: 265 | assert default.__class__ == value.__class__ 266 | ret = {key: value for key in keys} 267 | 268 | return ret 269 | 270 | 271 | def _list_union(x: List, y: List): 272 | return [z for z in x if z in y] 273 | 274 | 275 | def logodds(probability): 276 | try: 277 | odds_ratio = probability / (1 - probability) 278 | log_odds = np.log(odds_ratio) 279 | except ZeroDivisionError: 280 | raise ValueError( 281 | "lightgbm output a probability of 1.0 or 0.0. " 282 | "This is usually because of rare classes. " 283 | "Try adjusting min_data_in_leaf." 284 | ) 285 | 286 | return log_odds 287 | 288 | 289 | def logistic_function(log_odds): 290 | return 1 / (1 + np.exp(-log_odds)) 291 | -------------------------------------------------------------------------------- /tests/test_ImputationKernel.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_iris 2 | import pandas as pd 3 | import numpy as np 4 | import miceforest as mf 5 | from datetime import datetime 6 | from matplotlib.pyplot import close 7 | from tempfile import mkstemp 8 | import dill 9 | 10 | 11 | # Make random state and load data 12 | # Define data 13 | random_state = np.random.RandomState(1991) 14 | iris = pd.concat(load_iris(as_frame=True, return_X_y=True), axis=1) 15 | # iris = iris.sample(100000, replace=True) 16 | iris["sp"] = ( 17 | iris["target"] 18 | .map({0: "Category1", 1: "Category2", 2: "Category3"}) 19 | .astype("category") 20 | ) 21 | del iris["target"] 22 | iris.rename( 23 | { 24 | "sepal length (cm)": "sl", 25 | "sepal width (cm)": "ws", 26 | "petal length (cm)": "pl", 27 | "petal width (cm)": "pw", 28 | }, 29 | axis=1, 30 | inplace=True, 31 | ) 32 | iris["bi"] = ( 33 | pd.Series(np.random.binomial(n=1, p=0.5, size=iris.shape[0])) 34 | .map({0: "FOO", 1: "BAR"}) 35 | .astype("category") 36 | ) 37 | iris["ui8"] = iris["sl"].round(0).astype("UInt8") 38 | iris["ws"] = iris["ws"].astype("float32") 39 | iris.reset_index(drop=True, inplace=True) 40 | amputed_variables = ["sl", "ws", "pl", "sp", "bi", "ui8"] 41 | iris_amp = mf.ampute_data( 42 | iris, variables=amputed_variables, perc=0.25, random_state=random_state 43 | ) 44 | na_where = {var: np.where(iris_amp[var].isnull())[0] for var in iris_amp.columns} 45 | notnan_where = { 46 | var: np.setdiff1d(np.arange(iris_amp.shape[0]), na_where[var], assume_unique=True)[ 47 | 0 48 | ] 49 | for var in iris_amp.columns 50 | } 51 | 52 | new_amputed_data = iris_amp.loc[range(20), :].reset_index(drop=True).copy() 53 | new_nonmissing_data = iris.loc[range(20), :].reset_index(drop=True).copy() 54 | 55 | # Make special datasets that have weird edge cases 56 | # Multiple columns with all missing values 57 | # sp is categorical, and pw had no missing 58 | # values in the original kernel data 59 | new_amputed_data_special_1 = iris_amp.loc[range(20), :].reset_index(drop=True).copy() 60 | for col in ["sp", "pw"]: 61 | new_amputed_data_special_1[col] = np.nan 62 | dtype = iris[col].dtype 63 | new_amputed_data_special_1[col] = new_amputed_data_special_1[col].astype(dtype) 64 | 65 | # Some columns with no missing values 66 | new_amputed_data_special_2 = iris_amp.loc[range(20), :].reset_index(drop=True).copy() 67 | new_amputed_data_special_2[["sp", "ui8"]] = iris.loc[range(20), ["sp", "ui8"]] 68 | 69 | 70 | def make_and_test_kernel(**kwargs): 71 | 72 | # kwargs = { 73 | # "data": iris_amp, 74 | # "num_datasets": 2, 75 | # "mean_match_strategy": "normal", 76 | # "save_all_iterations_data": True, 77 | # } 78 | 79 | # Build a normal kernel, run mice, save, load, and run mice again 80 | kernel = mf.ImputationKernel(**kwargs) 81 | assert kernel.iteration_count() == 0 82 | kernel.mice(iterations=2, verbose=True) 83 | assert kernel.iteration_count() == 2 84 | new_file, filename = mkstemp() 85 | with open(filename, "wb") as file: 86 | dill.dump(kernel, file) 87 | del kernel 88 | with open(filename, "rb") as file: 89 | kernel = dill.load(file) 90 | kernel.mice(iterations=1, verbose=True) 91 | assert kernel.iteration_count() == 3 92 | 93 | modeled_variables = kernel.model_training_order 94 | imputed_variables = kernel.imputed_variables 95 | 96 | # pw has no missing values. 97 | assert "pw" not in imputed_variables 98 | 99 | # Make a completed dataset 100 | completed_data = kernel.complete_data(dataset=0, inplace=False) 101 | 102 | # Make sure the data was imputed 103 | assert all(completed_data[imputed_variables].isnull().sum() == 0) 104 | 105 | # Make sure the dtypes didn't change 106 | for col, series in iris_amp.items(): 107 | dtype = series.dtype 108 | assert completed_data[col].dtype == dtype 109 | 110 | # Make sure the working data wasn't imputed 111 | for var, naw in na_where.items(): 112 | if len(naw) > 0: 113 | assert kernel.working_data.loc[naw, var].isnull().mean() == 1.0 114 | 115 | # Make sure the original nonmissing data wasn't changed 116 | for var, naw in notnan_where.items(): 117 | assert completed_data.loc[naw, var] == iris_amp.loc[naw, var] 118 | 119 | # Impute the data in place now 120 | kernel.complete_data(0, inplace=True) 121 | 122 | # Assert we actually imputed the working data 123 | assert all(kernel.working_data[imputed_variables].isnull().sum() == 0) 124 | 125 | # Assert the original data was not touched 126 | assert all(iris_amp[imputed_variables].isnull().sum() > 0) 127 | 128 | # Make sure the models were trained the way we expect 129 | for variable in modeled_variables: 130 | if variable == "sp": 131 | objective = "multiclass" 132 | elif variable == "bi": 133 | objective = "binary" 134 | else: 135 | objective = "regression" 136 | assert ( 137 | kernel.get_model(variable=variable, dataset=0, iteration=1).params[ 138 | "objective" 139 | ] 140 | == objective 141 | ) 142 | assert ( 143 | kernel.get_model(variable=variable, dataset=0, iteration=2).params[ 144 | "objective" 145 | ] 146 | == objective 147 | ) 148 | assert ( 149 | kernel.get_model(variable=variable, dataset=1, iteration=1).params[ 150 | "objective" 151 | ] 152 | == objective 153 | ) 154 | assert ( 155 | kernel.get_model(variable=variable, dataset=1, iteration=2).params[ 156 | "objective" 157 | ] 158 | == objective 159 | ) 160 | 161 | # Impute a new dataset, and complete the data 162 | imputed_new_data = kernel.impute_new_data(new_amputed_data, verbose=True) 163 | imputed_dataset_0 = imputed_new_data.complete_data( 164 | dataset=0, iteration=2, inplace=False 165 | ) 166 | imputed_dataset_1 = imputed_new_data.complete_data( 167 | dataset=1, iteration=2, inplace=False 168 | ) 169 | 170 | # Assert we didn't just impute the same thing for all values 171 | assert not np.all(imputed_dataset_0 == imputed_dataset_1) 172 | 173 | # Make sure we can impute the special cases 174 | imputed_data_special_1 = kernel.impute_new_data(new_amputed_data_special_1) 175 | 176 | # Before we do anything else, make sure saving / loading works 177 | new_file, filename = mkstemp() 178 | with open(filename, "wb") as file: 179 | dill.dump(imputed_data_special_1, file) 180 | del imputed_data_special_1 181 | with open(filename, "rb") as file: 182 | imputed_data_special_1 = dill.load(file) 183 | 184 | imputed_data_special_2 = kernel.impute_new_data(new_amputed_data_special_2) 185 | imputed_dataset_special_1 = imputed_data_special_1.complete_data(0) 186 | imputed_dataset_special_2 = imputed_data_special_2.complete_data(0) 187 | assert not np.any(imputed_dataset_special_1[modeled_variables].isnull()) 188 | assert not np.any(imputed_dataset_special_2[modeled_variables].isnull()) 189 | 190 | # Reproducibility 191 | random_seed_array = np.random.randint( 192 | 9999, size=new_amputed_data_special_1.shape[0], dtype="uint32" 193 | ) 194 | imputed_data_special_3 = kernel.impute_new_data( 195 | new_data=new_amputed_data_special_1, 196 | random_seed_array=random_seed_array, 197 | random_state=1, 198 | ) 199 | imputed_data_special_4 = kernel.impute_new_data( 200 | new_data=new_amputed_data_special_1, 201 | random_seed_array=random_seed_array, 202 | random_state=1, 203 | ) 204 | assert imputed_data_special_3.complete_data(0).equals( 205 | imputed_data_special_4.complete_data(0) 206 | ) 207 | 208 | # Ensure kernel imputes new data on a subset of datasets deterministically 209 | if kernel.num_datasets > 1: 210 | datasets = list(range(kernel.num_datasets)) 211 | datasets.remove(0) 212 | imputed_data_special_5 = kernel.impute_new_data( 213 | new_data=new_amputed_data_special_1, 214 | datasets=datasets, 215 | random_seed_array=random_seed_array, 216 | random_state=1, 217 | verbose=True, 218 | ) 219 | imputed_data_special_6 = kernel.impute_new_data( 220 | new_data=new_amputed_data_special_1, 221 | datasets=datasets, 222 | random_seed_array=random_seed_array, 223 | random_state=1, 224 | ) 225 | assert imputed_data_special_5.complete_data(1).equals( 226 | imputed_data_special_6.complete_data(1) 227 | ) 228 | 229 | mv = kernel.modeled_variables 230 | 231 | # Test tuning parameters 232 | kernel.tune_parameters( 233 | optimization_steps=2, 234 | use_gbdt=True, 235 | random_state=1, 236 | variable_parameters={ 237 | mv[0]: { 238 | "min_data_in_leaf": (1, 10), 239 | "cat_l2": 0.5, 240 | } 241 | }, 242 | extra_trees=[True, False], 243 | ) 244 | op = kernel.optimal_parameters[mv[0]] 245 | assert "extra_trees" in list(op) 246 | assert op["cat_l2"] == 0.5 247 | assert 1 <= op["min_data_in_leaf"] <= 10 248 | 249 | kernel.tune_parameters( 250 | optimization_steps=2, 251 | use_gbdt=False, 252 | random_state=1, 253 | variable_parameters={ 254 | mv[0]: { 255 | "min_data_in_leaf": (1, 10), 256 | "cat_l2": 0.5, 257 | } 258 | }, 259 | extra_trees=[True, False], 260 | ) 261 | op = kernel.optimal_parameters[mv[0]] 262 | assert "extra_trees" in list(op) 263 | assert op["cat_l2"] == 0.5 264 | assert 1 <= op["min_data_in_leaf"] <= 10 265 | 266 | # Test plotting 267 | kernel.plot_imputed_distributions() 268 | kernel.plot_feature_importance(dataset=0) 269 | kernel.plot_mean_convergence() 270 | 271 | return kernel 272 | 273 | 274 | def test_defaults(): 275 | 276 | kernel_normal = make_and_test_kernel( 277 | data=iris_amp, 278 | num_datasets=2, 279 | mean_match_strategy="normal", 280 | save_all_iterations_data=True, 281 | ) 282 | kernel_fast = make_and_test_kernel( 283 | data=iris_amp, 284 | num_datasets=2, 285 | mean_match_strategy="fast", 286 | save_all_iterations_data=True, 287 | ) 288 | kernel_shap = make_and_test_kernel( 289 | data=iris_amp, 290 | num_datasets=2, 291 | mean_match_strategy="shap", 292 | save_all_iterations_data=True, 293 | ) 294 | kernel_iwp = make_and_test_kernel( 295 | data=iris_amp, 296 | num_datasets=2, 297 | mean_match_candidates=0, 298 | save_all_iterations_data=True, 299 | ) 300 | 301 | 302 | def test_complex(): 303 | 304 | # Customize everything. 305 | vs = { 306 | "sl": ["ws", "pl", "pw", "sp", "bi"], 307 | "ws": ["sl"], 308 | "pl": ["sp", "bi"], 309 | # 'sp': ['sl', 'ws', 'pl', 'pw', 'bc'], # Purposely don't train a variable that does have missing values 310 | "pw": ["sl", "ws", "pl", "sp", "bi"], 311 | "bi": ["ws", "pl", "sp"], 312 | "ui8": ["sp", "ws"], 313 | } 314 | mmc = {"sl": 4, "ws": 0, "bi": 5} 315 | ds = {"sl": int(iris_amp.shape[0] / 2), "ws": 50} 316 | 317 | imputed_var_names = list(vs) 318 | non_imputed_var_names = [c for c in iris_amp if c not in imputed_var_names] 319 | 320 | # Build a normal kernel, run mice, save, load, and run mice again 321 | kernel = make_and_test_kernel( 322 | data=iris_amp, 323 | num_datasets=2, 324 | variable_schema=vs, 325 | mean_match_candidates=mmc, 326 | data_subset=ds, 327 | mean_match_strategy="normal", 328 | save_all_iterations_data=True, 329 | ) 330 | assert kernel.data_subset == { 331 | "sl": 75, 332 | "ws": 50, 333 | "pl": 0, 334 | "bi": 0, 335 | "ui8": 0, 336 | "pw": 0, 337 | }, "mean_match_subset initialization failed" 338 | 339 | kernel_fast = make_and_test_kernel( 340 | data=iris_amp, 341 | num_datasets=2, 342 | variable_schema=vs, 343 | mean_match_candidates=mmc, 344 | data_subset=ds, 345 | mean_match_strategy="fast", 346 | save_all_iterations_data=True, 347 | ) 348 | 349 | mmc_shap = mmc.copy() 350 | mmc_shap["ws"] = 1 351 | kernel_shap = make_and_test_kernel( 352 | data=iris_amp, 353 | num_datasets=2, 354 | variable_schema=vs, 355 | mean_match_candidates=mmc_shap, 356 | data_subset=ds, 357 | mean_match_strategy="shap", 358 | save_all_iterations_data=True, 359 | ) 360 | 361 | mixed_mms = {"sl": "shap", "ws": "fast", "ui8": "fast", "bi": "normal"} 362 | kernel_mixed = make_and_test_kernel( 363 | data=iris_amp, 364 | num_datasets=2, 365 | variable_schema=vs, 366 | mean_match_candidates=mmc, 367 | data_subset=ds, 368 | mean_match_strategy=mixed_mms, 369 | save_all_iterations_data=True, 370 | ) 371 | 372 | 373 | def test_object_column(): 374 | 375 | # Customize everything. 376 | vs = { 377 | "sl": ["ws", "pl", "pw", "sp", "bi"], 378 | "ws": ["sl"], 379 | "pl": ["sp", "bi"], 380 | # 'sp': ['sl', 'ws', 'pl', 'pw', 'bc'], # Purposely don't train a variable that does have missing values 381 | "pw": ["sl", "ws", "pl", "sp", "bi"], 382 | "bi": ["ws", "pl", "sp"], 383 | "ui8": ["sp", "ws"], 384 | } 385 | mmc = {"sl": 4, "ws": 0, "bi": 5} 386 | ds = {"sl": int(iris_amp.shape[0] / 2), "ws": 50} 387 | 388 | iris_amp["obj_col"] = iris_amp["sl"].astype("object") 389 | 390 | imputed_var_names = list(vs) 391 | non_imputed_var_names = [c for c in iris_amp if c not in imputed_var_names] 392 | kernel = mf.ImputationKernel( 393 | data=iris_amp, 394 | num_datasets=2, 395 | variable_schema=vs, 396 | mean_match_candidates=mmc, 397 | data_subset=ds, 398 | mean_match_strategy="normal", 399 | save_all_iterations_data=True, 400 | ) 401 | 402 | assert "obj_col" not in kernel.variable_schema 403 | assert "obj_col" not in kernel.all_var_in_schema 404 | -------------------------------------------------------------------------------- /miceforest/imputed_data.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | from io import BytesIO 3 | from itertools import combinations 4 | from typing import Any, Dict, List, Optional, Union 5 | from warnings import warn 6 | 7 | import numpy as np 8 | from pandas import DataFrame, MultiIndex, RangeIndex, Series, concat, read_parquet 9 | 10 | from .utils import get_best_int_downcast, hash_numpy_int_array 11 | 12 | 13 | class ImputedData: 14 | def __init__( 15 | self, 16 | impute_data: DataFrame, 17 | # num_datasets: int = 5, 18 | datasets: List[int], 19 | variable_schema: Optional[Union[List[str], Dict[str, List[str]]]] = None, 20 | save_all_iterations_data: bool = True, 21 | copy_data: bool = True, 22 | random_seed_array: Optional[np.ndarray] = None, 23 | ): 24 | # All references to the data should be through self. 25 | self.working_data = impute_data.copy() if copy_data else impute_data 26 | self.shape = self.working_data.shape 27 | self.save_all_iterations_data = save_all_iterations_data 28 | self.datasets = datasets 29 | 30 | assert isinstance( 31 | self.working_data.index, RangeIndex 32 | ), "Please reset the index on the dataframe" 33 | 34 | column_names = self.working_data.columns 35 | assert np.all( 36 | [isinstance(col, str) for col in column_names] 37 | ), "Column names must be strings" 38 | 39 | self.column_names = column_names 40 | pd_dtypes_orig = self.working_data.dtypes 41 | 42 | # Collect info about what data is missing. 43 | na_where = {} 44 | for col in column_names: 45 | nas = np.where(self.working_data[col].isnull())[0] 46 | if len(nas) == 0: 47 | best_downcast = "uint8" 48 | else: 49 | best_downcast = get_best_int_downcast(int(nas.max())) 50 | na_where[col] = nas.astype(best_downcast) 51 | na_counts = {col: len(nw) for col, nw in na_where.items()} 52 | self.vars_with_any_missing = [ 53 | col for col, count in na_counts.items() if count > 0 54 | ] 55 | 56 | # If variable_schema was passed, use that as the 57 | # list of variables that should have models trained. 58 | # Otherwise, only train models on variables that have 59 | # missing values. 60 | if variable_schema is None: 61 | modeled_variables = self.vars_with_any_missing.copy() 62 | variable_schema = { 63 | target: [ 64 | regressor for regressor in self.column_names if regressor != target 65 | ] 66 | for target in modeled_variables 67 | } 68 | elif isinstance(variable_schema, list): 69 | variable_schema = { 70 | target: [ 71 | regressor for regressor in self.column_names if regressor != target 72 | ] 73 | for target in variable_schema 74 | } 75 | elif isinstance(variable_schema, dict): 76 | # Don't alter the original dict out of scope 77 | variable_schema = variable_schema.copy() 78 | for target, regressors in variable_schema.items(): 79 | if target in regressors: 80 | raise ValueError(f"{target} being used to impute itself") 81 | 82 | self.variable_schema = variable_schema 83 | 84 | self.modeled_variables = list(self.variable_schema) 85 | self.imputed_variables = [ 86 | col for col in self.modeled_variables if col in self.vars_with_any_missing 87 | ] 88 | 89 | # This should be all variables in the schema, not all variables in the dataset. 90 | self.all_var_in_schema = set( 91 | self.modeled_variables 92 | + [y for x in self.variable_schema.values() for y in x] 93 | ) 94 | 95 | for col in self.all_var_in_schema: 96 | assert pd_dtypes_orig[col].name != "object", ( 97 | "Cannot model an object column, please convert to int or categorical, or " 98 | "specify a variable_schema that does not use the object column." 99 | ) 100 | 101 | if random_seed_array is not None: 102 | assert isinstance(random_seed_array, np.ndarray) 103 | assert ( 104 | random_seed_array.shape[0] == self.shape[0] 105 | ), "random_seed_array must be the same length as data." 106 | # Our hashing scheme doesn't work for specifically the value 0. 107 | # Set any values == 0 to the value 1. 108 | random_seed_array = random_seed_array.copy() 109 | zero_value_seeds = random_seed_array == 0 110 | random_seed_array[zero_value_seeds] = 1 111 | hash_numpy_int_array(random_seed_array) 112 | self.random_seed_array: Optional[np.ndarray] = random_seed_array 113 | else: 114 | self.random_seed_array = None 115 | 116 | self.na_counts = na_counts 117 | self.na_where = na_where 118 | self.num_datasets = len(datasets) 119 | self.initialized = False 120 | self.imputed_variable_count = len(self.imputed_variables) 121 | self.modeled_variable_count = len(self.modeled_variables) 122 | 123 | # Create a multiindexed dataframe to store our imputation values 124 | iv_multiindex = MultiIndex.from_product( 125 | [[0], datasets], names=("iteration", "dataset") 126 | ) 127 | self.imputation_values = { 128 | var: DataFrame(index=na_where[var], columns=iv_multiindex).astype( 129 | pd_dtypes_orig[var] 130 | ) 131 | for var in self.imputed_variables 132 | } 133 | 134 | # Create an iteration counter 135 | self.iteration_tab = {} 136 | for variable in self.modeled_variables: 137 | for dataset in datasets: 138 | self.iteration_tab[variable, dataset] = 0 139 | 140 | # Save the version of miceforest that was used to make this kernel 141 | self.version = importlib.metadata.version("miceforest") 142 | 143 | # Subsetting allows us to get to the imputation values: 144 | def __getitem__(self, tup): 145 | variable, iteration, dataset = tup 146 | return self.imputation_values[variable].loc[:, (iteration, dataset)] 147 | 148 | def __setitem__(self, tup, newitem): 149 | variable, iteration, dataset = tup 150 | imputation_iteration = self.iteration_count(dataset=dataset, variable=variable) 151 | 152 | # Don't throw this warning on initialization 153 | if (iteration <= imputation_iteration) and (iteration > 0): 154 | warn( 155 | f"Overwriting Variable: {variable} Dataset: {dataset} Iteration: iteration" 156 | ) 157 | 158 | self.imputation_values[variable].loc[:, (iteration, dataset)] = newitem 159 | 160 | def __delitem__(self, tup): 161 | variable, iteration, dataset = tup 162 | self.imputation_values[variable].drop( 163 | [(iteration, dataset)], axis=1, inplace=True 164 | ) 165 | 166 | def __getstate__(self): 167 | """ 168 | For pickling 169 | """ 170 | # Copy the entire object, minus the big stuff 171 | state = { 172 | key: value 173 | for key, value in self.__dict__.items() 174 | if key not in ["imputation_values"] 175 | }.copy() 176 | 177 | state["imputation_values"] = {} 178 | 179 | for col, df in self.imputation_values.items(): 180 | byte_stream = BytesIO() 181 | df.to_parquet(byte_stream) 182 | state["imputation_values"][col] = byte_stream 183 | 184 | return state 185 | 186 | def __setstate__(self, state): 187 | """ 188 | For unpickling 189 | """ 190 | self.__dict__ = state 191 | 192 | for col, bytes in self.imputation_values.items(): 193 | self.imputation_values[col] = read_parquet(bytes) 194 | 195 | def __repr__(self): 196 | summary_string = f'\n{" " * 14}Class: ImputedData\n{self._ids_info()}' 197 | return summary_string 198 | 199 | def _ids_info(self): 200 | summary_string = f"""\ 201 | Datasets: {self.num_datasets} 202 | Iterations: {self.iteration_count()} 203 | Data Samples: {self.shape[0]} 204 | Data Columns: {self.shape[1]} 205 | Imputed Variables: {self.imputed_variable_count} 206 | Modeled Variables: {self.modeled_variable_count} 207 | All Iterations Saved: {self.save_all_iterations_data} 208 | """ 209 | return summary_string 210 | 211 | def _get_nonmissing_index(self, variable: str): 212 | na_where = self.na_where[variable] 213 | dtype = na_where.dtype 214 | non_missing_ind = np.setdiff1d( 215 | np.arange(self.shape[0], dtype=dtype), na_where, assume_unique=True 216 | ) 217 | return non_missing_ind 218 | 219 | def _get_nonmissing_values(self, variable: str): 220 | ind = self._get_nonmissing_index(variable) 221 | return self.working_data.loc[ind, variable] 222 | 223 | def _ampute_original_data(self): 224 | """Need to put self.working_data back in its original form""" 225 | for variable in self.imputed_variables: 226 | na_where = self.na_where[variable] 227 | self.working_data.loc[na_where, variable] = np.nan 228 | 229 | def _get_hashed_seeds(self, variable: str): 230 | if self.random_seed_array is not None: 231 | na_where = self.na_where[variable] 232 | hashed_seeds = self.random_seed_array[na_where].copy() 233 | hash_numpy_int_array(self.random_seed_array, ind=na_where) 234 | return hashed_seeds 235 | else: 236 | return None 237 | 238 | def _get_bachelor_features(self, variable): 239 | na_where = self.na_where[variable] 240 | predictors = self.variable_schema[variable] 241 | bachelor_features = self.working_data.loc[na_where, predictors] 242 | return bachelor_features 243 | 244 | def iteration_count( 245 | self, 246 | dataset: Union[slice, int] = slice(None), 247 | variable: Union[slice, str] = slice(None), 248 | ): 249 | """ 250 | Grabs the iteration count for specified variables, datasets. 251 | If the iteration count is not consistent across the provided 252 | datasets/variables, an error will be thrown. Providing None 253 | will use all datasets/variables. 254 | 255 | This is to ensure the process is in a consistent state when 256 | the iteration count is needed. 257 | 258 | Parameters 259 | ---------- 260 | datasets: None or int 261 | The datasets to check the iteration count for. 262 | If :code:`None`, all datasets are assumed (and assured) 263 | to have the same iteration count, otherwise error. 264 | variables: str or None 265 | The variable to check the iteration count for. 266 | If :code:`None`, all variables are assumed (and assured) 267 | to have the same iteration count, otherwise error. 268 | 269 | Returns 270 | ------- 271 | An integer representing the iteration count. 272 | """ 273 | 274 | iteration_tab = Series(self.iteration_tab) 275 | iteration_tab.index.names = ["variable", "dataset"] 276 | 277 | iterations = np.unique(iteration_tab.loc[variable, dataset]) 278 | if iterations.shape[0] > 1: 279 | raise ValueError("Multiple iteration counts found") 280 | else: 281 | return iterations[0] 282 | 283 | def complete_data( 284 | self, 285 | dataset: int = 0, 286 | iteration: int = -1, 287 | inplace: bool = False, 288 | variables: Optional[List[str]] = None, 289 | ): 290 | """ 291 | Return dataset with missing values imputed. 292 | 293 | Parameters 294 | ---------- 295 | dataset: int 296 | The dataset to complete. 297 | iteration: int 298 | Impute data with values obtained at this iteration. 299 | If :code:`-1`, returns the most up-to-date iterations, 300 | even if different between variables. If not -1, 301 | iteration must have been saved in imputed values. 302 | inplace: bool 303 | Should the data be completed in place? If True, 304 | self.working_data is imputed,and nothing is returned. 305 | This is useful if the dataset is very large. If 306 | False, a copy of the data is returned, with missing 307 | values imputed. 308 | 309 | Returns 310 | ------- 311 | The completed data, with values imputed for specified variables. 312 | 313 | """ 314 | 315 | # Return a copy if not inplace. 316 | impute_data = self.working_data if inplace else self.working_data.copy() 317 | 318 | # Figure out which variables we need to impute. 319 | # Never impute variables that are not in imputed_variables. 320 | imp_vars = self.imputed_variables if variables is None else variables 321 | assert set(imp_vars).issubset( 322 | set(self.imputed_variables) 323 | ), "Not all variables specified were imputed." 324 | 325 | for variable in imp_vars: 326 | if iteration == -1: 327 | iteration = self.iteration_count(dataset=dataset, variable=variable) 328 | na_where = self.na_where[variable] 329 | impute_data.loc[na_where, variable] = self[variable, iteration, dataset] 330 | 331 | if not inplace: 332 | return impute_data 333 | 334 | def plot_imputed_distributions( 335 | self, variables: Optional[List[str]] = None, iteration: int = -1 336 | ): 337 | """ 338 | Plot the imputed value distributions. 339 | Red lines are the distribution of original data 340 | Black lines are the distribution of the imputed values. 341 | 342 | Parameters 343 | ---------- 344 | datasets: None, int, list[int] 345 | variables: None, list[str] 346 | The variables to plot. If None, all numeric variables 347 | are plotted. 348 | iteration: int 349 | The iteration to plot the distribution for. 350 | If None, the latest iteration is plotted. 351 | save_all_iterations must be True if specifying 352 | an iteration. 353 | adj_args 354 | Additional arguments passed to plt.subplots_adjust() 355 | 356 | """ 357 | 358 | try: 359 | from plotnine import ( 360 | aes, 361 | facet_wrap, 362 | geom_density, 363 | ggplot, 364 | ggtitle, 365 | scale_color_manual, 366 | theme, 367 | xlab, 368 | ) 369 | except ImportError: 370 | raise ImportError("plotnine must be installed to plot distributions.") 371 | 372 | if iteration == -1: 373 | iteration = self.iteration_count() 374 | 375 | colors = {str(i): "black" for i in range(self.num_datasets)} 376 | colors["-1"] = "red" 377 | 378 | num_vars = self.working_data.select_dtypes("number").columns.to_list() 379 | 380 | if variables is None: 381 | variables = [var for var in self.imputed_variables if var in num_vars] 382 | else: 383 | variables = [var for var in variables if var in num_vars] 384 | 385 | dat = DataFrame() 386 | for variable in variables: 387 | 388 | imps = self.imputation_values[variable].loc[:, iteration].melt() 389 | imps["variable"] = variable 390 | ind = self._get_nonmissing_index(variable) 391 | orig = self.working_data.loc[ind, variable].rename("value").to_frame() 392 | orig["dataset"] = -1 393 | orig["variable"] = variable 394 | dat = concat([dat, imps, orig], axis=0) 395 | 396 | dat["dataset"] = dat["dataset"].astype("string") 397 | 398 | fig = ( 399 | ggplot() 400 | + geom_density( 401 | data=dat, mapping=aes(x="value", group="dataset", color="dataset") 402 | ) 403 | + facet_wrap("variable", scales="free") 404 | + scale_color_manual(values=colors) 405 | + ggtitle("Distribution Plots") 406 | + xlab("") 407 | + theme(legend_position="none") 408 | ) 409 | 410 | return fig 411 | 412 | def plot_mean_convergence( 413 | self, 414 | variables: Optional[List[str]] = None, 415 | ): 416 | """ 417 | Plots the average value and standard deviation of imputations over each iteration. 418 | The lines show the average imputation value for a dataset over the iteration. 419 | The bars show the average standard deviation of the imputation values within datasets. 420 | 421 | Parameters 422 | ---------- 423 | variables: Optional[List[str]], default=None 424 | The variables to plot. By default, all numeric, imputed variables are plotted. 425 | """ 426 | 427 | try: 428 | from plotnine import ( 429 | aes, 430 | element_text, 431 | facet_wrap, 432 | geom_errorbar, 433 | geom_line, 434 | geom_point, 435 | ggplot, 436 | ggtitle, 437 | theme, 438 | theme_538, 439 | xlab, 440 | ylab, 441 | ) 442 | except ImportError: 443 | raise ImportError("plotnine must be installed to plot distributions.") 444 | 445 | num_vars = self.working_data.select_dtypes("number").columns.to_list() 446 | imp_vars = self.imputed_variables 447 | imp_num_vars = [v for v in num_vars if v in imp_vars] 448 | if variables is None: 449 | variables = imp_num_vars 450 | else: 451 | variables = [v for v in variables if v in imp_num_vars] 452 | 453 | plot_dat = DataFrame() 454 | for variable in variables: 455 | dat = self.imputation_values[variable].melt(col_level="iteration") 456 | dat["dataset"] = self.imputation_values[variable].melt(col_level="dataset")[ 457 | "dataset" 458 | ] 459 | dat = ( 460 | dat.groupby(["dataset", "iteration"]) 461 | .agg({"value": ["mean", "std"]}) 462 | .reset_index() 463 | ) 464 | dat["middle"] = dat[("value", "mean")] 465 | dat["upper"] = dat["middle"] + dat[("value", "std")] 466 | dat["lower"] = dat["middle"] - dat[("value", "std")] 467 | del dat["value"] 468 | dat.columns = dat.columns.droplevel(1) 469 | iter_dat = dat.groupby("iteration").agg( 470 | {"lower": "mean", "middle": "mean", "upper": "mean"} 471 | ) 472 | dat["lower"] = dat.iteration.map(iter_dat["lower"]) 473 | dat["stdavg"] = dat.iteration.map(iter_dat["middle"]) 474 | dat["upper"] = dat.iteration.map(iter_dat["upper"]) 475 | dat["variable"] = variable 476 | plot_dat = concat([dat, plot_dat], axis=0) 477 | 478 | fig = ( 479 | ggplot(plot_dat, aes(x="iteration", y="middle", group="dataset")) 480 | + geom_line() 481 | + geom_errorbar( 482 | aes(x="iteration", ymin="lower", ymax="upper", group="dataset") 483 | ) 484 | + geom_point(aes(x="iteration", y="stdavg")) 485 | + facet_wrap("variable", scales="free") 486 | + ggtitle("Mean Convergence Plot") 487 | + xlab("") 488 | + ylab("") 489 | + theme( 490 | plot_title=element_text(ha="left", size=20), 491 | ) 492 | + theme_538() 493 | ) 494 | 495 | return fig 496 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/289387436.svg)](https://zenodo.org/badge/latestdoi/289387436) 2 | [![Downloads](https://static.pepy.tech/badge/miceforest)](https://pepy.tech/project/miceforest) 3 | [![Pypi](https://img.shields.io/pypi/v/miceforest.svg)](https://pypi.python.org/pypi/miceforest) 4 | [![Conda 5 | Version](https://img.shields.io/conda/vn/conda-forge/miceforest.svg)](https://anaconda.org/conda-forge/miceforest) 6 | [![PyVersions](https://img.shields.io/pypi/pyversions/miceforest.svg?logo=python&logoColor=white)](https://pypi.org/project/miceforest/) 7 | [![tests + 8 | mypy](https://github.com/AnotherSamWilson/miceforest/actions/workflows/run_tests.yml/badge.svg)](https://github.com/AnotherSamWilson/miceforest/actions/workflows/run_tests.yml) 9 | [![Documentation 10 | Status](https://readthedocs.org/projects/miceforest/badge/?version=latest)](https://miceforest.readthedocs.io/en/latest/?badge=latest) 11 | [![CodeCov](https://codecov.io/gh/AnotherSamWilson/miceforest/branch/master/graphs/badge.svg?branch=master&service=github)](https://codecov.io/gh/AnotherSamWilson/miceforest) 12 | 13 | 14 | 15 | 16 | 17 | # miceforest: Fast, Memory Efficient Imputation with LightGBM 18 | 19 | 20 | 21 | Fast, memory efficient Multiple Imputation by Chained Equations (MICE) 22 | with lightgbm. The R version of this package may be found 23 | [here](https://github.com/FarrellDay/miceRanger). 24 | 25 | `miceforest` was designed to be: 26 | 27 | - **Fast** 28 | - Uses lightgbm as a backend 29 | - Has efficient mean matching solutions. 30 | - Can utilize GPU training 31 | - **Flexible** 32 | - Can impute pandas dataframes and numpy arrays 33 | - Handles categorical data automatically 34 | - Fits into a sklearn pipeline 35 | - User can customize every aspect of the imputation process 36 | - **Production Ready** 37 | - Can impute new, unseen datasets quickly 38 | - Kernels are efficiently compressed during saving and loading 39 | - Data can be imputed in place to save memory 40 | - Can build models on non-missing data 41 | 42 | 43 | This document contains a thorough walkthrough of the package, 44 | benchmarks, and an introduction to multiple imputation. More information 45 | on MICE can be found in Stef van Buuren’s excellent online book, which 46 | you can find 47 | [here](https://stefvanbuuren.name/fimd/ch-introduction.html). 48 | 49 | #### Table of Contents: 50 | 51 | - [Classes](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#classes) 52 | - [Basic Usage](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#basic-usage) 53 | - [Example](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#basic-usage) 54 | - [Customizing LightGBM Parameters](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#customizing-lightgbm-parameters) 55 | - [Available Mean Match Schemes](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#adjusting-the-mean-matching-scheme) 56 | - [Imputing New Data with Existing Models](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#imputing-new-data-with-existing-models) 57 | - [Saving and Loading Kernels](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#saving-and-loading-kernels) 58 | - [Implementing sklearn Pipelines](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#saving-and-loading-kernels) 59 | - [Advanced Features](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#advanced-features) 60 | - [Building Models on Nonmissing Data](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#building-models-on-nonmissing-data) 61 | - [Tuning Parameters](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#tuning-parameters) 62 | - [On Reproducibility](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#on-reproducibility) 63 | - [How to Make the Process Faster](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#how-to-make-the-process-faster) 64 | - [Imputing Data In Place](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#imputing-data-in-place) 65 | - [Diagnostic Plotting](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#diagnostic-plotting) 66 | - [Feature Importance](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#feature-importance) 67 | - [Imputed Distributions](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#plot-imputed-distributions) 68 | - [Using the Imputed Data](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#using-the-imputed-data) 69 | - [The MICE Algorithm](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#the-mice-algorithm) 70 | - [Introduction](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#the-mice-algorithm) 71 | - [Common Use Cases](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#common-use-cases) 72 | - [Predictive Mean Matching](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#predictive-mean-matching) 73 | - [Effects of Mean Matching](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#effects-of-mean-matching) 74 | 75 | ## Installation 76 | 77 | This package can be installed using either pip or conda, through 78 | conda-forge: 79 | 80 | ``` bash 81 | # Using pip 82 | $ pip install miceforest --no-cache-dir 83 | 84 | # Using conda 85 | $ conda install -c conda-forge miceforest 86 | ``` 87 | 88 | You can also download the latest development version from this 89 | repository. If you want to install from github with conda, you must 90 | first run `conda install pip git`. 91 | 92 | ``` bash 93 | $ pip install git+https://github.com/AnotherSamWilson/miceforest.git 94 | ``` 95 | 96 | ## Classes 97 | 98 | miceforest has 2 main classes which the user will interact with: 99 | 100 | - [`ImputationKernel`](https://miceforest.readthedocs.io/en/latest/ImputationKernel.html) 101 | - This class contains the raw data off of which the `mice` algorithm 102 | is performed. During this process, models will be trained, and the 103 | imputed (predicted) values will be stored. These values can be used 104 | to fill in the missing values of the raw data. The raw data can be 105 | copied, or referenced directly. Models can be saved, and used to 106 | impute new datasets. 107 | - [`ImputedData`](https://miceforest.readthedocs.io/en/latest/ImputedData.html) 108 | - The result of `ImputationKernel.impute_new_data(new_data)`. This 109 | contains the raw data in `new_data` as well as the imputed values. 110 | 111 | 112 | ## Basic Usage 113 | 114 | We will be looking at a few simple examples of imputation. We need to 115 | load the packages, and define the data: 116 | 117 | 118 | ```python 119 | import miceforest as mf 120 | from sklearn.datasets import load_iris 121 | import pandas as pd 122 | import numpy as np 123 | 124 | # Load data and introduce missing values 125 | iris = pd.concat(load_iris(as_frame=True,return_X_y=True),axis=1) 126 | iris.rename({"target": "species"}, inplace=True, axis=1) 127 | iris['species'] = iris['species'].astype('category') 128 | iris_amp = mf.ampute_data(iris,perc=0.25,random_state=1991) 129 | ``` 130 | 131 | If you only want to create a single imputed dataset, you can use 132 | [`ImputationKernel`](https://miceforest.readthedocs.io/en/latest/ImputationKernel.html) 133 | with some default settings: 134 | 135 | 136 | ```python 137 | # Create kernel. 138 | kds = mf.ImputationKernel( 139 | iris_amp, 140 | random_state=1991 141 | ) 142 | 143 | # Run the MICE algorithm for 2 iterations 144 | kds.mice(2) 145 | 146 | # Return the completed dataset. 147 | iris_complete = kds.complete_data() 148 | ``` 149 | 150 | There are also an array of plotting functions available, these are 151 | discussed below in the section [Diagnostic 152 | Plotting](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#diagnostic-plotting). 153 | 154 | We usually don’t want to impute just a single dataset. In statistics, 155 | multiple imputation is a process by which the uncertainty/other effects 156 | caused by missing values can be examined by creating multiple different 157 | imputed datasets. 158 | [`ImputationKernel`](https://miceforest.readthedocs.io/en/latest/ImputationKernel.html) 159 | can contain an arbitrary number of different datasets, all of which have 160 | gone through mutually exclusive imputation processes: 161 | 162 | 163 | ```python 164 | # Create kernel. 165 | kernel = mf.ImputationKernel( 166 | iris_amp, 167 | num_datasets=4, 168 | random_state=1 169 | ) 170 | 171 | # Run the MICE algorithm for 2 iterations on each of the datasets 172 | kernel.mice(2) 173 | 174 | # Printing the kernel will show you some high level information. 175 | print(kernel) 176 | ``` 177 | 178 | 179 | Class: ImputationKernel 180 | Datasets: 4 181 | Iterations: 2 182 | Data Samples: 150 183 | Data Columns: 5 184 | Imputed Variables: 5 185 | Modeled Variables: 5 186 | All Iterations Saved: True 187 | 188 | 189 | 190 | After we have run mice, we can obtain our completed dataset directly 191 | from the kernel: 192 | 193 | 194 | ```python 195 | completed_dataset = kernel.complete_data(dataset=2) 196 | print(completed_dataset.isnull().sum(0)) 197 | ``` 198 | 199 | sepal length (cm) 0 200 | sepal width (cm) 0 201 | petal length (cm) 0 202 | petal width (cm) 0 203 | species 0 204 | dtype: int64 205 | 206 | 207 | ## Customizing LightGBM Parameters 208 | 209 | Parameters can be passed directly to lightgbm in several different ways. 210 | Parameters you wish to apply globally to every model can simply be 211 | passed as kwargs to `mice`: 212 | 213 | 214 | ```python 215 | # Run the MICE algorithm for 1 more iteration on the kernel with new parameters 216 | kernel.mice(iterations=1, n_estimators=50) 217 | ``` 218 | 219 | You can also pass pass variable-specific arguments to 220 | `variable_parameters` in mice. For instance, let’s say you noticed the 221 | imputation of the `[species]` column was taking a little longer, because 222 | it is multiclass. You could decrease the n\_estimators specifically for 223 | that column with: 224 | 225 | 226 | ```python 227 | # Run the MICE algorithm for 2 more iterations on the kernel 228 | kernel.mice( 229 | iterations=1, 230 | variable_parameters={'species': {'n_estimators': 25}}, 231 | n_estimators=50 232 | ) 233 | 234 | # Let's get the actual models for these variables: 235 | species_model = kernel.get_model(dataset=0,variable="species") 236 | sepalwidth_model = kernel.get_model(dataset=0,variable="sepal width (cm)") 237 | 238 | print( 239 | f"""Species used {str(species_model.params["num_iterations"])} iterations 240 | Sepal Width used {str(sepalwidth_model.params["num_iterations"])} iterations 241 | """ 242 | ) 243 | ``` 244 | 245 | Species used 25 iterations 246 | Sepal Width used 50 iterations 247 | 248 | 249 | 250 | In this scenario, any parameters specified in `variable_parameters` 251 | takes presidence over the kwargs. 252 | 253 | Since we can pass any parameters we want to LightGBM, we can completely 254 | customize how our models are built. That includes how the data should be 255 | modeled. If your data contains count data, or any other data which can 256 | be parameterized by lightgbm, you can simply specify that variable to be 257 | modeled with the corresponding objective function. 258 | 259 | For example, let’s pretend `sepal width (cm)` is a count field which can 260 | be parameterized by a Poisson distribution. Let’s also change our 261 | boosting method to gradient boosted trees: 262 | 263 | 264 | ```python 265 | # Create kernel. 266 | cust_kernel = mf.ImputationKernel( 267 | iris_amp, 268 | num_datasets=1, 269 | random_state=1 270 | ) 271 | 272 | cust_kernel.mice( 273 | iterations=1, 274 | variable_parameters={'sepal width (cm)': {'objective': 'poisson'}}, 275 | boosting = 'gbdt', 276 | min_sum_hessian_in_leaf=0.01 277 | ) 278 | ``` 279 | 280 | Other nice parameters like `monotone_constraints` can also be passed. 281 | Setting the parameter `device: 'gpu'` will utilize GPU learning, if 282 | LightGBM is set up to do this on your machine. 283 | 284 | ## Adjusting The Mean Matching Scheme 285 | 286 | Note: It is probably a good idea to read [this 287 | section](https://github.com/AnotherSamWilson/miceforest?tab=readme-ov-file#predictive-mean-matching) 288 | first, to get some context on how mean matching works. 289 | 290 | There are 4 imputation strategies employed by `miceforest`: 291 | - **Fast** Mean Matching: Available only on binary and categorical variables. Chooses a class randomly based on the predicted probabilities output by lightgbm. 292 | - **Normal** Mean Matching: Employs mean matching as described in the section below. 293 | - **Shap** Mean Matching: Runs a nearest neighbor search on the shap values of the bachelor predictions in the shap values of the candidate predictions. Finds the `mean_match_candidates` nearest neighbors, and chooses one randomly as the imputation value. 294 | - Value Imputation: Uses the value output by lightgbm as the imputation value. Skips mean matching entirely. To use, set `mean_match_candidates = 0`. 295 | 296 | Here is the code required to use each method: 297 | 298 | 299 | ```python 300 | # Create kernel. 301 | cust_kernel = mf.ImputationKernel( 302 | iris_amp, 303 | num_datasets=1, 304 | random_state=1, 305 | mean_match_strategy={ 306 | 'sepal length (cm)': 'normal', 307 | 'sepal width (cm)': 'shap', 308 | 'species': 'fast', 309 | }, 310 | mean_match_candidates={ 311 | 'petal length (cm)': 0, 312 | } 313 | ) 314 | 315 | cust_kernel.mice( 316 | iterations=1, 317 | ) 318 | ``` 319 | 320 | ## Imputing New Data with Existing Models 321 | 322 | Multiple Imputation can take a long time. If you wish to impute a 323 | dataset using the MICE algorithm, but don’t have time to train new 324 | models, it is possible to impute new datasets using a `ImputationKernel` 325 | object. The `impute_new_data()` function uses the models collected by 326 | `ImputationKernel` to perform multiple imputation without updating the 327 | models at each iteration: 328 | 329 | 330 | ```python 331 | # Our 'new data' is just the first 15 rows of iris_amp 332 | from datetime import datetime 333 | 334 | # Define our new data as the first 15 rows 335 | new_data = iris_amp.iloc[range(15)].reset_index(drop=True) 336 | 337 | start_t = datetime.now() 338 | new_data_imputed = cust_kernel.impute_new_data(new_data=new_data) 339 | print(f"New Data imputed in {(datetime.now() - start_t).total_seconds()} seconds") 340 | ``` 341 | 342 | New Data imputed in 0.035129 seconds 343 | 344 | 345 | ## Saving and Loading Kernels 346 | 347 | Saving `miceforest` kernels is efficient. During the pickling process, the following steps are taken: 348 | 349 | 1. Convert working data to parquet bytes. 350 | 2. Serialize the kernel. 351 | 4. Save to a file. 352 | 353 | You can save and load the kernel like any other object using `pickle` or `dill`: 354 | 355 | 356 | 357 | ```python 358 | from tempfile import mkstemp 359 | import dill 360 | new_file, filename = mkstemp() 361 | 362 | with open(filename, "wb") as f: 363 | dill.dump(kernel, f) 364 | 365 | with open(filename, "rb") as f: 366 | kernel_from_pickle = dill.load(f) 367 | ``` 368 | 369 | ## Implementing sklearn Pipelines 370 | 371 | `miceforest` kernels can be fit into sklearn pipelines to impute training and scoring 372 | datasets: 373 | 374 | 375 | ```python 376 | import numpy as np 377 | from sklearn.preprocessing import StandardScaler 378 | from sklearn.datasets import make_classification 379 | from sklearn.model_selection import train_test_split 380 | from sklearn.pipeline import Pipeline 381 | import miceforest as mf 382 | 383 | kernel = mf.ImputationKernel(iris_amp, num_datasets=1, random_state=1) 384 | 385 | pipe = Pipeline([ 386 | ('impute', kernel), 387 | ('scaler', StandardScaler()), 388 | ]) 389 | 390 | # The pipeline can be used as any other estimator 391 | # and avoids leaking the test set into the train set 392 | X_train_t = pipe.fit_transform( 393 | X=iris_amp, 394 | y=None, 395 | impute__iterations=2 396 | ) 397 | X_test_t = pipe.transform(new_data) 398 | 399 | # Show that neither now have missing values. 400 | assert not np.any(np.isnan(X_train_t)) 401 | assert not np.any(np.isnan(X_test_t)) 402 | ``` 403 | 404 | # Advanced Features 405 | 406 | ## Building Models on Nonmissing Data 407 | 408 | The MICE process itself is used to impute missing data in a dataset. 409 | However, sometimes a variable can be fully recognized in the training 410 | data, but needs to be imputed later on in a different dataset. It is 411 | possible to train models to impute variables even if they have no 412 | missing values by specifying them in the `variable_schema` parameter. 413 | In this case, `variable_schema` is treated as the list of variables 414 | to train models on. 415 | 416 | 417 | ```python 418 | # Set petal length (cm) in our amputed data 419 | # to original values with no missing data. 420 | iris_amp['sepal width (cm)'] = iris['sepal width (cm)'].copy() 421 | iris_amp.isnull().sum() 422 | ``` 423 | 424 | 425 | 426 | 427 | sepal length (cm) 37 428 | sepal width (cm) 0 429 | petal length (cm) 37 430 | petal width (cm) 37 431 | species 37 432 | dtype: int64 433 | 434 | 435 | 436 | 437 | ```python 438 | kernel = mf.ImputationKernel( 439 | data=iris_amp, 440 | variable_schema=iris_amp.columns.to_list(), 441 | num_datasets=1, 442 | random_state=1, 443 | ) 444 | kernel.mice(1) 445 | ``` 446 | 447 | 448 | ```python 449 | # Remember, the dataset we are imputing does have 450 | # missing values in the sepal width (cm) column 451 | new_data.isnull().sum() 452 | ``` 453 | 454 | 455 | 456 | 457 | sepal length (cm) 4 458 | sepal width (cm) 3 459 | petal length (cm) 1 460 | petal width (cm) 3 461 | species 3 462 | dtype: int64 463 | 464 | 465 | 466 | 467 | ```python 468 | new_data_imp = kernel.impute_new_data(new_data) 469 | new_data_imp = new_data_imp.complete_data() 470 | 471 | # All columns have been imputed. 472 | new_data_imp.isnull().sum() 473 | ``` 474 | 475 | 476 | 477 | 478 | sepal length (cm) 0 479 | sepal width (cm) 0 480 | petal length (cm) 0 481 | petal width (cm) 0 482 | species 0 483 | dtype: int64 484 | 485 | 486 | 487 | ## Tuning Parameters 488 | 489 | `miceforest` allows you to tune the parameters on a kernel dataset. 490 | These parameters can then be used to build the models in future 491 | iterations of mice. In its most simple invocation, you can just call the 492 | function with the desired optimization steps: 493 | 494 | 495 | ```python 496 | optimal_params = kernel.tune_parameters( 497 | dataset=0, 498 | use_gbdt=True, 499 | num_iterations=500, 500 | random_state=1, 501 | ) 502 | kernel.mice(1, variable_parameters=optimal_params) 503 | pd.DataFrame(optimal_params) 504 | ``` 505 | 506 | 507 | 508 | 509 |
510 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 |
sepal length (cm)petal length (cm)petal width (cm)species
boostinggbdtgbdtgbdtgbdt
data_sample_strategybaggingbaggingbaggingbagging
num_iterations142248262172
max_depth4455
num_leaves1217219
min_data_in_leaf22155
min_sum_hessian_in_leaf0.010.010.010.01
min_gain_to_split0.00.00.00.0
bagging_fraction0.5809730.5015210.5867090.795465
feature_fraction_bynode0.9225660.2999120.5031820.237637
bagging_freq1111
verbosity-1-1-1-1
learning_rate0.020.020.020.02
objectiveregressionregressionregressionmulticlass
num_classNaNNaNNaN3
641 |
642 | 643 | 644 | 645 | This will perform 10 fold cross validation on random samples of 646 | parameters. By default, all variables models are tuned. 647 | 648 | The parameter tuning is pretty flexible. If you wish to set some model 649 | parameters static, or to change the bounds that are searched in, you can 650 | simply pass this information to either the `variable_parameters` 651 | parameter, `**kwbounds`, or both: 652 | 653 | 654 | ```python 655 | optimal_params = kernel.tune_parameters( 656 | dataset=0, 657 | variables = ['sepal width (cm)','species','petal width (cm)'], 658 | variable_parameters = { 659 | 'sepal width (cm)': {'bagging_fraction': 0.5}, 660 | 'species': {'bagging_freq': (5,10)} 661 | }, 662 | use_gbdt=True, 663 | optimization_steps=5, 664 | extra_trees = [True, False] 665 | ) 666 | 667 | kernel.mice(1, variable_parameters=optimal_params) 668 | ``` 669 | 670 | In this example, we did a few things - we specified that only `sepal 671 | width (cm)`, `species`, and `petal width (cm)` should be tuned. We also 672 | specified some specific parameters in `variable_parameters`. Notice that 673 | `bagging_fraction` was passed as a scalar, `0.5`. This means that, for 674 | the variable `sepal width (cm)`, the parameter `bagging_fraction` will 675 | be set as that number and not be tuned. We did the opposite for 676 | `bagging_freq`. We specified bounds that the process should search in. 677 | We also passed the argument `extra_trees` as a list. Since it was passed 678 | to \*\*kwbounds, this parameter will apply to all variables that are 679 | being tuned. Passing values as a list tells the process that it should 680 | randomly sample values from the list, instead of treating them as set of 681 | counts to search within. 682 | 683 | Additionally, we set `use_gbdt=True`. This switches the process to use 684 | gradient boosted trees, instead of random forests. Typically, gradient 685 | boosted trees will perform better. The optimal `num_iterations` is also 686 | determined by early stopping in cross validation. 687 | 688 | The tuning process follows these rules for different parameter values it 689 | finds: 690 | 691 | - Scalar: That value is used, and not tuned. 692 | - Tuple: Should be length 2. Treated as the lower and upper bound to 693 | search in. 694 | - List: Treated as a distinct list of values to try randomly. 695 | 696 | 697 | ## On Reproducibility 698 | 699 | `miceforest` allows for different “levels” of reproducibility, global 700 | and record-level. 701 | 702 | ##### **Global Reproducibility** 703 | 704 | Global reproducibility ensures that the same values will be imputed if 705 | the same code is run multiple times. To ensure global reproducibility, 706 | all the user needs to do is set a `random_state` when the kernel is 707 | initialized. 708 | 709 | ##### **Record-Level Reproducibility** 710 | 711 | Sometimes we want to obtain reproducible imputations at the record 712 | level, without having to pass the same dataset. This is possible by 713 | passing a list of record-specific seeds to the `random_seed_array` 714 | parameter. This is useful if imputing new data multiple times, and you 715 | would like imputations for each row to match each time it is imputed. 716 | 717 | 718 | 719 | ```python 720 | # Define seeds for the data, and impute iris 721 | import numpy as np 722 | random_seed_array = np.random.randint(0, 9999, size=iris_amp.shape[0], dtype='uint32') 723 | iris_imputed = kernel.impute_new_data( 724 | iris_amp, 725 | random_state=4, 726 | random_seed_array=random_seed_array 727 | ) 728 | 729 | # Select a random sample 730 | new_inds = np.random.choice(150, size=15) 731 | new_data = iris_amp.loc[new_inds].reset_index(drop=True) 732 | new_seeds = random_seed_array[new_inds] 733 | new_imputed = kernel.impute_new_data( 734 | new_data, 735 | random_state=4, 736 | random_seed_array=new_seeds 737 | ) 738 | 739 | # We imputed the same values for the 15 values each time, 740 | # because each record was associated with the same seed. 741 | assert new_imputed.complete_data(0).equals( 742 | iris_imputed.complete_data(0).loc[new_inds].reset_index(drop=True) 743 | ) 744 | ``` 745 | 746 | ## How to Make the Process Faster 747 | 748 | Multiple Imputation is one of the most robust ways to handle missing 749 | data - but it can take a long time. There are several strategies you can 750 | use to decrease the time a process takes to run: 751 | 752 | - Decrease `data_subset`. By default all non-missing datapoints for 753 | each variable are used to train the model and perform mean matching. 754 | This can cause the model training nearest-neighbors search to take a 755 | long time for large data. A subset of these points can be searched 756 | instead by using `data_subset`. 757 | - If categorical columns are taking a long time, you can set 758 | `mean_match_strategy="fast"`. You can also set different parameters 759 | specifically for categorical columns, like smaller `bagging_fraction` 760 | or `num_iterations`, or try grouping the categories before they are 761 | imputed. Model training time for categorical variables is linear with 762 | the number of distinct categories. 763 | - Decrease `mean_match_candidates`. The maximum number of neighbors 764 | that are considered with the default parameters is 10. However, for 765 | large datasets, this can still be an expensive operation. Consider 766 | explicitly setting `mean_match_candidates` lower. Setting 767 | `mean_match_candidates=0` will skip mean matching entirely, and 768 | just use the lightgbm predictions as the imputation values. 769 | - Use different lightgbm parameters. lightgbm is usually not the 770 | problem, however if a certain variable has a large number of 771 | classes, then the max number of trees actually grown is (\# classes) 772 | \* (n\_estimators). You can specifically decrease the bagging 773 | fraction or n\_estimators for large multi-class variables, or grow 774 | less trees in general. 775 | 776 | ## Imputing Data In Place 777 | 778 | It is possible to run the entire process without copying the dataset. If 779 | `copy_data=False`, then the data is referenced directly: 780 | 781 | 782 | 783 | ```python 784 | kernel_inplace = mf.ImputationKernel( 785 | iris_amp, 786 | num_datasets=1, 787 | copy_data=False, 788 | random_state=1, 789 | ) 790 | kernel_inplace.mice(2) 791 | ``` 792 | 793 | Note, that this probably won’t (but could) change the original dataset 794 | in undesirable ways. Throughout the `mice` procedure, imputed values are 795 | stored directly in the original data. At the end, the missing values are 796 | put back as `np.NaN`. 797 | 798 | We can also complete our original data in place. This is useful if the dataset is large, and copies can’t be made in 799 | memory: 800 | 801 | 802 | ```python 803 | kernel_inplace.complete_data(dataset=0, inplace=True) 804 | print(iris_amp.isnull().sum(0)) 805 | ``` 806 | 807 | sepal length (cm) 0 808 | sepal width (cm) 0 809 | petal length (cm) 0 810 | petal width (cm) 0 811 | species 0 812 | dtype: int64 813 | 814 | 815 | # Diagnostic Plotting 816 | 817 | As of now, there are 2 diagnostic plot available. More coming soon! 818 | 819 | ### Feature Importance 820 | 821 | 822 | ```python 823 | kernel.plot_feature_importance(dataset=0) 824 | ``` 825 | 826 | 827 | 828 | ![png](README_files/README_49_0.png) 829 | 830 | 831 | 832 | ### Plot Imputed Distributions 833 | 834 | 835 | ```python 836 | kernel.plot_imputed_distributions() 837 | ``` 838 | 839 | 840 | 841 | ![png](README_files/README_51_0.png) 842 | 843 | 844 | 845 | ## Using the Imputed Data 846 | 847 | To return the imputed data simply use the `complete_data` method: 848 | 849 | 850 | ```python 851 | dataset_1 = kernel.complete_data(0) 852 | ``` 853 | 854 | This will return a single specified dataset. Multiple datasets are 855 | typically created so that some measure of confidence around each 856 | prediction can be created. 857 | 858 | Since we know what the original data looked like, we can cheat and see 859 | how well the imputations compare to the original data: 860 | 861 | 862 | ```python 863 | acclist = [] 864 | iterations = kernel.iteration_count()+1 865 | for iteration in range(iterations): 866 | species_na_count = kernel.na_counts['species'] 867 | compdat = kernel.complete_data(dataset=0,iteration=iteration) 868 | 869 | # Record the accuract of the imputations of species. 870 | acclist.append( 871 | round(1-sum(compdat['species'] != iris['species'])/species_na_count,2) 872 | ) 873 | 874 | # acclist shows the accuracy of the imputations over the iterations. 875 | acclist = pd.Series(acclist).rename("Species Imputation Accuracy") 876 | acclist.index = range(iterations) 877 | acclist.index.name = "Iteration" 878 | acclist 879 | ``` 880 | 881 | 882 | 883 | 884 | Iteration 885 | 0 0.35 886 | 1 0.81 887 | 2 0.81 888 | 3 0.84 889 | Name: Species Imputation Accuracy, dtype: float64 890 | 891 | 892 | 893 | In this instance, we went from a low accuracy (what is expected with 894 | random sampling) to a much higher accuracy. 895 | 896 | ## The MICE Algorithm 897 | 898 | Multiple Imputation by Chained Equations ‘fills in’ (imputes) missing 899 | data in a dataset through an iterative series of predictive models. In 900 | each iteration, each specified variable in the dataset is imputed using 901 | the other variables in the dataset. These iterations should be run until 902 | it appears that convergence has been met. 903 | 904 | 905 | 906 | This process is continued until all specified variables have been 907 | imputed. Additional iterations can be run if it appears that the average 908 | imputed values have not converged, although no more than 5 iterations 909 | are usually necessary. 910 | 911 | ### Common Use Cases 912 | 913 | ##### **Data Leakage:** 914 | 915 | MICE is particularly useful if missing values are associated with the 916 | target variable in a way that introduces leakage. For instance, let’s 917 | say you wanted to model customer retention at the time of sign up. A 918 | certain variable is collected at sign up or 1 month after sign up. The 919 | absence of that variable is a data leak, since it tells you that the 920 | customer did not retain for 1 month. 921 | 922 | ##### **Funnel Analysis:** 923 | 924 | Information is often collected at different stages of a ‘funnel’. MICE 925 | can be used to make educated guesses about the characteristics of 926 | entities at different points in a funnel. 927 | 928 | ##### **Confidence Intervals:** 929 | 930 | MICE can be used to impute missing values, however it is important to 931 | keep in mind that these imputed values are a prediction. Creating 932 | multiple datasets with different imputed values allows you to do two 933 | types of inference: 934 | 935 | - Imputed Value Distribution: A profile can be built for each imputed 936 | value, allowing you to make statements about the likely distribution 937 | of that value. 938 | - Model Prediction Distribution: With multiple datasets, you can build 939 | multiple models and create a distribution of predictions for each 940 | sample. Those samples with imputed values which were not able to be 941 | imputed with much confidence would have a larger variance in their 942 | predictions. 943 | 944 | 945 | ## Predictive Mean Matching 946 | 947 | `miceforest` can make use of a procedure called predictive mean matching 948 | (PMM) to select which values are imputed. PMM involves selecting a 949 | datapoint from the original, nonmissing data (candidates) which has a 950 | predicted value close to the predicted value of the missing sample 951 | (bachelors). The closest N (`mean_match_candidates` parameter) values 952 | are selected, from which a value is chosen at random. This can be 953 | specified on a column-by-column basis. Going into more detail from our 954 | example above, we see how this works in practice: 955 | 956 | 957 | 958 | This method is very useful if you have a variable which needs imputing 959 | which has any of the following characteristics: 960 | 961 | - Multimodal 962 | - Integer 963 | - Skewed 964 | 965 | 966 | ### Effects of Mean Matching 967 | 968 | As an example, let’s construct a dataset with some of the above 969 | characteristics: 970 | 971 | 972 | ```python 973 | randst = np.random.RandomState(1991) 974 | # random uniform variable 975 | nrws = 1000 976 | uniform_vec = randst.uniform(size=nrws) 977 | 978 | def make_bimodal(mean1,mean2,size): 979 | bimodal_1 = randst.normal(size=nrws, loc=mean1) 980 | bimodal_2 = randst.normal(size=nrws, loc=mean2) 981 | bimdvec = [] 982 | for i in range(size): 983 | bimdvec.append(randst.choice([bimodal_1[i], bimodal_2[i]])) 984 | return np.array(bimdvec) 985 | 986 | # Make 2 Bimodal Variables 987 | close_bimodal_vec = make_bimodal(2,-2,nrws) 988 | far_bimodal_vec = make_bimodal(3,-3,nrws) 989 | 990 | 991 | # Highly skewed variable correlated with Uniform_Variable 992 | skewed_vec = np.exp(uniform_vec*randst.uniform(size=nrws)*3) + randst.uniform(size=nrws)*3 993 | 994 | # Integer variable correlated with Close_Bimodal_Variable and Uniform_Variable 995 | integer_vec = np.round(uniform_vec + close_bimodal_vec/3 + randst.uniform(size=nrws)*2) 996 | 997 | # Make a DataFrame 998 | dat = pd.DataFrame( 999 | { 1000 | 'uniform_var':uniform_vec, 1001 | 'close_bimodal_var':close_bimodal_vec, 1002 | 'far_bimodal_var':far_bimodal_vec, 1003 | 'skewed_var':skewed_vec, 1004 | 'integer_var':integer_vec 1005 | } 1006 | ) 1007 | 1008 | # Ampute the data. 1009 | ampdat = mf.ampute_data(dat,perc=0.25,random_state=randst) 1010 | ``` 1011 | 1012 | 1013 | ```python 1014 | import plotnine as p9 1015 | import itertools 1016 | 1017 | def plot_matrix(df, columns): 1018 | pdf = [] 1019 | for a1, b1 in itertools.combinations(columns, 2): 1020 | for (a,b) in ((a1, b1), (b1, a1)): 1021 | sub = df[[a, b]].rename(columns={a: "x", b: "y"}).assign(a=a, b=b) 1022 | pdf.append(sub) 1023 | 1024 | g = ( 1025 | p9.ggplot(pd.concat(pdf)) 1026 | + p9.geom_point(p9.aes('x','y')) 1027 | + p9.facet_grid('b~a', scales='free') 1028 | + p9.theme(figure_size=(7, 7)) 1029 | + p9.xlab("") + p9.ylab("") 1030 | ) 1031 | return g 1032 | 1033 | plot_matrix(dat, dat.columns) 1034 | ``` 1035 | 1036 | 1037 | 1038 | ![png](README_files/README_61_0.png) 1039 | 1040 | 1041 | 1042 | We can see how our variables are distributed and correlated in the graph 1043 | above. Now let’s run our imputation process twice, once using mean 1044 | matching, and once using the model prediction. 1045 | 1046 | 1047 | ```python 1048 | kernel_mean_match = mf.ImputationKernel( 1049 | data=ampdat, 1050 | num_datasets=3, 1051 | mean_match_candidates=5, 1052 | random_state=1 1053 | ) 1054 | kernel_mean_match.mice(2) 1055 | kernel_no_mean_match = mf.ImputationKernel( 1056 | data=ampdat, 1057 | num_datasets=3, 1058 | mean_match_candidates=0, 1059 | random_state=1 1060 | ) 1061 | kernel_no_mean_match.mice(2) 1062 | ``` 1063 | 1064 | 1065 | ```python 1066 | kernel_mean_match.plot_imputed_distributions() 1067 | ``` 1068 | 1069 | 1070 | 1071 | ![png](README_files/README_64_0.png) 1072 | 1073 | 1074 | 1075 | 1076 | ```python 1077 | kernel_no_mean_match.plot_imputed_distributions() 1078 | ``` 1079 | 1080 | 1081 | 1082 | ![png](README_files/README_65_0.png) 1083 | 1084 | 1085 | 1086 | You can see the effects that mean matching has, depending on the 1087 | distribution of the data. Simply returning the value from the model 1088 | prediction, while it may provide a better ‘fit’, will not provide 1089 | imputations with a similair distribution to the original. This may be 1090 | beneficial, depending on your goal. 1091 | -------------------------------------------------------------------------------- /miceforest/imputation_kernel.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from io import BytesIO 3 | from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union 4 | from warnings import warn 5 | 6 | import numpy as np 7 | from lightgbm import Booster, Dataset, cv, early_stopping, log_evaluation, train 8 | from lightgbm.basic import _ConfigAliases 9 | from pandas import Categorical, DataFrame, MultiIndex, Series, read_parquet 10 | from pandas.api.types import is_integer_dtype 11 | from scipy.spatial import KDTree 12 | 13 | from .default_lightgbm_parameters import _DEFAULT_LGB_PARAMS, _sample_parameters 14 | from .imputed_data import ImputedData 15 | from .logger import Logger 16 | from .utils import ( 17 | _draw_random_int32, 18 | _expand_value_to_dict, 19 | _list_union, 20 | ensure_rng, 21 | logodds, 22 | stratified_categorical_folds, 23 | stratified_continuous_folds, 24 | ) 25 | 26 | _DEFAULT_DATA_SUBSET = 0 27 | _DEFAULT_MEANMATCH_CANDIDATES = 5 28 | _DEFAULT_MEANMATCH_STRATEGY = "normal" 29 | _MICE_TIMED_LEVELS = ["Dataset", "Iteration", "Variable", "Event"] 30 | _IMPUTE_NEW_DATA_TIMED_LEVELS = ["Dataset", "Iteration", "Variable", "Event"] 31 | _TUNING_TIMED_LEVELS = ["Variable", "Iteration"] 32 | _PRE_LINK_DATATYPE = "float16" 33 | 34 | 35 | class ImputationKernel(ImputedData): 36 | """ 37 | Creates a kernel dataset. This dataset can perform MICE on itself, 38 | and impute new data from models obtained during MICE. 39 | 40 | Parameters 41 | ---------- 42 | data : pandas.DataFrame. 43 | The data to be imputed. 44 | variable_schema : None or List[str] or Dict[str, str], default=None 45 | Specifies the feature - target relationships used to train models. 46 | This parameter also controls which models are built. Models can be built 47 | even if a variable contains no missing values, or is not being imputed. 48 | 49 | - If :code:`None`, all columns with missing values will have models trained, and all 50 | columns will be used as features in these models. 51 | - If :code:`List[str]`, all columns in data are used to impute the variables in the list 52 | - If :code:`Dict[str, str]` the values will be used to impute the keys. 53 | 54 | No models will be trained for variables not specified by variable_schema 55 | (either by None, a list, or in dict keys). 56 | imputation_order : str, default="ascending" 57 | The order the imputations should occur in: 58 | 59 | - :code:`ascending`: variables are imputed from least to most missing 60 | - :code:`descending`: most to least missing 61 | - :code:`roman`: from left to right in the dataset 62 | - :code:`arabic`: from right to left in the dataset. 63 | 64 | data_subset: None or int or Dict[str, int], default=0 65 | Subsets the data used to train the model for each variable, which can save a significant amount of time. 66 | The number of rows used for model training and mean matching (candidates) is 67 | :code:`(# rows in raw data) - (# missing variable values)` 68 | for each variable. :code:`data_subset` takes a random sample from these candidates. 69 | 70 | - If :code:`int`, must be >= 0. Interpreted as the number of candidates. 71 | - If :code:`0`, no subsetting is done. 72 | - If :code:`Dict[str, int]`, keys must be variable names, and values must follow two above rules. 73 | 74 | This can also help with memory consumption, as the candidate data must be copied to 75 | make a feature dataset for lightgbm. It is recommended to carefully select this value 76 | for each variable if dealing with very large data that barely fits into memory. 77 | 78 | mean_match_strategy: str or Dict[str, str], default="normal" 79 | There are 3 mean matching strategies included in miceforest: 80 | 81 | - :code:`normal` - this is the default. For all predictions, K-nearest-neighbors 82 | is performed on the candidate predictions and bachelor predictions. 83 | The top MMC closest candidate values are chosen at random. 84 | - :code:`fast` - Only available for categorical and binary columns. A value 85 | is selected at random weighted by the class probabilities. 86 | - :code:`shap` - Similar to "normal" but more robust. A K-nearest-neighbors 87 | search is performed on the shap values of the candidate predictions 88 | and the bachelor predictions. A value from the top MMC closest candidate 89 | values is chosen at random. 90 | 91 | A dict of strategies by variable can be passed as well. Any unmentioned variables 92 | will be set to the default, "normal". 93 | 94 | .. code-block:: python 95 | 96 | mean_match_strategy = { 97 | 'column_1': 'fast', 98 | 'column_2': 'shap', 99 | } 100 | 101 | Special rules are enacted when :code:`mean_match_candidates==0` for a 102 | variable. See the mean_match_candidates parameter for more information. 103 | 104 | mean_match_candidates: int or Dict[str, int] 105 | The number of nearest neighbors to choose an imputation value from randomly when mean matching. 106 | 107 | Special rules apply when this value is set to 0. This will skip mean matching entirely. 108 | The algorithm that applies depends on the objective type: 109 | 110 | - :code:`Regression`: The bachelor predictions are used as the imputation values. 111 | - :code:`Binary`: The class with the higher probability is chosen. 112 | - :code:`Multiclass`: The class with the highest probability is chosen. 113 | 114 | Setting mmc to 0 will result in much faster process times, but has a few downsides: 115 | 116 | - Imputation values for regression variables might no longer be valid values. 117 | Mean matching ensures that the imputed values have been realized in the data before. 118 | - Random variability from mean matching is often desired to get a more accurate 119 | view of the variability in imputed "confidence" 120 | 121 | initialize_empty: bool, default=False 122 | If :code:`True`, missing data is not filled in randomly before model training starts. 123 | 124 | save_all_iterations_data: bool, default=True 125 | Setting to False will cause the process to not store the models and 126 | candidate values obtained at each iteration. This can save significant 127 | amounts of memory, but it means :code:`impute_new_data()` will not be callable. 128 | 129 | copy_data: bool, default=True 130 | Should the dataset be referenced directly? If False, this will cause 131 | the dataset to be altered in place. If a copy is created, it is saved 132 | in self.working_data. There are different ways in which the dataset 133 | can be altered. 134 | 135 | random_state: None, int, or numpy.random.RandomState 136 | The random_state ensures script reproducibility. It only ensures reproducible 137 | results if the same script is called multiple times. It does not guarantee 138 | reproducible results at the record level if a record is imputed multiple 139 | different times. If reproducible record-results are desired, a seed must be 140 | passed for each record in the :code:`random_seed_array` parameter. 141 | """ 142 | 143 | def __init__( 144 | self, 145 | data: DataFrame, 146 | num_datasets: int = 1, 147 | variable_schema: Optional[Union[List[str], Dict[str, List[str]]]] = None, 148 | imputation_order: Literal[ 149 | "ascending", "descending", "roman", "latin" 150 | ] = "ascending", 151 | mean_match_candidates: Union[ 152 | int, Dict[str, int] 153 | ] = _DEFAULT_MEANMATCH_CANDIDATES, 154 | mean_match_strategy: Optional[ 155 | Union[str, Dict[str, str]] 156 | ] = _DEFAULT_MEANMATCH_STRATEGY, 157 | data_subset: Union[int, Dict[str, int]] = _DEFAULT_DATA_SUBSET, 158 | initialize_empty: bool = False, 159 | save_all_iterations_data: bool = True, 160 | copy_data: bool = True, 161 | random_state: Optional[Union[int, np.random.RandomState]] = None, 162 | ): 163 | 164 | datasets = list(range(num_datasets)) 165 | 166 | super().__init__( 167 | impute_data=data, 168 | datasets=datasets, 169 | variable_schema=variable_schema, 170 | save_all_iterations_data=save_all_iterations_data, 171 | copy_data=copy_data, 172 | random_seed_array=None, 173 | ) 174 | 175 | # Model Training / Imputation Order: 176 | # Variables with missing data are always trained 177 | # first, according to imputation_order. Afterwards, 178 | # variables with no missing values have models trained. 179 | if imputation_order in ["ascending", "descending"]: 180 | _na_counts = { 181 | key: value 182 | for key, value in self.na_counts.items() 183 | if key in self.imputed_variables 184 | } 185 | self.imputation_order = list( 186 | Series(_na_counts).sort_values(ascending=True).index 187 | ) 188 | if imputation_order == "descending": 189 | self.imputation_order.reverse() 190 | elif imputation_order == "roman": 191 | self.imputation_order = self.imputed_variables.copy() 192 | elif imputation_order == "arabic": 193 | self.imputation_order = self.imputed_variables.copy() 194 | self.imputation_order.reverse() 195 | else: 196 | raise ValueError("imputation_order not recognized.") 197 | 198 | modeled_but_not_imputed_variables = [ 199 | col for col in self.modeled_variables if col not in self.imputed_variables 200 | ] 201 | model_training_order = self.imputation_order + modeled_but_not_imputed_variables 202 | self.model_training_order = model_training_order 203 | 204 | self.initialize_empty = initialize_empty 205 | self.save_all_iterations_data = save_all_iterations_data 206 | 207 | # Models are stored in a dict, keys are (variable, iteration, dataset) 208 | self.models: Dict[Tuple[str, int, int], Booster] = {} 209 | 210 | # Candidate preds are stored the same as models. 211 | self.candidate_preds: Dict[str, DataFrame] = {} 212 | 213 | # Optimal parameters can only be found on 1 dataset at the current iteration. 214 | self.optimal_parameters: Dict[str, Dict[str, Any]] = {} 215 | 216 | # Determine available candidates and interpret data subset. 217 | available_candidates = { 218 | v: (self.shape[0] - self.na_counts[v]) for v in self.model_training_order 219 | } 220 | data_subset = _expand_value_to_dict( 221 | _DEFAULT_DATA_SUBSET, data_subset, keys=self.model_training_order 222 | ) 223 | for col in self.model_training_order: 224 | assert ( 225 | data_subset[col] <= available_candidates[col] 226 | ), f"data_subset is more than available candidates for {col}" 227 | self.available_candidates = available_candidates 228 | self.data_subset = data_subset 229 | 230 | # Collect category information. 231 | categorical_columns: List[str] = [ 232 | var 233 | for var, dtype in self.working_data.dtypes.items() 234 | if dtype.name == "category" 235 | ] 236 | category_counts = { 237 | col: len(self.working_data[col].cat.categories) 238 | for col in categorical_columns 239 | } 240 | numeric_columns = [ 241 | col for col in self.working_data.columns if col not in categorical_columns 242 | ] 243 | binary_columns = [] 244 | for col, count in category_counts.items(): 245 | if count == 2: 246 | binary_columns.append(col) 247 | categorical_columns.remove(col) 248 | 249 | # Probably a better way of doing this 250 | assert set(categorical_columns).isdisjoint(set(numeric_columns)) 251 | assert set(categorical_columns).isdisjoint(set(binary_columns)) 252 | assert set(binary_columns).isdisjoint(set(numeric_columns)) 253 | 254 | self.category_counts = category_counts 255 | self.modeled_categorical_columns = _list_union( 256 | categorical_columns, self.model_training_order 257 | ) 258 | self.modeled_numeric_columns = _list_union( 259 | numeric_columns, self.model_training_order 260 | ) 261 | self.modeled_binary_columns = _list_union( 262 | binary_columns, self.model_training_order 263 | ) 264 | predictor_columns = sum(self.variable_schema.values(), []) 265 | self.predictor_columns = [ 266 | col for col in data.columns if col in predictor_columns 267 | ] 268 | 269 | # Make sure all pandas categorical levels are used. 270 | rare_level_cols = [] 271 | for col in self.modeled_categorical_columns: 272 | value_counts = data[col].value_counts(normalize=True) 273 | if np.any(value_counts < 0.002): 274 | rare_level_cols.append(col) 275 | if rare_level_cols: 276 | warn( 277 | f"{','.join(rare_level_cols)} have very rare categories, it is a good " 278 | "idea to group these, or set the min_data_in_leaf parameter to prevent " 279 | "lightgbm from outputting 0.0 probabilities." 280 | ) 281 | 282 | self.mean_match_candidates = _expand_value_to_dict( 283 | _DEFAULT_MEANMATCH_CANDIDATES, 284 | mean_match_candidates, 285 | self.model_training_order, 286 | ) 287 | self.mean_match_strategy = _expand_value_to_dict( 288 | _DEFAULT_MEANMATCH_STRATEGY, mean_match_strategy, self.model_training_order 289 | ) 290 | 291 | for col in self.model_training_order: 292 | mmc = self.mean_match_candidates[col] 293 | mms = self.mean_match_strategy[col] 294 | assert not ((mmc == 0) and (mms == "shap")), ( 295 | f"Failing because {col} mean_match_candidates == 0 and " 296 | "mean_match_strategy == shap. This implies an unintentional setup." 297 | ) 298 | 299 | # Determine if the mean matching scheme will 300 | # require candidate information for each variable 301 | self.mean_matching_requires_candidates = [] 302 | for variable in self.model_training_order: 303 | mean_match_strategy = self.mean_match_strategy[variable] 304 | if (mean_match_strategy in ["normal", "shap"]) or ( 305 | variable in self.modeled_numeric_columns 306 | ): 307 | self.mean_matching_requires_candidates.append(variable) 308 | 309 | self.loggers: List[Logger] = [] 310 | 311 | # Manage randomness 312 | self._completely_random_kernel = random_state is None 313 | self._random_state = ensure_rng(random_state) 314 | 315 | # Set initial imputations (iteration 0). 316 | self._initialize_dataset(self, random_state=self._random_state) 317 | 318 | # Save for use later 319 | self.optimal_parameter_losses: Dict[str, float] = dict() 320 | self.optimal_parameters = dict() 321 | 322 | def __getstate__(self): 323 | """ 324 | For pickling 325 | """ 326 | # Copy the entire object, minus the big stuff 327 | 328 | special_handling = ["imputation_values"] 329 | if self.save_all_iterations_data: 330 | special_handling.append("candidate_preds") 331 | 332 | state = { 333 | key: value 334 | for key, value in self.__dict__.items() 335 | if key not in special_handling 336 | }.copy() 337 | 338 | state["imputation_values"] = {} 339 | state["candidate_preds"] = {} 340 | 341 | for col, df in self.imputation_values.items(): 342 | byte_stream = BytesIO() 343 | df.to_parquet(byte_stream) 344 | state["imputation_values"][col] = byte_stream 345 | for col, df in self.candidate_preds.items(): 346 | byte_stream = BytesIO() 347 | df.to_parquet(byte_stream) 348 | state["candidate_preds"][col] = byte_stream 349 | 350 | return state 351 | 352 | def __setstate__(self, state): 353 | """ 354 | For unpickling 355 | """ 356 | self.__dict__ = state 357 | 358 | for col, bytes in self.imputation_values.items(): 359 | self.imputation_values[col] = read_parquet(bytes) 360 | 361 | if self.save_all_iterations_data: 362 | for col, bytes in self.candidate_preds.items(): 363 | self.candidate_preds[col] = read_parquet(bytes) 364 | 365 | def __repr__(self): 366 | summary_string = f'\n{" " * 14}Class: ImputationKernel\n{self._ids_info()}' 367 | return summary_string 368 | 369 | def _initialize_dataset(self, imputed_data, random_state): 370 | """ 371 | Sets initial imputation values for iteration 0. 372 | If "random", draw values from the working data at random. 373 | If "empty", keep the values missing, since missing values 374 | can be handled natively by lightgbm. 375 | """ 376 | 377 | assert not imputed_data.initialized, "dataset has already been initialized" 378 | 379 | if self.initialize_empty: 380 | # The default value when initialized is np.nan, nothing to do here 381 | pass 382 | else: 383 | for variable in imputed_data.imputed_variables: 384 | # Pulls from the kernel working data 385 | candidate_values = self._get_nonmissing_values(variable) 386 | candidate_num = candidate_values.shape[0] 387 | 388 | # Pulls from the ImputedData 389 | missing_ind = imputed_data.na_where[variable] 390 | missing_num = imputed_data.na_counts[variable] 391 | 392 | for dataset in imputed_data.datasets: 393 | # Initialize using the random_state if no record seeds were passed. 394 | if imputed_data.random_seed_array is None: 395 | imputation_values = candidate_values.sample( 396 | n=missing_num, replace=True, random_state=random_state 397 | ) 398 | imputation_values.index = missing_ind 399 | imputed_data[variable, 0, dataset] = imputation_values 400 | else: 401 | assert ( 402 | len(imputed_data.random_seed_array) == imputed_data.shape[0] 403 | ), "The random_seed_array did not match the number of rows being imputed." 404 | hashed_seeds = imputed_data._get_hashed_seeds(variable=variable) 405 | selection_ind = hashed_seeds % candidate_num 406 | imputation_values = candidate_values.iloc[selection_ind] 407 | imputation_values.index = missing_ind 408 | imputed_data[variable, 0, dataset] = imputation_values 409 | 410 | imputed_data.initialized = True 411 | 412 | @staticmethod 413 | def _uncover_aliases(params): 414 | """ 415 | Switches all aliases in the parameter dict to their 416 | True name, easiest way to avoid duplicate parameters. 417 | """ 418 | alias_dict = _ConfigAliases._get_all_param_aliases() 419 | for param in list(params): 420 | for true_name, aliases in alias_dict.items(): 421 | if param in aliases: 422 | params[true_name] = params.pop(param) 423 | 424 | def _make_lgb_params( 425 | self, 426 | variable: str, 427 | default_parameters: dict, 428 | variable_parameters: dict, 429 | **kwlgb, 430 | ): 431 | """ 432 | Builds the parameters for a lightgbm model. Infers objective based on 433 | datatype of the response variable, assigns a random seed, finds 434 | aliases in the user supplied parameters, and returns a final dict. 435 | 436 | Parameters 437 | ---------- 438 | variable: int 439 | The variable to be modeled 440 | 441 | default_parameters: dict 442 | The base set of parameters that should be used. 443 | 444 | variable_parameters: dict 445 | Variable specific parameters. These are supplied by the user. 446 | 447 | kwlgb: dict 448 | Any additional parameters that should take presidence 449 | over the defaults. 450 | """ 451 | 452 | seed = _draw_random_int32(self._random_state, size=1)[0] 453 | 454 | if variable in self.modeled_categorical_columns: 455 | n_c = self.category_counts[variable] 456 | obj = {"objective": "multiclass", "num_class": n_c} 457 | elif variable in self.modeled_binary_columns: 458 | obj = {"objective": "binary"} 459 | else: 460 | obj = {"objective": "regression"} 461 | 462 | lgb_params = default_parameters.copy() 463 | lgb_params.update(obj) 464 | lgb_params["seed"] = seed 465 | 466 | self._uncover_aliases(lgb_params) 467 | self._uncover_aliases(kwlgb) 468 | self._uncover_aliases(variable_parameters) 469 | 470 | # Priority is [variable specific] > [global in kwargs] > [defaults] 471 | lgb_params.update(kwlgb) 472 | lgb_params.update(variable_parameters) 473 | 474 | return lgb_params 475 | 476 | # WHEN TUNING, THESE PARAMETERS OVERWRITE THE DEFAULTS ABOVE 477 | # These need to be main parameter names, not aliases 478 | def _make_tuning_space( 479 | self, 480 | variable: str, 481 | variable_parameters: dict, 482 | use_gbdt: bool, 483 | min_samples: int, 484 | max_samples: int, 485 | **kwargs, 486 | ): 487 | 488 | # Start with the default parameters, update with the search space 489 | params = _DEFAULT_LGB_PARAMS.copy() 490 | search_space = { 491 | "min_data_in_leaf": (min_samples, max_samples), 492 | "max_depth": (2, 6), 493 | "num_leaves": (2, 25), 494 | "bagging_fraction": (0.1, 1.0), 495 | "feature_fraction_bynode": (0.1, 1.0), 496 | } 497 | params.update(search_space) 498 | 499 | # Set our defaults if using gbdt 500 | if use_gbdt: 501 | params["boosting"] = "gbdt" 502 | params["learning_rate"] = 0.02 503 | params["num_iterations"] = 250 504 | 505 | params = self._make_lgb_params( 506 | variable=variable, 507 | default_parameters=params, 508 | variable_parameters=variable_parameters, 509 | **kwargs, 510 | ) 511 | 512 | return params 513 | 514 | @staticmethod 515 | def _get_oof_performance( 516 | parameters: dict, 517 | folds: Generator, 518 | train_set: Dataset, 519 | ): 520 | """ 521 | Performance is gathered from built-in lightgbm.cv out of fold metric. 522 | Optimal number of iterations is also obtained. 523 | """ 524 | 525 | num_iterations = parameters.pop("num_iterations") 526 | lgbcv = cv( 527 | params=parameters, 528 | train_set=train_set, 529 | folds=folds, 530 | num_boost_round=num_iterations, 531 | return_cvbooster=True, 532 | callbacks=[ 533 | early_stopping(stopping_rounds=10, verbose=False), 534 | log_evaluation(period=0), 535 | ], 536 | ) 537 | best_iteration = lgbcv["cvbooster"].best_iteration # type: ignore 538 | loss_metric_key = list(lgbcv)[0] 539 | loss: float = np.min(lgbcv[loss_metric_key]) # type: ignore 540 | 541 | return loss, best_iteration 542 | 543 | def _get_nonmissing_subset_index(self, variable: str, seed: int): 544 | """ 545 | Get random indices for a subset of the data in which variable is not missing. 546 | Used to create feature / label for training. 547 | 548 | replace = False because it would NOT mimic bagging for random forests. 549 | """ 550 | 551 | data_subset = self.data_subset[variable] 552 | available_candidates = self.available_candidates[variable] 553 | nonmissing_ind = self._get_nonmissing_index(variable=variable) 554 | if (data_subset == 0) or (data_subset >= available_candidates): 555 | subset_index = nonmissing_ind 556 | else: 557 | rs = np.random.RandomState(seed) 558 | subset_index = rs.choice(nonmissing_ind, size=data_subset, replace=False) 559 | return subset_index 560 | 561 | def _make_label(self, variable: str, seed: int): 562 | """ 563 | Returns a reproducible subset of the non-missing values of a variable. 564 | """ 565 | # Don't subset at all if data_subset == 0 or we want more than there are candidates 566 | 567 | subset_index = self._get_nonmissing_subset_index(variable=variable, seed=seed) 568 | label = self.working_data.loc[subset_index, variable].copy() 569 | return label 570 | 571 | def _make_features_label(self, variable: str, seed: int): 572 | """ 573 | Makes a reproducible set of features and 574 | target needed to train a lightgbm model. 575 | """ 576 | subset_index = self._get_nonmissing_subset_index(variable=variable, seed=seed) 577 | predictor_columns = self.variable_schema[variable] 578 | features = self.working_data.loc[ 579 | subset_index, predictor_columns + [variable] 580 | ].copy() 581 | label = features.pop(variable) 582 | return features, label 583 | 584 | @staticmethod 585 | def _mean_match_nearest_neighbors( 586 | mean_match_candidates: int, 587 | bachelor_preds: DataFrame, 588 | candidate_preds: DataFrame, 589 | candidate_values: Series, 590 | random_state: np.random.RandomState, 591 | hashed_seeds: Optional[np.ndarray] = None, 592 | ) -> Series: 593 | """ 594 | Determines the values of candidates which will be used to impute the bachelors 595 | """ 596 | 597 | assert mean_match_candidates > 0, "Do not use nearest_neighbors with 0 mmc." 598 | num_bachelors = bachelor_preds.shape[0] 599 | 600 | # balanced_tree = False fixes a recursion issue for some reason. 601 | # https://github.com/scipy/scipy/issues/14799 602 | kd_tree = KDTree(candidate_preds, leafsize=16, balanced_tree=False) 603 | _, knn_indices = kd_tree.query( 604 | bachelor_preds, k=mean_match_candidates, workers=-1 605 | ) 606 | 607 | # We can skip the random selection process if mean_match_candidates == 1 608 | if mean_match_candidates == 1: 609 | index_choice = knn_indices 610 | 611 | else: 612 | # Use the random_state if seed_array was not passed. Faster 613 | if hashed_seeds is None: 614 | ind = random_state.randint(mean_match_candidates, size=(num_bachelors)) 615 | # Use the random_seed_array if it was passed. Deterministic. 616 | else: 617 | ind = hashed_seeds % mean_match_candidates 618 | 619 | index_choice = knn_indices[np.arange(num_bachelors), ind] 620 | 621 | imp_values = candidate_values.iloc[index_choice] 622 | 623 | return imp_values 624 | 625 | @staticmethod 626 | def _mean_match_binary_fast( 627 | mean_match_candidates: int, 628 | bachelor_preds: DataFrame, 629 | random_state: np.random.RandomState, 630 | hashed_seeds: Optional[np.ndarray], 631 | ) -> np.ndarray: 632 | """ 633 | Chooses 0/1 randomly weighted by probability obtained from prediction. 634 | If mean_match_candidates is 0, choose class with highest probability. 635 | 636 | Returns a np.ndarray, because these get set to categorical later on. 637 | """ 638 | if mean_match_candidates == 0: 639 | imp_values = np.floor(bachelor_preds + 0.5) 640 | 641 | else: 642 | num_bachelors = bachelor_preds.shape[0] 643 | if hashed_seeds is None: 644 | imp_values = random_state.binomial(n=1, p=bachelor_preds) 645 | else: 646 | imp_values = [] 647 | for i in range(num_bachelors): 648 | np.random.seed(seed=hashed_seeds[i]) 649 | imp_values.append(np.random.binomial(n=1, p=bachelor_preds.iloc[i])) 650 | 651 | imp_values = np.array(imp_values) 652 | 653 | imp_values.shape = (-1,) 654 | 655 | return imp_values 656 | 657 | @staticmethod 658 | def _mean_match_multiclass_fast( 659 | mean_match_candidates: int, 660 | bachelor_preds: DataFrame, 661 | random_state: np.random.RandomState, 662 | hashed_seeds: Optional[np.ndarray], 663 | ): 664 | """ 665 | If mean_match_candidates is 0, choose class with highest probability. 666 | Otherwise, randomly choose class weighted by class probabilities. 667 | 668 | Returns a np.ndarray, because these get set to categorical later on. 669 | """ 670 | if mean_match_candidates == 0: 671 | imp_values = np.argmax(bachelor_preds, axis=1) 672 | 673 | else: 674 | num_bachelors = bachelor_preds.shape[0] 675 | bachelor_preds = bachelor_preds.cumsum(axis=1).to_numpy() 676 | 677 | if hashed_seeds is None: 678 | compare = random_state.uniform(0, 1, size=(num_bachelors, 1)) 679 | imp_values = (bachelor_preds < compare).sum(1) 680 | 681 | else: 682 | dtype = hashed_seeds.dtype 683 | dtype_max = np.iinfo(dtype).max 684 | compare = np.abs(hashed_seeds / dtype_max) 685 | compare.shape = (-1, 1) 686 | imp_values = (bachelor_preds < compare).sum(1) 687 | 688 | imp_values.shape = (-1,) 689 | 690 | return imp_values 691 | 692 | def _mean_match_fast( 693 | self, 694 | variable: str, 695 | mean_match_candidates: int, 696 | bachelor_preds: np.ndarray, 697 | random_state: np.random.RandomState, 698 | hashed_seeds: Optional[np.ndarray], 699 | ): 700 | """ 701 | Dispatcher and formatter for the fast mean matching functions 702 | """ 703 | if variable in self.modeled_categorical_columns: 704 | imputation_values = self._mean_match_multiclass_fast( 705 | mean_match_candidates=mean_match_candidates, 706 | bachelor_preds=bachelor_preds, 707 | random_state=random_state, 708 | hashed_seeds=hashed_seeds, 709 | ) 710 | elif variable in self.modeled_binary_columns: 711 | imputation_values = self._mean_match_binary_fast( 712 | mean_match_candidates=mean_match_candidates, 713 | bachelor_preds=bachelor_preds, 714 | random_state=random_state, 715 | hashed_seeds=hashed_seeds, 716 | ) 717 | else: 718 | raise ValueError("Shouldnt be able to get here") 719 | 720 | dtype = self.working_data[variable].dtype 721 | imputation_values = Categorical.from_codes(codes=imputation_values, dtype=dtype) 722 | 723 | return imputation_values 724 | 725 | def _impute_with_predictions( 726 | self, 727 | variable: str, 728 | lgbmodel: Booster, 729 | bachelor_features: DataFrame, 730 | ): 731 | bachelor_preds = lgbmodel.predict( 732 | bachelor_features, 733 | pred_contrib=False, 734 | raw_score=False, 735 | ) 736 | assert isinstance(bachelor_preds, np.ndarray) 737 | dtype = self.working_data[variable].dtype 738 | if variable in self.modeled_numeric_columns: 739 | if is_integer_dtype(dtype): 740 | bachelor_preds = bachelor_preds.round(0) 741 | return Series(bachelor_preds, dtype=dtype) 742 | else: 743 | if variable in self.modeled_binary_columns: 744 | selection_ind = (bachelor_preds > 0.5).astype("uint8") 745 | else: 746 | assert ( 747 | variable in self.modeled_categorical_columns 748 | ), f"{variable} is not in numeric, binary or categorical columns" 749 | selection_ind = np.argmax(bachelor_preds, axis=1) 750 | values = dtype.categories[selection_ind] 751 | return Series(values, dtype=dtype) 752 | 753 | def _get_candidate_preds_mice( 754 | self, 755 | variable: str, 756 | lgbmodel: Booster, 757 | candidate_features: DataFrame, 758 | dataset: int, 759 | iteration: int, 760 | ) -> DataFrame: 761 | """ 762 | This function also records the candidate predictions 763 | """ 764 | shap = self.mean_match_strategy[variable] == "shap" 765 | fast = self.mean_match_strategy[variable] == "fast" 766 | logistic = variable not in self.modeled_numeric_columns 767 | 768 | assert hasattr( 769 | lgbmodel, "train_set" 770 | ), "Model was passed that does not have training data." 771 | if shap: 772 | candidate_preds = lgbmodel.predict( 773 | candidate_features, 774 | pred_contrib=True, 775 | ) 776 | candidate_preds = candidate_preds.astype(_PRE_LINK_DATATYPE) # type: ignore 777 | else: 778 | candidate_preds = lgbmodel._Booster__inner_predict(0) # type: ignore 779 | if logistic and not (shap or fast): 780 | candidate_preds = logodds(candidate_preds).astype(_PRE_LINK_DATATYPE) 781 | 782 | candidate_preds = self._prepare_prediction_multiindex( 783 | variable=variable, 784 | preds=candidate_preds, 785 | shap=shap, 786 | dataset=dataset, 787 | iteration=iteration, 788 | ) 789 | 790 | if self.save_all_iterations_data: 791 | self._record_candidate_preds( 792 | variable=variable, 793 | candidate_preds=candidate_preds, 794 | ) 795 | 796 | return candidate_preds 797 | 798 | def _get_candidate_preds_from_store( 799 | self, 800 | variable: str, 801 | dataset: int, 802 | iteration: int, 803 | ) -> DataFrame: 804 | """ 805 | Mean matching requires 2D array, so always return a dataframe 806 | """ 807 | ret = self.candidate_preds[variable][iteration][[dataset]] 808 | assert isinstance(ret, DataFrame) 809 | return ret 810 | 811 | def _get_bachelor_preds( 812 | self, 813 | variable: str, 814 | lgbmodel: Booster, 815 | bachelor_features: DataFrame, 816 | dataset: int, 817 | iteration: int, 818 | ) -> DataFrame: 819 | 820 | shap = self.mean_match_strategy[variable] == "shap" 821 | fast = self.mean_match_strategy[variable] == "fast" 822 | logistic = variable not in self.modeled_numeric_columns 823 | 824 | bachelor_preds = lgbmodel.predict( 825 | bachelor_features, 826 | pred_contrib=shap, 827 | ) 828 | assert isinstance(bachelor_preds, np.ndarray) 829 | 830 | if shap: 831 | bachelor_preds = bachelor_preds.astype(_PRE_LINK_DATATYPE) 832 | 833 | # We want the logods if running k-nearest 834 | # neighbors on logistic-link predictions 835 | if logistic and not (shap or fast): 836 | bachelor_preds = logodds(bachelor_preds).astype(_PRE_LINK_DATATYPE) 837 | 838 | bachelor_preds = self._prepare_prediction_multiindex( 839 | variable=variable, 840 | preds=bachelor_preds, 841 | shap=shap, 842 | dataset=dataset, 843 | iteration=iteration, 844 | ) 845 | 846 | return bachelor_preds 847 | 848 | def _record_candidate_preds( 849 | self, 850 | variable: str, 851 | candidate_preds: DataFrame, 852 | ): 853 | 854 | assign_col_index = candidate_preds.columns 855 | 856 | if variable not in self.candidate_preds.keys(): 857 | inferred_iteration = assign_col_index.get_level_values("iteration").unique() 858 | assert ( 859 | len(inferred_iteration) == 1 860 | ), f"Malformed iteration multiindex for {variable}: {assign_col_index}" 861 | inferred_iteration = inferred_iteration[0] 862 | assert ( 863 | inferred_iteration == 1 864 | ), "Adding initial candidate preds after iteration 1." 865 | self.candidate_preds[variable] = candidate_preds 866 | else: 867 | self.candidate_preds[variable][assign_col_index] = candidate_preds 868 | 869 | def _prepare_prediction_multiindex( 870 | self, 871 | variable: str, 872 | preds: np.ndarray, 873 | shap: bool, 874 | dataset: int, 875 | iteration: int, 876 | ) -> DataFrame: 877 | 878 | multiclass = variable in self.modeled_categorical_columns 879 | cols = self.variable_schema[variable] + ["Intercept"] 880 | 881 | if shap: 882 | 883 | if multiclass: 884 | 885 | categories = self.working_data[variable].dtype.categories 886 | cat_count = self.category_counts[variable] 887 | preds_df = DataFrame(preds, columns=cols * cat_count) 888 | del preds_df["Intercept"] 889 | cols.remove("Intercept") 890 | assign_col_index = MultiIndex.from_product( 891 | [[iteration], [dataset], categories, cols], 892 | names=("iteration", "dataset", "categories", "predictor"), 893 | ) 894 | preds_df.columns = assign_col_index 895 | 896 | else: 897 | preds_df = DataFrame(preds, columns=cols) 898 | del preds_df["Intercept"] 899 | cols.remove("Intercept") 900 | assign_col_index = MultiIndex.from_product( 901 | [[iteration], [dataset], cols], 902 | names=("iteration", "dataset", "predictor"), 903 | ) 904 | preds_df.columns = assign_col_index 905 | 906 | else: 907 | 908 | if multiclass: 909 | 910 | categories = self.working_data[variable].dtype.categories 911 | preds_df = DataFrame(preds, columns=categories) 912 | assign_col_index = MultiIndex.from_product( 913 | [[iteration], [dataset], categories], 914 | names=("iteration", "dataset", "categories"), 915 | ) 916 | preds_df.columns = assign_col_index 917 | 918 | else: 919 | 920 | preds_df = DataFrame(preds, columns=[variable]) 921 | assign_col_index = MultiIndex.from_product( 922 | [[iteration], [dataset]], names=("iteration", "dataset") 923 | ) 924 | preds_df.columns = assign_col_index 925 | 926 | return preds_df 927 | 928 | def _mean_match_mice( 929 | self, 930 | variable: str, 931 | lgbmodel: Booster, 932 | bachelor_features: DataFrame, 933 | candidate_features: DataFrame, 934 | candidate_values: Series, 935 | dataset: int, 936 | iteration: int, 937 | ): 938 | mean_match_candidates = self.mean_match_candidates[variable] 939 | using_candidate_data = variable in self.mean_matching_requires_candidates 940 | 941 | use_mean_matching = mean_match_candidates > 0 942 | if not use_mean_matching: 943 | imputation_values = self._impute_with_predictions( 944 | variable=variable, 945 | lgbmodel=lgbmodel, 946 | bachelor_features=bachelor_features, 947 | ) 948 | return imputation_values 949 | 950 | # Get bachelor predictions 951 | bachelor_preds = self._get_bachelor_preds( 952 | variable=variable, 953 | lgbmodel=lgbmodel, 954 | bachelor_features=bachelor_features, 955 | dataset=dataset, 956 | iteration=iteration, 957 | ) 958 | 959 | if using_candidate_data: 960 | 961 | candidate_preds = self._get_candidate_preds_mice( 962 | variable=variable, 963 | lgbmodel=lgbmodel, 964 | candidate_features=candidate_features, 965 | dataset=dataset, 966 | iteration=iteration, 967 | ) 968 | 969 | # By now, a numeric variable will be post-link, and 970 | # categorical / binary variables will be pre-link. 971 | imputation_values = self._mean_match_nearest_neighbors( 972 | mean_match_candidates=mean_match_candidates, 973 | bachelor_preds=bachelor_preds, 974 | candidate_preds=candidate_preds, 975 | candidate_values=candidate_values, 976 | random_state=self._random_state, 977 | hashed_seeds=None, 978 | ) 979 | 980 | else: 981 | 982 | imputation_values = self._mean_match_fast( 983 | variable=variable, 984 | mean_match_candidates=mean_match_candidates, 985 | bachelor_preds=bachelor_preds, 986 | random_state=self._random_state, 987 | hashed_seeds=None, 988 | ) 989 | 990 | return imputation_values 991 | 992 | def _mean_match_ind( 993 | self, 994 | variable: str, 995 | lgbmodel: Booster, 996 | bachelor_features: DataFrame, 997 | dataset: int, 998 | iteration: int, 999 | hashed_seeds: Optional[np.ndarray] = None, 1000 | ): 1001 | mean_match_candidates = self.mean_match_candidates[variable] 1002 | using_candidate_data = variable in self.mean_matching_requires_candidates 1003 | use_mean_matching = mean_match_candidates > 0 1004 | 1005 | if not use_mean_matching: 1006 | imputation_values = self._impute_with_predictions( 1007 | variable=variable, 1008 | lgbmodel=lgbmodel, 1009 | bachelor_features=bachelor_features, 1010 | ) 1011 | return imputation_values 1012 | 1013 | # Get bachelor predictions 1014 | bachelor_preds = self._get_bachelor_preds( 1015 | variable=variable, 1016 | lgbmodel=lgbmodel, 1017 | bachelor_features=bachelor_features, 1018 | dataset=dataset, 1019 | iteration=iteration, 1020 | ) 1021 | 1022 | if using_candidate_data: 1023 | 1024 | candidate_preds = self._get_candidate_preds_from_store( 1025 | variable=variable, 1026 | dataset=dataset, 1027 | iteration=iteration, 1028 | ) 1029 | 1030 | candidate_values = self._make_label( 1031 | variable=variable, seed=lgbmodel.params["seed"] 1032 | ) 1033 | 1034 | # By now, a numeric variable will be post-link, and 1035 | # categorical / binary variables will be pre-link. 1036 | imputation_values = self._mean_match_nearest_neighbors( 1037 | mean_match_candidates=mean_match_candidates, 1038 | bachelor_preds=bachelor_preds, 1039 | candidate_preds=candidate_preds, 1040 | candidate_values=candidate_values, 1041 | random_state=self._random_state, 1042 | hashed_seeds=hashed_seeds, 1043 | ) 1044 | 1045 | else: 1046 | 1047 | imputation_values = self._mean_match_fast( 1048 | variable=variable, 1049 | mean_match_candidates=mean_match_candidates, 1050 | bachelor_preds=bachelor_preds, 1051 | random_state=self._random_state, 1052 | hashed_seeds=hashed_seeds, 1053 | ) 1054 | 1055 | return imputation_values 1056 | 1057 | def mice( 1058 | self, 1059 | iterations: int, 1060 | verbose: bool = False, 1061 | variable_parameters: Dict[str, Any] = {}, 1062 | **kwlgb, 1063 | ): 1064 | """ 1065 | Perform MICE on a given dataset. 1066 | 1067 | Multiple Imputation by Chained Equations (MICE) is an 1068 | iterative method which fills in (imputes) missing data 1069 | points in a dataset by modeling each column using the 1070 | other columns, and then inferring the missing data. 1071 | 1072 | For more information on MICE, and missing data in 1073 | general, see Stef van Buuren's excellent online book: 1074 | https://stefvanbuuren.name/fimd/ch-introduction.html 1075 | 1076 | For detailed usage information, see this project's 1077 | README on the github repository: 1078 | https://github.com/AnotherSamWilson/miceforest 1079 | 1080 | Parameters 1081 | ---------- 1082 | iterations: int 1083 | The number of iterations to run. 1084 | 1085 | verbose: bool 1086 | Should information about the process be printed? 1087 | 1088 | variable_parameters: None or dict 1089 | Model parameters can be specified by variable here. Keys should 1090 | be variable names or indices, and values should be a dict of 1091 | parameter which should apply to that variable only. 1092 | 1093 | .. code-block:: python 1094 | 1095 | variable_parameters = { 1096 | 'column': { 1097 | 'min_sum_hessian_in_leaf: 25.0, 1098 | 'extra_trees': True, 1099 | } 1100 | } 1101 | 1102 | kwlgb: 1103 | Additional parameters to pass to lightgbm. Applied to all models. 1104 | 1105 | """ 1106 | 1107 | current_iterations = self.iteration_count() 1108 | start_iter = current_iterations + 1 1109 | end_iter = current_iterations + iterations + 1 1110 | logger = Logger( 1111 | name=f"MICE Iterations {current_iterations + 1} - {current_iterations + iterations}", 1112 | timed_levels=_MICE_TIMED_LEVELS, 1113 | verbose=verbose, 1114 | ) 1115 | 1116 | if len(variable_parameters) > 0: 1117 | assert isinstance( 1118 | variable_parameters, dict 1119 | ), "variable_parameters should be a dict." 1120 | assert set(variable_parameters).issubset(self.model_training_order), ( 1121 | "Variables in variable_parameters will not have models trained. " 1122 | "Check kernel.model_training_order" 1123 | ) 1124 | 1125 | for iteration in range(start_iter, end_iter, 1): 1126 | # absolute_iteration = self.iteration_count(datasets=dataset) 1127 | logger.log(str(iteration) + " ", end="") 1128 | 1129 | for dataset in self.datasets: 1130 | logger.log("Dataset " + str(dataset)) 1131 | 1132 | # Set self.working_data to the most current iteration. 1133 | self.complete_data(dataset=dataset, inplace=True) 1134 | 1135 | for variable in self.model_training_order: 1136 | logger.log(" | " + variable, end="") 1137 | 1138 | # Define the lightgbm parameters 1139 | lgbpars = self._make_lgb_params( 1140 | variable=variable, 1141 | default_parameters=_DEFAULT_LGB_PARAMS.copy(), 1142 | variable_parameters=variable_parameters.get(variable, dict()), 1143 | **kwlgb, 1144 | ) 1145 | 1146 | time_key = dataset, iteration, variable, "Prepare XY" 1147 | logger.set_start_time(time_key) 1148 | ( 1149 | candidate_features, 1150 | candidate_values, 1151 | ) = self._make_features_label( 1152 | variable=variable, seed=lgbpars["seed"] 1153 | ) 1154 | 1155 | # lightgbm requires integers for label. Categories won't work. 1156 | if candidate_values.dtype.name == "category": 1157 | label = candidate_values.cat.codes 1158 | else: 1159 | label = candidate_values 1160 | 1161 | num_iterations = lgbpars.pop("num_iterations") 1162 | train_pointer = Dataset( 1163 | data=candidate_features, 1164 | label=label, 1165 | ) 1166 | logger.record_time(time_key) 1167 | 1168 | time_key = dataset, iteration, variable, "Training" 1169 | logger.set_start_time(time_key) 1170 | current_model = train( 1171 | params=lgbpars, 1172 | train_set=train_pointer, 1173 | num_boost_round=num_iterations, 1174 | keep_training_booster=True, 1175 | ) 1176 | logger.record_time(time_key) 1177 | 1178 | # Only perform mean matching and insertion 1179 | # if variable is being imputed. 1180 | if variable in self.imputation_order: 1181 | time_key = dataset, iteration, variable, "Mean Matching" 1182 | logger.set_start_time(time_key) 1183 | bachelor_features = self._get_bachelor_features( 1184 | variable=variable 1185 | ) 1186 | imputation_values = self._mean_match_mice( 1187 | variable=variable, 1188 | lgbmodel=current_model, 1189 | bachelor_features=bachelor_features, 1190 | candidate_features=candidate_features, 1191 | candidate_values=candidate_values, 1192 | dataset=dataset, 1193 | iteration=iteration, 1194 | ) 1195 | imputation_values.index = self.na_where[variable] 1196 | logger.record_time(time_key) 1197 | 1198 | assert imputation_values.shape == ( 1199 | self.na_counts[variable], 1200 | ), f"{variable} mean matching returned malformed array" 1201 | 1202 | # Insert the imputation_values we obtained 1203 | self[variable, iteration, dataset] = imputation_values 1204 | 1205 | if not self.save_all_iterations_data: 1206 | del self[variable, iteration - 1, dataset] 1207 | 1208 | else: 1209 | 1210 | # This is called to save the candidate predictions 1211 | _ = self._get_candidate_preds_mice( 1212 | variable=variable, 1213 | lgbmodel=current_model, 1214 | candidate_features=candidate_features, 1215 | dataset=dataset, 1216 | iteration=iteration, 1217 | ) 1218 | del _ 1219 | 1220 | # Save the model, if we should be 1221 | if self.save_all_iterations_data: 1222 | self.models[variable, iteration, dataset] = ( 1223 | current_model.free_dataset() 1224 | ) 1225 | 1226 | self.iteration_tab[variable, dataset] += 1 1227 | 1228 | logger.log("\n", end="") 1229 | 1230 | self._ampute_original_data() 1231 | self.loggers.append(logger) 1232 | 1233 | def get_model( 1234 | self, 1235 | variable: str, 1236 | dataset: int, 1237 | iteration: int = -1, 1238 | ): 1239 | """ 1240 | Returns the model trained for the specified variable, dataset, iteration. 1241 | Model must have been saved. 1242 | 1243 | Parameters 1244 | ---------- 1245 | variable: str 1246 | The variable 1247 | 1248 | dataset: int 1249 | The dataset 1250 | 1251 | iteration: str 1252 | The iteration. Use -1 for the latest. 1253 | """ 1254 | # Allow passing -1 to get the latest iteration's model 1255 | if iteration == -1: 1256 | iteration = self.iteration_count(dataset=dataset, variable=variable) 1257 | try: 1258 | model = self.models[variable, iteration, dataset] 1259 | except KeyError: 1260 | raise ValueError("Model was not saved.") 1261 | return model 1262 | 1263 | def fit(self, X, y, **fit_params): 1264 | """ 1265 | Method for fitting a kernel when used in a sklearn pipeline. 1266 | Should not be called by the user directly. 1267 | """ 1268 | assert self.num_datasets == 1, ( 1269 | "miceforest kernel should be initialized with datasets=1 if " 1270 | "being used in a sklearn pipeline." 1271 | ) 1272 | assert X.equals(self.working_data), ( 1273 | "It looks like this kernel is being used in a sklearn pipeline. " 1274 | "The data passed in fit() should be the same as the data that " 1275 | "was originally passed to the kernel. If this kernel is not being " 1276 | "used in an sklearn pipeline, please just use the mice() method." 1277 | ) 1278 | self.mice(**fit_params) 1279 | return self 1280 | 1281 | def transform(self, X, y=None): 1282 | """ 1283 | Method for calling a kernel when used in a sklearn pipeline. 1284 | Should not be called by the user directly. 1285 | """ 1286 | 1287 | new_dat = self.impute_new_data(X, datasets=[0]) 1288 | return new_dat.complete_data(dataset=0, inplace=False) 1289 | 1290 | def tune_parameters( 1291 | self, 1292 | dataset: int = 0, 1293 | variables: Optional[List[str]] = None, 1294 | variable_parameters: Dict[str, Any] = dict(), 1295 | parameter_sampling_method: Literal["random"] = "random", 1296 | max_reattempts: int = 5, 1297 | use_gbdt: bool = True, 1298 | nfold: int = 10, 1299 | optimization_steps: int = 5, 1300 | random_state: Optional[Union[int, np.random.RandomState]] = None, 1301 | verbose: bool = False, 1302 | **kwargs, 1303 | ): 1304 | """ 1305 | Perform hyperparameter tuning on models at the current iteration. 1306 | This method is not meant to be robust, but to get a decent set of 1307 | parameters to help with imputation. A few notes: 1308 | 1309 | - The parameters are tuned on the data that would currently be returned by 1310 | complete_data(dataset). It is usually a good idea to run at least 1 iteration 1311 | of mice with the default parameters to get a more accurate idea of the 1312 | real optimal parameters, since Missing At Random (MAR) data imputations 1313 | tend to converge over time. 1314 | - num_iterations is treated as the maximum number of boosting rounds to run 1315 | in lightgbm.cv. It is NEVER optimized. The num_iterations that is returned 1316 | is the best_iteration returned by lightgbm.cv. num_iterations can be passed to 1317 | limit the boosting rounds, but the returned value will always be obtained 1318 | from best_iteration. 1319 | - lightgbm parameters are chosen in the following order of priority: 1320 | - Anything specified in variable_parameters 1321 | - Parameters specified globally in `**kwbounds` 1322 | - Default tuning space (miceforest.default_lightgbm_parameters) 1323 | - Default parameters (miceforest.default_lightgbm_parameters.default_parameters) 1324 | - See examples for a detailed run-through. See 1325 | https://github.com/AnotherSamWilson/miceforest#Tuning-Parameters 1326 | for even more detailed examples. 1327 | 1328 | Parameters 1329 | ---------- 1330 | dataset: int (required) 1331 | The dataset to run parameter tuning on. Tuning parameters on 1 dataset usually results 1332 | in acceptable parameters for all datasets. However, tuning results are still stored 1333 | seperately for each dataset. 1334 | variables: None or List[str] 1335 | - If None, default hyper-parameter spaces are selected based on kernel data, and 1336 | all variables with missing values are tuned. 1337 | - If list, must either be indexes or variable names corresponding to the variables 1338 | that are to be tuned. 1339 | 1340 | variable_parameters: None or dict 1341 | Defines the tuning space. Dict keys must be variable names or indices, and a subset 1342 | of the variables parameter. Values must be a dict with lightgbm parameter names as 1343 | keys, and values that abide by the following rules: 1344 | 1345 | - **scalar**: If a single value is passed, that parameter will be used to build the 1346 | model, and will not be tuned. 1347 | - **tuple**: If a tuple is passed, it must have length = 2 and will be interpreted as 1348 | the bounds to search within for that parameter. 1349 | - **list**: If a list is passed, values will be randomly selected from the list. 1350 | 1351 | example: If you wish to tune the imputation model for the 4th variable with specific 1352 | bounds and parameters, you could pass: 1353 | 1354 | .. code-block:: python 1355 | 1356 | variable_parameters = { 1357 | 'column': { 1358 | 'learning_rate: 0.01', 1359 | 'min_sum_hessian_in_leaf: (0.1, 10), 1360 | 'extra_trees': [True, False] 1361 | } 1362 | } 1363 | 1364 | All models for variable 'column' will have a learning_rate = 0.01. The process will randomly 1365 | search within the bounds (0.1, 10) for min_sum_hessian_in_leaf, and extra_trees will 1366 | be randomly selected from the list. Also note, the variable name for the 4th column 1367 | could also be passed instead of the integer 4. All other variables will be tuned with 1368 | the default search space, unless `**kwbounds` are passed. 1369 | 1370 | parameter_sampling_method: str 1371 | If :code:`random`, parameters are randomly selected. 1372 | Other methods will be added in future releases. 1373 | 1374 | max_reattempts: int 1375 | The maximum number of failures (or non-learners) before the process stops, and moves to the 1376 | next variable. Failures can be caused by bad parameters passed to lightgbm. Non-learners 1377 | occur when trees cannot possibly be built (i.e. if :code:`min_data_in_leaf > dataset.shape[0]`). 1378 | 1379 | use_gbdt: bool 1380 | Whether the models should use gradient boosting instead of random forests. 1381 | If True, the optimal number of iterations will be found in lgb.cv, along 1382 | with the other parameters. 1383 | 1384 | nfold: int 1385 | The number of folds to perform cross validation with. More folds takes longer, but 1386 | Gives a more accurate distribution of the error metric. 1387 | 1388 | optimization_steps: int 1389 | How many steps to run the process for. 1390 | 1391 | random_state: int or np.random.RandomState or None (default=None) 1392 | The random state of the process. Ensures reproduceability. If None, the random state 1393 | of the kernel is used. Beware, this permanently alters the random state of the kernel 1394 | and ensures non-reproduceable results, unless the entire process up to this point 1395 | is re-run. 1396 | 1397 | verbose: bool 1398 | Whether to print progress. 1399 | 1400 | kwbounds: 1401 | Any additional arguments that you want to apply globally to every variable. 1402 | For example, if you want to limit the number of iterations, you could pass 1403 | num_iterations = x to this functions, and it would apply globally. Custom 1404 | bounds can also be passed. 1405 | 1406 | 1407 | Returns 1408 | ------- 1409 | optimal_parameters: dict 1410 | A dict of the optimal parameters found for each variable. 1411 | This can be passed directly to the :code:`variable_parameters` parameter in :code:`mice()` 1412 | 1413 | """ 1414 | 1415 | random_state = ensure_rng(random_state) 1416 | 1417 | if variables is None: 1418 | variables = self.imputation_order 1419 | 1420 | self.complete_data(dataset, inplace=True) 1421 | 1422 | logger = Logger( 1423 | name=f"tune: {optimization_steps}", 1424 | timed_levels=_TUNING_TIMED_LEVELS, 1425 | verbose=verbose, 1426 | ) 1427 | 1428 | for variable in variables: 1429 | 1430 | logger.log(f"Optimizing {variable}") 1431 | 1432 | seed = _draw_random_int32(random_state=random_state, size=1) 1433 | 1434 | ( 1435 | candidate_features, 1436 | candidate_values, 1437 | ) = self._make_features_label(variable=variable, seed=seed) 1438 | 1439 | min_samples = ( 1440 | self.category_counts[variable] 1441 | if variable in self.modeled_categorical_columns 1442 | else 1 1443 | ) 1444 | max_samples = int(candidate_features.shape[0] / 5) 1445 | 1446 | assert isinstance( 1447 | variable_parameters, dict 1448 | ), "variable_parameters should be a dict" 1449 | vp = variable_parameters.get(variable, dict()).copy() 1450 | 1451 | tuning_space = self._make_tuning_space( 1452 | variable=variable, 1453 | variable_parameters=vp, 1454 | use_gbdt=use_gbdt, 1455 | min_samples=min_samples, 1456 | max_samples=max_samples, 1457 | **kwargs, 1458 | ) 1459 | 1460 | # lightgbm requires integers for label. Categories won't work. 1461 | if candidate_values.dtype.name == "category": 1462 | cat_cols = ( 1463 | self.modeled_categorical_columns + self.modeled_binary_columns 1464 | ) 1465 | assert variable in cat_cols, ( 1466 | "Something went wrong in definining categorical " 1467 | f"status of variable {variable}. Please open an issue." 1468 | ) 1469 | candidate_values = candidate_values.cat.codes 1470 | is_cat = True 1471 | else: 1472 | is_cat = False 1473 | 1474 | for step in range(optimization_steps): 1475 | 1476 | # Make multiple attempts to learn something. 1477 | non_learners = 0 1478 | while non_learners < max_reattempts: 1479 | 1480 | # Sample parameters 1481 | sampled_parameters = _sample_parameters( 1482 | parameters=tuning_space, 1483 | random_state=random_state, 1484 | parameter_sampling_method=parameter_sampling_method, 1485 | ) 1486 | logger.log( 1487 | f" Step {step} - Parameters: {sampled_parameters}", end="" 1488 | ) 1489 | 1490 | # Pointer and folds need to be re-initialized after every run. 1491 | train_set = Dataset( 1492 | data=candidate_features, 1493 | label=candidate_values, 1494 | ) 1495 | if is_cat: 1496 | folds = stratified_categorical_folds(candidate_values, nfold) 1497 | else: 1498 | folds = stratified_continuous_folds(candidate_values, nfold) 1499 | 1500 | try: 1501 | loss, best_iteration = self._get_oof_performance( 1502 | parameters=sampled_parameters.copy(), 1503 | folds=folds, 1504 | train_set=train_set, 1505 | ) 1506 | except Exception as err: 1507 | non_learners += 1 1508 | logger.log(f" - Lightgbm Error {err=}, {type(err)=}") 1509 | continue 1510 | 1511 | if best_iteration > 1: 1512 | logger.log(f" - Success - Loss: {loss}") 1513 | break 1514 | else: 1515 | logger.log(" - Non-Learner") 1516 | non_learners += 1 1517 | 1518 | best_loss = self.optimal_parameter_losses.get(variable, np.inf) 1519 | if loss < best_loss: 1520 | del sampled_parameters["seed"] 1521 | sampled_parameters["num_iterations"] = best_iteration 1522 | self.optimal_parameters[variable] = sampled_parameters 1523 | self.optimal_parameter_losses[variable] = loss 1524 | 1525 | self._ampute_original_data() 1526 | return self.optimal_parameters 1527 | 1528 | def impute_new_data( 1529 | self, 1530 | new_data: DataFrame, 1531 | datasets: Optional[List[int]] = None, 1532 | iterations: Optional[int] = None, 1533 | save_all_iterations_data: bool = True, 1534 | copy_data: bool = True, 1535 | random_state: Optional[Union[int, np.random.RandomState]] = None, 1536 | random_seed_array: Optional[np.ndarray] = None, 1537 | verbose: bool = False, 1538 | ) -> ImputedData: 1539 | """ 1540 | Impute a new dataset 1541 | 1542 | Uses the models obtained while running MICE to impute new data, 1543 | without fitting new models. Pulls mean matching candidates from 1544 | the original data. 1545 | 1546 | save_models must be > 0. If save_models == 1, the last model 1547 | obtained in mice is used for every iteration. If save_models > 1, 1548 | the model obtained at each iteration is used to impute the new 1549 | data for that iteration. If specified iterations is greater than 1550 | the number of iterations run so far using mice, the last model 1551 | is used for each additional iteration. 1552 | 1553 | Type checking is not done. It is up to the user to ensure that the 1554 | kernel data matches the new data being imputed. 1555 | 1556 | Parameters 1557 | ---------- 1558 | new_data: pandas.DataFrame 1559 | The new data to impute 1560 | 1561 | datasets: int or List[int], default = None 1562 | The datasets from the kernel to use to impute the new data. 1563 | If :code:`None`, all datasets from the kernel are used. 1564 | 1565 | iterations: int, default=None 1566 | The number of iterations to run. 1567 | If :code:`None`, the same number of iterations run so far in mice is used. 1568 | 1569 | save_all_iterations_data: bool, default=True 1570 | Should the imputation values of all iterations be archived? 1571 | If :code:`False`, only the latest imputation values are saved. 1572 | 1573 | copy_data: boolean, default=True 1574 | Should the dataset be referenced directly? This will cause the dataset to be altered 1575 | in place. 1576 | 1577 | random_state: None or int or np.random.RandomState (default=None) 1578 | The random state of the process. Ensures reproducibility. If :code:`None`, the random state 1579 | of the kernel is used. Beware, this permanently alters the random state of the kernel 1580 | and ensures non-reproduceable results, unless the entire process up to this point 1581 | is re-run. 1582 | 1583 | random_seed_array: None or np.ndarray[uint32, int32, uint64] 1584 | Record-level seeds. 1585 | 1586 | Ensures deterministic imputations at the record level. random_seed_array causes 1587 | deterministic imputations for each record no matter what dataset each record is 1588 | imputed with, assuming the same number of iterations and datasets are used. 1589 | If :code:`random_seed_array` is passed, random_state must also be passed. 1590 | 1591 | Record-level imputations are deterministic if the following conditions are met: 1592 | 1) The associated value in :code:`random_seed_array` is the same. 1593 | 2) The same kernel is used. 1594 | 3) The same number of iterations are run. 1595 | 4) The same number of datasets are run. 1596 | 1597 | Note: Using this parameter may change the global numpy seed by calling :code:`np.random.seed()` 1598 | 1599 | verbose: boolean, default=False 1600 | Should information about the process be printed? 1601 | 1602 | Returns 1603 | ------- 1604 | miceforest.ImputedData 1605 | 1606 | """ 1607 | 1608 | assert self.save_all_iterations_data, ( 1609 | "Cannot recreate imputation procedure, data was not saved during MICE. " 1610 | "To save this data, set save_all_iterations_data to True when making kernel." 1611 | ) 1612 | 1613 | # datasets = list(range(self.num_datasets)) if datasets is None else datasets 1614 | datasets = self.datasets if datasets is None else datasets 1615 | kernel_iterations = self.iteration_count() 1616 | iterations = kernel_iterations if iterations is None else iterations 1617 | logger = Logger( 1618 | name=f"Impute New Data {0}-{iterations}", 1619 | timed_levels=_IMPUTE_NEW_DATA_TIMED_LEVELS, 1620 | verbose=verbose, 1621 | ) 1622 | 1623 | assert isinstance(new_data, DataFrame) 1624 | assert self.working_data.columns.equals( 1625 | new_data.columns 1626 | ), "Different columns from original dataset." 1627 | assert np.all( 1628 | [ 1629 | self.working_data[col].dtype == new_data[col].dtype 1630 | for col in self.column_names 1631 | ] 1632 | ), "Column types are not the same as the original data. Check categorical columns." 1633 | 1634 | imputed_data = ImputedData( 1635 | impute_data=new_data, 1636 | # num_datasets=len(datasets), 1637 | datasets=datasets, 1638 | variable_schema=self.variable_schema.copy(), 1639 | save_all_iterations_data=save_all_iterations_data, 1640 | copy_data=copy_data, 1641 | random_seed_array=random_seed_array, 1642 | ) 1643 | new_imputation_order = [ 1644 | col 1645 | for col in self.model_training_order 1646 | if col in imputed_data.vars_with_any_missing 1647 | ] 1648 | 1649 | ### Manage Randomness. 1650 | if random_state is None: 1651 | assert ( 1652 | random_seed_array is None 1653 | ), "random_state is also required when using random_seed_array" 1654 | random_state = self._random_state 1655 | else: 1656 | random_state = ensure_rng(random_state) 1657 | 1658 | self._initialize_dataset( 1659 | imputed_data, 1660 | random_state=random_state, 1661 | ) 1662 | 1663 | for iteration in range(1, iterations + 1): 1664 | logger.log(str(iteration) + " ", end="") 1665 | 1666 | for dataset in datasets: 1667 | logger.log("Dataset " + str(dataset)) 1668 | self.complete_data(dataset=dataset, inplace=True) 1669 | imputed_data.complete_data(dataset=dataset, inplace=True) 1670 | 1671 | for variable in new_imputation_order: 1672 | logger.log(" | " + variable, end="") 1673 | 1674 | # Select our model. 1675 | current_model = self.get_model( 1676 | variable=variable, dataset=dataset, iteration=iteration 1677 | ) 1678 | 1679 | time_key = dataset, iteration, variable, "Getting Bachelor Features" 1680 | logger.set_start_time(time_key) 1681 | bachelor_features = imputed_data._get_bachelor_features(variable) 1682 | hashed_seeds = imputed_data._get_hashed_seeds(variable) 1683 | logger.record_time(time_key) 1684 | 1685 | time_key = dataset, iteration, variable, "Mean Matching" 1686 | logger.set_start_time(time_key) 1687 | na_where = imputed_data.na_where[variable] 1688 | imputation_values = self._mean_match_ind( 1689 | variable=variable, 1690 | lgbmodel=current_model, 1691 | bachelor_features=bachelor_features, 1692 | dataset=dataset, 1693 | iteration=iteration, 1694 | hashed_seeds=hashed_seeds, 1695 | ) 1696 | # self.cycle_random_seed_array(variable) 1697 | imputation_values.index = na_where 1698 | logger.record_time(time_key) 1699 | 1700 | assert imputation_values.shape == ( 1701 | imputed_data.na_counts[variable], 1702 | ), f"{variable} mean matching returned malformed array" 1703 | 1704 | # Insert the imputation_values we obtained 1705 | imputed_data[variable, iteration, dataset] = imputation_values 1706 | 1707 | if not imputed_data.save_all_iterations_data: 1708 | del imputed_data[variable, iteration - 1, dataset] 1709 | 1710 | logger.log("\n", end="") 1711 | 1712 | imputed_data._ampute_original_data() 1713 | self.loggers.append(logger) 1714 | 1715 | return imputed_data 1716 | 1717 | def get_feature_importance( 1718 | self, 1719 | dataset: int = 0, 1720 | iteration: int = -1, 1721 | importance_type: str = "split", 1722 | normalize: bool = True, 1723 | ) -> DataFrame: 1724 | """ 1725 | Return a matrix of feature importance. The cells 1726 | represent the normalized feature importance of the 1727 | columns to impute the rows. This is calculated 1728 | internally by lightgbm.Booster.feature_importance(). 1729 | 1730 | Parameters 1731 | ---------- 1732 | dataset: int 1733 | The dataset to get the feature importance for. 1734 | 1735 | iteration: int 1736 | The iteration to return the feature importance for. 1737 | The model must be saved to return importance. 1738 | Use -1 to specify the latest iteration. 1739 | 1740 | importance_type: str 1741 | Passed to :code:`lgb.feature_importance()` 1742 | 1743 | normalize: bool 1744 | Whether to normalize the values within 1745 | each modeled variable to sum to 1. 1746 | 1747 | Returns 1748 | ------- 1749 | pandas.DataFrame of importance values. Rows are imputed variables, and columns are predictor variables. 1750 | 1751 | """ 1752 | 1753 | if iteration == -1: 1754 | iteration = self.iteration_count(dataset=dataset) 1755 | 1756 | modeled_vars = [ 1757 | col for col in self.working_data.columns if col in self.model_training_order 1758 | ] 1759 | 1760 | importance_matrix = DataFrame( 1761 | index=modeled_vars, columns=self.predictor_columns 1762 | ) 1763 | for modeled_variable in modeled_vars: 1764 | predictor_vars = self.variable_schema[modeled_variable] 1765 | importances = self.get_model( 1766 | variable=modeled_variable, dataset=dataset, iteration=iteration 1767 | ).feature_importance(importance_type=importance_type) 1768 | importances = Series(importances, index=predictor_vars) 1769 | importance_matrix.loc[modeled_variable, predictor_vars] = importances 1770 | 1771 | importance_matrix = importance_matrix.astype("float64") 1772 | 1773 | if normalize: 1774 | importance_matrix /= importance_matrix.sum(1).to_numpy().reshape(-1, 1) 1775 | 1776 | return importance_matrix 1777 | 1778 | def plot_feature_importance( 1779 | self, 1780 | dataset, 1781 | importance_type: str = "split", 1782 | normalize: bool = True, 1783 | iteration: int = -1, 1784 | ): 1785 | """ 1786 | Plot the feature importance. See get_feature_importance() 1787 | for more details. 1788 | 1789 | Parameters 1790 | ---------- 1791 | dataset: int 1792 | The dataset to plot the feature importance for. 1793 | 1794 | importance_type: str 1795 | Passed to lgb.feature_importance() 1796 | 1797 | normalize: book 1798 | Should the values be normalize from 0-1? 1799 | If False, values are raw from Booster.feature_importance() 1800 | 1801 | kw_plot 1802 | Additional arguments sent to sns.heatmap() 1803 | 1804 | """ 1805 | 1806 | try: 1807 | from plotnine import ( 1808 | aes, 1809 | element_blank, 1810 | element_text, 1811 | geom_label, 1812 | geom_tile, 1813 | ggplot, 1814 | ggtitle, 1815 | scale_fill_distiller, 1816 | theme, 1817 | xlab, 1818 | ylab, 1819 | ) 1820 | except ImportError: 1821 | raise ImportError("plotnine must be installed to plot importance") 1822 | 1823 | importance_matrix = self.get_feature_importance( 1824 | dataset=dataset, 1825 | iteration=iteration, 1826 | normalize=normalize, 1827 | importance_type=importance_type, 1828 | ) 1829 | importance_matrix = importance_matrix.reset_index().melt(id_vars="index") 1830 | importance_matrix["Importance"] = importance_matrix["value"].round(2) 1831 | importance_matrix = importance_matrix.dropna() 1832 | 1833 | fig = ( 1834 | ggplot(importance_matrix, aes(x="variable", y="index", fill="Importance")) 1835 | + geom_tile(show_legend=False) 1836 | + ylab("Modeled Variable") 1837 | + xlab("Predictor") 1838 | + ggtitle("Feature Importance") 1839 | + geom_label(aes(label="Importance"), fill="white", size=8) 1840 | + scale_fill_distiller(palette=1, direction=1) 1841 | + theme( 1842 | axis_text_x=element_text(rotation=30, hjust=1), 1843 | plot_title=element_text(ha="left", size=20), 1844 | panel_background=element_blank(), 1845 | figure_size=(6, 6), 1846 | ) 1847 | ) 1848 | 1849 | return fig 1850 | --------------------------------------------------------------------------------