├── tests
├── __init__.py
├── utrees
│ ├── __init__.py
│ ├── test_baltobot.py
│ └── test_utrees.py
└── test_imports.py
├── .gitattributes
├── paper
├── wave-demo.png
├── iris-vertical.png
├── poisson-demo.png
├── moons-generation.png
├── moons-imputation.png
├── wave-time-pdfat2.png
├── wave-baltobot-readme.jpg
├── baltobotmodel.py
├── moons.ipynb
├── baltobot-simple-demo.ipynb
└── iris.ipynb
├── .pre-commit-config.yaml
├── utrees
├── __init__.py
├── kdi_quantizer.py
├── baltobot.py
└── unmasking_trees.py
├── requirements.txt
├── LICENSE
├── Results
├── latex_generation_wmissing_baltobot4d8f71b.txt
├── latex_imputation_baltobot4d8f71b.txt
├── latex_generation_baltobot4d8f71b.txt
├── latex_imputation_ablation_baltobot4d8f71b.txt
├── tabular_imputation_results_kmeans.txt
├── tabular_imputation_results_tobt.txt
├── tabular_imputation_results.txt
├── tabular_imputation_results_baltobot4d8f71b.txt
├── latex_tables.txt
├── imputation_script.R
├── generation_script_nmiss50.R
└── generation_script.R
├── pyproject.toml
├── setup.py
├── .gitignore
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/utrees/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.pdf binary
2 | *.eps binary
3 | *.png binary
4 |
--------------------------------------------------------------------------------
/paper/wave-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/wave-demo.png
--------------------------------------------------------------------------------
/paper/iris-vertical.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/iris-vertical.png
--------------------------------------------------------------------------------
/paper/poisson-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/poisson-demo.png
--------------------------------------------------------------------------------
/paper/moons-generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/moons-generation.png
--------------------------------------------------------------------------------
/paper/moons-imputation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/moons-imputation.png
--------------------------------------------------------------------------------
/paper/wave-time-pdfat2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/wave-time-pdfat2.png
--------------------------------------------------------------------------------
/paper/wave-baltobot-readme.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/wave-baltobot-readme.jpg
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | files: ^utrees/
2 | repos:
3 | - repo: https://github.com/psf/black
4 | rev: 22.10.0
5 | hooks:
6 | - id: black
7 |
--------------------------------------------------------------------------------
/tests/test_imports.py:
--------------------------------------------------------------------------------
1 | def test_import_packages():
2 | """Test that importing our package works."""
3 | import utrees
4 | from utrees import UnmaskingTrees
5 | from utrees import KDIQuantizer
6 |
--------------------------------------------------------------------------------
/utrees/__init__.py:
--------------------------------------------------------------------------------
1 | from utrees.baltobot import Baltobot
2 | from utrees.baltobot import NanTabPFNClassifier
3 | from utrees.kdi_quantizer import KDIQuantizer
4 | from utrees.unmasking_trees import UnmaskingTrees
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black
2 | numpy
3 | pre-commit>=2.2.0
4 | pytest>=5.4.1
5 | scikit-learn>=0.23
6 | scipy>=1.0
7 | tqdm
8 | torch
9 | lightgbm
10 | xgboost
11 | pandas
12 | matplotlib
13 | jupyterlab
14 | seaborn
15 |
--------------------------------------------------------------------------------
/tests/utrees/test_baltobot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from sklearn.datasets import make_moons
5 |
6 | from utrees import Baltobot
7 |
8 |
9 | def test_moons():
10 | n_upper = 100
11 | n_lower = 100
12 | n_generate = 123
13 | n = n_upper + n_lower
14 | data, labels = make_moons(
15 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345)
16 | X = data[:, 0:1]
17 | y = data[:, 1]
18 | balto = Baltobot(random_state=12345)
19 | balto.fit(X, y)
20 | sampley = balto.sample(X)
21 | assert sampley.shape == (n, )
22 | evaly = np.linspace(-3, 3, 1000)
23 | evalX = np.full((1000, 1), fill_value=0.)
24 | scores = balto.score_samples(evalX, evaly)
25 | assert scores.shape == (1000,)
26 | probs = np.exp(scores)
27 | assert 0 <= np.min(probs)
28 | assert 0.99 < np.sum(probs)
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Calvin McCarter
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Results/latex_generation_wmissing_baltobot4d8f71b.txt:
--------------------------------------------------------------------------------
1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
2 | % Tue Sep 3 11:00:16 2024
3 | \begin{table}[ht]
4 | \centering
5 | \begin{tabular}{rlllllllll}
6 | \hline
7 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\
8 | \hline
9 | GaussianCopula & 7 (0.3) & 7.1 (0.2) & 7.2 (0.3) & 7.1 (0.3) & 6.3 (0.4) & 6.6 (0.3) & 6.7 (0.4) & 5.5 (1) & 7.7 (0.6) \\
10 | TVAE & 5.2 (0.3) & 4.9 (0.3) & 5.7 (0.3) & 5.8 (0.2) & 6 (1) & 5.8 (0.5) & 5.8 (0.4) & 8 (0.4) & 6.2 (1) \\
11 | CTGAN & 8.3 (0.2) & 8.4 (0.2) & 8.4 (0.2) & 8.3 (0.2) & 8.3 (0.3) & 8.4 (0.2) & 6.5 (0.2) & 4.8 (1.2) & 7.1 (0.7) \\
12 | CTABGAN & 6.7 (0.4) & 6.5 (0.4) & 7.1 (0.3) & 6.8 (0.3) & 7.3 (0.6) & 7.1 (0.4) & 6.6 (0.3) & 7.5 (1) & 6.1 (0.6) \\
13 | Stasy & 5.9 (0.2) & 6.1 (0.3) & 5.3 (0.2) & 5.1 (0.3) & 5.8 (0.9) & 4.4 (0.4) & 5.3 (0.4) & 3.7 (0.4) & 4.6 (1.1) \\
14 | TabDDPM & 3 (0.7) & 3.4 (0.7) & 2.3 (0.5) & 2.9 (0.6) & 1.7 (0.3) & 3.3 (0.6) & 3.9 (0.6) & 3.8 (1.2) & 2 (0.5) \\
15 | Forest-VP & 3.7 (0.2) & 3.2 (0.3) & 3.9 (0.2) & 3.8 (0.3) & 3.2 (0.3) & 2.3 (0.3) & 4.2 (0.4) & 4.2 (0.8) & 4.5 (1.1) \\
16 | Forest-Flow & 3 (0.3) & 2.6 (0.3) & 2.6 (0.3) & 2.7 (0.2) & 3 (0.7) & 3.7 (0.3) & 5 (0.5) & 3.8 (0.9) & 3.2 (0.8) \\
17 | UTrees & 2.1 (0.2) & 2.8 (0.3) & 2.5 (0.2) & 2.5 (0.2) & 3.3 (0.8) & 3.5 (0.5) & 1 (0) & 3.7 (0.9) & 3.7 (1) \\
18 | \hline
19 | \end{tabular}
20 | \end{table}
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "utrees"
7 | version = "0.3.0"
8 | description = "Unmasking trees for tabular data generation and imputation"
9 | readme = "README.md"
10 | authors = [
11 | { name = "Calvin McCarter", email = "mccarter.calvin@gmail.com" },
12 | ]
13 | maintainers = [
14 | { name = "Calvin McCarter", email = "mccarter.calvin@gmail.com" },
15 | ]
16 | keywords = [
17 | "generation",
18 | "imputation",
19 | "tabular",
20 | ]
21 | classifiers = [
22 | "Development Status :: 3 - Alpha",
23 | "Intended Audience :: Developers",
24 | "Intended Audience :: Science/Research",
25 | "License :: OSI Approved",
26 | "Operating System :: MacOS",
27 | "Operating System :: Microsoft :: Windows",
28 | "Operating System :: POSIX",
29 | "Operating System :: Unix",
30 | "Programming Language :: Python",
31 | "Programming Language :: Python :: 3.6",
32 | "Programming Language :: Python :: 3.7",
33 | "Programming Language :: Python :: 3.8",
34 | "Programming Language :: Python :: 3.9",
35 | "Topic :: Scientific/Engineering",
36 | "Topic :: Software Development",
37 | ]
38 | dependencies = [
39 | "kditransform",
40 | "numba >= 0.48",
41 | "numpy",
42 | "pandas",
43 | "scikit-learn >= 0.23",
44 | "scipy >= 1.0",
45 | "tqdm",
46 | "xgboost",
47 | ]
48 |
49 | [project.urls]
50 | Homepage = "http://github.com/calvinmccarter/unmasking-trees"
51 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 |
4 | def readme():
5 | with open("README.md") as readme_file:
6 | return readme_file.read()
7 |
8 |
9 | configuration = {
10 | "name": "utrees",
11 | "version": "0.3.0",
12 | "description": "Unmasking trees for tabular data generation and imputation",
13 | "long_description": readme(),
14 | "long_description_content_type": "text/markdown",
15 | "classifiers": [
16 | "Development Status :: 3 - Alpha",
17 | "Intended Audience :: Science/Research",
18 | "Intended Audience :: Developers",
19 | "License :: OSI Approved",
20 | "Programming Language :: Python",
21 | "Topic :: Software Development",
22 | "Topic :: Scientific/Engineering",
23 | "Operating System :: Microsoft :: Windows",
24 | "Operating System :: POSIX",
25 | "Operating System :: Unix",
26 | "Operating System :: MacOS",
27 | "Programming Language :: Python :: 3.6",
28 | "Programming Language :: Python :: 3.7",
29 | "Programming Language :: Python :: 3.8",
30 | "Programming Language :: Python :: 3.9",
31 | ],
32 | "keywords": "tabular,imputation,generation",
33 | "url": "http://github.com/calvinmccarter/unmasking-trees",
34 | "author": "Calvin McCarter",
35 | "author_email": "mccarter.calvin@gmail.com",
36 | "maintainer": "Calvin McCarter",
37 | "maintainer_email": "mccarter.calvin@gmail.com",
38 | "packages": ["utrees"],
39 | "install_requires": [
40 | "kditransform",
41 | "numba >= 0.48",
42 | "numpy",
43 | "pandas",
44 | "scikit-learn >= 0.23",
45 | "scipy >= 1.0",
46 | "tqdm",
47 | "xgboost",
48 | ],
49 | "ext_modules": [],
50 | "cmdclass": {},
51 | "test_suite": "nose.collector",
52 | "tests_require": ["nose"],
53 | "data_files": (),
54 | "zip_safe": True,
55 | }
56 |
57 | setup(**configuration)
58 |
--------------------------------------------------------------------------------
/paper/baltobotmodel.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from typing import Optional
3 |
4 | from jaxtyping import Float
5 | from numpy import ndarray
6 | from sklearn.base import MultiOutputMixin
7 | from skopt.space import Categorical
8 | from skopt.space import Integer
9 | from skopt.space import Real
10 |
11 | import numpy as np
12 | import utrees
13 |
14 | from .base_model import ProbabilisticModel
15 |
16 |
17 | class BaltoBotModel(ProbabilisticModel, MultiOutputMixin):
18 | """
19 | Wrapping the BaltoBot model as a ProbabilisticModel.
20 | """
21 |
22 | def __init__(
23 | self,
24 | learning_rate: float = 0.3,
25 | max_leaves: int = 0,
26 | subsample: float = 1,
27 | seed: int = 0,
28 | ):
29 | super().__init__(seed)
30 | self.subsample = subsample
31 | self.max_leaves = max_leaves
32 | self.learning_rate = learning_rate
33 | self.model = None
34 |
35 | def fit(
36 | self,
37 | X: Float[ndarray, "batch x_dim"],
38 | y: Float[ndarray, "batch y_dim"],
39 | cat_idx: Optional[List[int]] = None,
40 | ) -> "ProbabilisticModel":
41 | self.model = utrees.Baltobot(
42 | tabpfn=True,
43 | random_state=self.seed,
44 | xgboost_kwargs={
45 | "subsample": self.subsample,
46 | "max_leaves": self.max_leaves,
47 | "learning_rate": self.learning_rate,
48 | },
49 | )
50 | assert y.shape[1] == 1
51 | self.model.fit(X, y.ravel())
52 | return self
53 |
54 | def predict(self, X: Float[ndarray, "batch x_dim"]) -> Float[ndarray, "batch y_dim"]:
55 | y_samples = self.sample(X, n_samples=50)
56 | y_samples = y_samples.mean(axis=0)
57 | return y_samples
58 |
59 | def sample(
60 | self, X: Float[ndarray, "batch x_dim"], n_samples=100, seed=None
61 | ) -> Float[ndarray, "n_samples batch y_dim"]:
62 | batch = X.shape[0]
63 | y_samples = np.zeros((n_samples, batch, 1))
64 | for n in range(n_samples):
65 | y_samples[n, :, 0] = self.model.sample(X)
66 | return y_samples
67 |
68 | @staticmethod
69 | def search_space() -> dict:
70 | return {
71 | "learning_rate": Real(0.05, 0.5, "log-uniform"),
72 | "max_leaves": Categorical([0, 25, 50]),
73 | "subsample": Real(0.3, 1., "log-uniform"),
74 | }
75 |
76 | def get_extra_stats(self) -> dict:
77 | return {}
78 |
--------------------------------------------------------------------------------
/Results/latex_imputation_baltobot4d8f71b.txt:
--------------------------------------------------------------------------------
1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
2 | % Mon Sep 2 13:56:23 2024
3 | \begin{table}[ht]
4 | \centering
5 | \begin{tabular}{rlllllllll}
6 | \hline
7 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
8 | \hline
9 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\
10 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\
11 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\
12 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\
13 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\
14 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\
15 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\
16 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\
17 | utrees & 0.08 (0.02) & 0.14 (0.03) & 0.37 (0.08) & 1.87 (0.48) & 0.27 (0.07) & 0.61 (0.1) & 0.76 (0.04) & 0.55 (0.19) & 0.71 (0.13) \\
18 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\
19 | \hline
20 | \end{tabular}
21 | \end{table}
22 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
23 | % Mon Sep 2 13:56:23 2024
24 | \begin{table}[ht]
25 | \centering
26 | \begin{tabular}{rlllllllll}
27 | \hline
28 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
29 | \hline
30 | KNN & 5.5 (0.5) & 6.3 (0.4) & 4.9 (0.4) & 5 (0.4) & 8.4 (0) & 6.5 (1) & 5.7 (1.1) & 6.2 (1) & 5.4 (0.6) \\
31 | ICE & 6.8 (0.4) & 4.7 (0.4) & 7 (0.5) & 7.2 (0.4) & 1.6 (0.2) & 6.2 (1) & 7 (0.6) & 5.7 (0.9) & 5.3 (0.6) \\
32 | MICE-Forest & 3.9 (0.4) & 2.5 (0.4) & 2.9 (0.2) & 3 (0.2) & 3.6 (0.2) & 3.7 (1.4) & 3.2 (1) & 5.5 (1.2) & 4.3 (0.6) \\
33 | MissForest & 2.7 (0.5) & 4 (0.4) & 1.8 (0.3) & 2 (0.3) & 5.5 (0.2) & 3.8 (1.4) & 2.5 (0.5) & 5.5 (1.5) & 3.3 (0.5) \\
34 | Softimpute & 6.7 (0.4) & 7.6 (0.4) & 7.1 (0.5) & 7.3 (0.5) & 8.4 (0) & 6 (0.9) & 7.8 (0.4) & 6.3 (0.9) & 6.7 (0.4) \\
35 | OT & 5.9 (0.4) & 6.1 (0.3) & 6 (0.5) & 6 (0.5) & 3.7 (0.3) & 6.2 (0.5) & 6.8 (0.6) & 5.5 (0.8) & 4.8 (0.5) \\
36 | GAIN & 4.7 (0.4) & 6.5 (0.3) & 6 (0.3) & 6 (0.2) & 6.9 (0.1) & 5.7 (0.8) & 5.4 (0.8) & 4.7 (1) & 5 (0.6) \\
37 | Forest-VP & 5.3 (0.4) & 4 (0.5) & 5.8 (0.3) & 5.1 (0.4) & 3.2 (0.4) & 4.5 (0.9) & 4.6 (0.8) & 3.3 (0.6) & 5.5 (0.7) \\
38 | utrees & 3.5 (0.5) & 3.2 (0.5) & 3.5 (0.4) & 3.5 (0.5) & 3.8 (0.2) & 2.5 (0.6) & 2.2 (0.6) & 2.3 (0.9) & 4.7 (0.6) \\
39 | \hline
40 | \end{tabular}
41 | \end{table}
--------------------------------------------------------------------------------
/Results/latex_generation_baltobot4d8f71b.txt:
--------------------------------------------------------------------------------
1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
2 | % Sat Sep 7 22:19:01 2024
3 | \begin{table}[ht]
4 | \centering
5 | \begin{tabular}{rlllllllll}
6 | \hline
7 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\
8 | \hline
9 | GaussianCopula & 2.74 (0.56) & 2.99 (0.61) & 0.18 (0.04) & 0.37 (0.06) & 0.2 (0.14) & 0.46 (0.06) & NA (0.05) & 2.27 (0.77) & 0.23 (0.12) \\
10 | TVAE & 2.12 (0.58) & 2.35 (0.63) & 0.33 (0.04) & 0.63 (0.04) & -0.47 (0.61) & 0.52 (0.08) & NA (0.01) & 4.15 (1.97) & 0.26 (0.09) \\
11 | CTGAN & 3.58 (0.99) & 3.74 (1.01) & 0.12 (0.03) & 0.28 (0.04) & -0.43 (0.08) & 0.35 (0.04) & NA (0.01) & 2.48 (1.3) & 0.2 (0.08) \\
12 | CTAB-GAN+ & 2.71 (0.81) & 2.89 (0.83) & 0.22 (0.04) & 0.44 (0.05) & 0.05 (0.12) & 0.44 (0.05) & NA (0.02) & 2.95 (1.04) & 0.26 (0.07) \\
13 | STaSy & 3.41 (1.39) & 3.66 (1.42) & 0.38 (0.05) & 0.63 (0.05) & -4.21 (4.44) & 0.61 (0.06) & NA (0.02) & 1.23 (0.44) & 0.45 (0.12) \\
14 | TabDDPM & 4.27 (1.89) & 4.79 (1.89) & 0.76 (0.06) & 0.8 (0.06) & 0.6 (0.11) & 0.66 (0.06) & NA (0.03) & 0.76 (0.28) & 0.72 (0.11) \\
15 | Forest-VP & 1.46 (0.4) & 1.94 (0.5) & 0.67 (0.05) & 0.84 (0.03) & 0.55 (0.1) & 0.73 (0.04) & NA (0.01) & 0.94 (0.3) & 0.52 (0.15) \\
16 | Forest-Flow & 1.36 (0.39) & 1.9 (0.5) & 0.83 (0.03) & 0.9 (0.03) & 0.57 (0.11) & 0.73 (0.04) & NA (0.01) & 0.83 (0.23) & 0.63 (0.11) \\
17 | UTrees & 1.54 (0.47) & 2 (0.55) & 0.73 (0.04) & 0.87 (0.02) & 0.49 (0.1) & 0.7 (0.04) & NA (NA) & 1.2 (0.28) & 0.4 (0.09) \\
18 | Oracle & 0 (0) & 1.81 (0.47) & 0.99 (0.01) & 0.91 (0.04) & 0.64 (0.09) & 0.77 (0.04) & NA (0) & 0 (0) & 1 (0) \\
19 | \hline
20 | \end{tabular}
21 | \end{table}
22 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
23 | % Sat Sep 7 22:19:01 2024
24 | \begin{table}[ht]
25 | \centering
26 | \begin{tabular}{rlllllllll}
27 | \hline
28 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\
29 | \hline
30 | GaussianCopula & 7.1 (0.3) & 7.2 (0.3) & 7.3 (0.3) & 7.4 (0.3) & 6.2 (0.2) & 6.4 (0.3) & 7 (0.4) & 6.5 (1.1) & 7.5 (0.7) \\
31 | TVAE & 5.3 (0.2) & 5.1 (0.2) & 5.7 (0.2) & 5.7 (0.2) & 6.5 (0.7) & 6 (0.5) & 5.5 (0.3) & 7.3 (0.6) & 6.7 (0.6) \\
32 | CTGAN & 8.4 (0.1) & 8.4 (0.2) & 8.3 (0.2) & 8.1 (0.2) & 8.5 (0.2) & 8.3 (0.2) & 6.7 (0.3) & 5.3 (1.1) & 7.2 (0.5) \\
33 | CTAB-GAN+ & 6.8 (0.3) & 6.7 (0.3) & 7.2 (0.3) & 7.1 (0.3) & 6.8 (0.4) & 6.9 (0.4) & 6.9 (0.3) & 7.7 (0.8) & 6.7 (0.8) \\
34 | STaSy & 6.1 (0.2) & 6.3 (0.2) & 5.3 (0.2) & 5.4 (0.2) & 6 (1.2) & 5.1 (0.3) & 6.1 (0.3) & 4.5 (0.8) & 4.2 (1.1) \\
35 | TabDDPM & 3 (0.7) & 3.9 (0.6) & 2.8 (0.5) & 3.4 (0.5) & 1.2 (0.2) & 3.8 (0.6) & 3.2 (0.4) & 3 (0.9) & 1.4 (0.2) \\
36 | Forest-VP & 3.2 (0.2) & 2.8 (0.2) & 3.6 (0.3) & 3.3 (0.3) & 2.8 (0.3) & 2.2 (0.3) & 4.3 (0.4) & 3.2 (0.9) & 3.5 (0.8) \\
37 | Forest-Flow & 1.9 (0.2) & 1.5 (0.2) & 1.7 (0.2) & 1.8 (0.2) & 2.3 (0.4) & 2.4 (0.3) & 4.3 (0.4) & 2.8 (0.5) & 2.7 (0.4) \\
38 | UTrees & 3.1 (0.1) & 3.1 (0.2) & 3.1 (0.2) & 2.8 (0.2) & 4.7 (0.3) & 3.9 (0.3) & 1 (0) & 4.7 (0.7) & 5.2 (0.9) \\
39 | \hline
40 | \end{tabular}
41 | \end{table}
--------------------------------------------------------------------------------
/Results/latex_imputation_ablation_baltobot4d8f71b.txt:
--------------------------------------------------------------------------------
1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
2 | % Tue Sep 3 11:40:24 2024
3 | \begin{table}[ht]
4 | \centering
5 | \begin{tabular}{rlllllllll}
6 | \hline
7 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
8 | \hline
9 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\
10 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\
11 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\
12 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\
13 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\
14 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\
15 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\
16 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\
17 | utrees-kmeans & 0.1 (0.02) & 0.15 (0.03) & 0.43 (0.09) & 1.9 (0.5) & 0.28 (0.06) & 0.61 (0.1) & 0.76 (0.04) & 0.63 (0.21) & 0.72 (0.13) \\
18 | utrees-kdi & 0.1 (0.02) & 0.14 (0.03) & 0.42 (0.09) & 1.89 (0.49) & 0.27 (0.06) & 0.61 (0.1) & 0.76 (0.04) & 0.68 (0.24) & 0.68 (0.14) \\
19 | utrees & 0.08 (0.02) & 0.14 (0.03) & 0.37 (0.08) & 1.87 (0.48) & 0.27 (0.07) & 0.61 (0.1) & 0.76 (0.04) & 0.55 (0.19) & 0.71 (0.13) \\
20 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\
21 | \hline
22 | \end{tabular}
23 | \end{table}
24 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
25 | % Tue Sep 3 11:40:24 2024
26 | \begin{table}[ht]
27 | \centering
28 | \begin{tabular}{rlllllllll}
29 | \hline
30 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
31 | \hline
32 | KNN & 6.8 (0.6) & 7.8 (0.6) & 6 (0.4) & 6.1 (0.5) & 10.4 (0) & 8.2 (1.3) & 7 (1.5) & 7.5 (1.5) & 6.5 (0.8) \\
33 | ICE & 8.3 (0.5) & 5.8 (0.5) & 8.5 (0.6) & 8.8 (0.5) & 1.9 (0.4) & 8 (1.1) & 9 (0.6) & 7.2 (1.1) & 6.4 (0.8) \\
34 | MICE-Forest & 4.8 (0.6) & 3.3 (0.6) & 3.5 (0.3) & 3.4 (0.3) & 4.6 (0.4) & 4.3 (1.8) & 4.3 (1.3) & 6.8 (1.6) & 4.8 (0.7) \\
35 | MissForest & 3.3 (0.7) & 5 (0.6) & 2.2 (0.4) & 2.3 (0.4) & 7.2 (0.3) & 4.7 (1.8) & 3.3 (0.9) & 6.8 (1.9) & 3.8 (0.6) \\
36 | Softimpute & 8.3 (0.5) & 9.3 (0.5) & 8.8 (0.6) & 8.9 (0.6) & 10.4 (0) & 7.5 (1.2) & 9.8 (0.4) & 8.3 (0.9) & 7.9 (0.6) \\
37 | OT & 7.2 (0.5) & 7.6 (0.4) & 7.4 (0.6) & 7.4 (0.6) & 4.8 (0.4) & 8.2 (0.5) & 8.8 (0.6) & 7.3 (0.7) & 5.8 (0.7) \\
38 | GAIN & 5.8 (0.5) & 8.3 (0.4) & 7.2 (0.5) & 7.5 (0.4) & 8.9 (0.1) & 7.5 (0.8) & 7.4 (0.8) & 6.7 (1) & 6.1 (0.8) \\
39 | Forest-VP & 6.4 (0.5) & 4.8 (0.6) & 7 (0.4) & 6.1 (0.5) & 3.8 (0.5) & 6.5 (0.9) & 6.6 (0.8) & 4.5 (0.8) & 6.5 (0.8) \\
40 | utrees-kmeans & 6 (0.6) & 5.8 (0.5) & 6.3 (0.6) & 6.1 (0.6) & 4.1 (0.3) & 4 (0.7) & 2.9 (0.6) & 3.8 (1) & 6 (0.7) \\
41 | utrees-kdi & 5.1 (0.5) & 5.1 (0.5) & 5.4 (0.6) & 5.6 (0.5) & 4.8 (0.3) & 4.5 (0.9) & 4 (0.5) & 3.5 (1.2) & 6.4 (0.7) \\
42 | utrees & 3.8 (0.5) & 3.2 (0.5) & 3.8 (0.4) & 3.8 (0.5) & 5 (0.3) & 2.7 (0.6) & 2.9 (0.8) & 3.5 (0.8) & 5.8 (0.7) \\
43 | \hline
44 | \end{tabular}
45 | \end{table}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # unmasking-trees 😷➡️🥳 🌲🌲🌲
2 |
3 | [](https://badge.fury.io/py/utrees)
4 | [](https://pepy.tech/project/utrees)
5 | [](https://github.com/psf/black)
6 |
7 | UnmaskingTrees is a method for tabular data generation and imputation. It's an order-agnostic autoregressive diffusion model, wherein a training dataset is constructed by incrementally masking features in random order. Per-feature gradient-boosted trees are then trained to unmask each feature. Read more about it in my [paper](https://arxiv.org/abs/2407.05593)!
8 |
9 | To better model conditional distributions which are multi-modal ("modal" as in "modes", not as in "modalities"), we hierarchically partition each feature, recursively training XGBoost classifiers at each node of the binary "meta-tree". This approach for conditional modeling of individual features, dubbed BaltoBot, outperforms quantile regression and diffusion-based probabilistic prediction. You can also customize quantization via the `quantize_cols` parameter in the `fit` method. Provide a list of length `n_dims`, with values in `('continuous', 'categorical', 'integer')`. Given `categorical` it currently skips quantization of that feature.
10 |
11 |
12 | Here's how well UnmaskingTrees works on imputation with the Two Moons synthetic dataset:
13 |
14 |
15 |
16 |
17 | Here's how well BaltoBot works on probabilistic prediction:
18 |
19 |
20 |
21 | ## Installation
22 |
23 | ### Installation from PyPI
24 | ```
25 | pip install utrees
26 | ```
27 |
28 | ### Installation from source
29 | After cloning this repo, install the dependencies on the command-line, then install utrees:
30 | ```
31 | pip install -r requirements.txt
32 | pip install -e .
33 | pytest
34 | ```
35 |
36 | ## Usage
37 |
38 | Check out [this notebook](https://github.com/calvinmccarter/unmasking-trees/blob/master/paper/moons.ipynb) with the Two Moons example, or [this one](https://github.com/calvinmccarter/unmasking-trees/blob/master/paper/iris.ipynb) with the Iris dataset.
39 |
40 | ### Synthetic data generation
41 |
42 | You can fit `utrees.UnmaskingTrees` the way you would an sklearn model, with the added option that you can call `fit` with `quantize_cols`, a list to specify which columns are continuous (and therefore need to be discretized). By default, all columns are assumed to contain continuous features.
43 |
44 | ```
45 | import numpy as np
46 | from sklearn.datasets import make_moons
47 | from utrees import UnmaskingTrees
48 | data, labels = make_moons((100, 100), shuffle=False, noise=0.1, random_state=123) # size (200, 2)
49 | utree = UnmaskingTrees().fit(data)
50 | ```
51 |
52 | Then, you can generate new data:
53 |
54 | ```
55 | newdata = utree.generate(n_generate=123) # size (123, 2)
56 | ```
57 |
58 | ### Missing data imputation
59 |
60 | You can fit your `UnmaskingTrees` model on data with missing elements, provided as `np.nan`. You can then impute the missing values, potentially with multiple imputations per missing element. Given an array of `(n_samples, n_dims)`, you will get back an array of size `(n_impute, n_samples, n_dims)`, where the NaNs have been replaced while the others are unchanged.
61 |
62 | ```
63 | data4impute = data.copy()
64 | data4impute[:, 1] = np.nan
65 | X=np.concatenate([data, data4impute], axis=0) # size (400, 2)
66 | utree = UnmaskingTrees().fit(X)
67 | imputeddata = utree.impute(n_impute=5) # size (5, 400, 2)
68 | ```
69 |
70 | You can also provide a totally new dataset to be imputed, so the model performs imputation without retraining:
71 |
72 | ```
73 | utree = UnmaskingTrees().fit(data)
74 | imputeddata = utree.impute(n_impute=5, X=data4impute) # size (5, 200, 2)
75 | ```
76 |
77 | ### Hyperparameters
78 |
79 | - depth: Depth of balanced binary tree for recursively quantizing each feature.
80 | - duplicate_K: Number of random masking orders per actual sample. The training dataset will be of size `(n_samples * n_dims * duplicate_K, n_dims)`.
81 | - xgboost_kwargs: dict to pass to XGBClassifier.
82 | - strategy: how to quantize continuous features ('kdiquantile', 'quantile', 'uniform', or 'kmeans').
83 | - random_state: controls randomness.
84 |
85 |
86 | ## Citing this method
87 |
88 | Please consider citing the UnmaskingTrees [arXiv preprint](https://arxiv.org/pdf/2407.05593). The bibtex is:
89 |
90 | ```
91 | @article{mccarter2024unmasking,
92 | title={Unmasking Trees for Tabular Data},
93 | author={McCarter, Calvin},
94 | journal={arXiv preprint arXiv:2407.05593},
95 | year={2024}
96 | }
97 | ````
98 |
99 | Also, please consider citing ForestDiffusion ([code](https://github.com/SamsungSAILMontreal/ForestDiffusion) and [paper](https://arxiv.org/abs/2309.09968)), which this work builds on.
100 |
--------------------------------------------------------------------------------
/tests/utrees/test_utrees.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import time
4 |
5 | from sklearn.datasets import (
6 | load_iris,
7 | make_moons,
8 | )
9 | from sklearn.neighbors import KernelDensity
10 |
11 | from utrees import UnmaskingTrees
12 |
13 | CREATE_PLOTS = False
14 |
15 | @pytest.mark.parametrize("strategy, depth, duplicate_K, min_score", [
16 | ("kdiquantile", 3, 10, -2.),
17 | ("kdiquantile", 3, 10, -2.),
18 | ("kdiquantile", 4, 10, -2.),
19 | ("kdiquantile", 4, 50, -2.),
20 | ("kdiquantile", 5, 50, -2.),
21 | ])
22 | def test_moons_generate(strategy, depth, duplicate_K, min_score):
23 | n_upper = 100
24 | n_lower = 100
25 | n_generate = 123
26 | n = n_upper + n_lower
27 | data, labels = make_moons(
28 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345)
29 |
30 | utree = UnmaskingTrees(
31 | depth=depth,
32 | strategy=strategy,
33 | duplicate_K=duplicate_K,
34 | random_state=12345,
35 | )
36 | utree.fit(data)
37 | newdata = utree.generate(n_generate=n_generate)
38 | if CREATE_PLOTS:
39 | import matplotlib.pyplot as plt
40 | plt.figure()
41 | plt.scatter(newdata[:, 0], newdata[:, 1]);
42 | plt.savefig(f'test_moons_generate-{strategy}-{depth}-{duplicate_K}.pdf')
43 | assert newdata.shape == (n_generate, 2)
44 |
45 | kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(data)
46 | scores = kde.score_samples(newdata)
47 | assert scores.mean() >= min_score
48 |
49 |
50 | @pytest.mark.parametrize("strategy, depth, duplicate_K, cast_float32, min_score, k", [
51 | ("kdiquantile", 3, 10, True, -1.5, 1),
52 | ("kdiquantile", 3, 10, True, -1.5, 1),
53 | ("kdiquantile", 4, 10, True, -1.5, 1),
54 | ("kdiquantile", 4, 50, True, -1.5, 1),
55 | ("kdiquantile", 5, 50, True, -1.5, 1),
56 | ("kdiquantile", 3, 10, False, -1.5, 1),
57 | ("kdiquantile", 3, 10, False, -1.5, 1),
58 | ("kdiquantile", 4, 10, False, -1.5, 1),
59 | ("kdiquantile", 4, 50, False, -1.5, 1),
60 | ("kdiquantile", 5, 50, False, -1.5, 1),
61 | ("kdiquantile", 3, 10, False, -1.5, 3),
62 | ("kdiquantile", 3, 10, False, -1.5, 3),
63 | ("kdiquantile", 4, 10, False, -1.5, 3),
64 | ("kdiquantile", 4, 50, False, -1.5, 3),
65 | ("kdiquantile", 5, 50, False, -1.5, 3),
66 | ])
67 | def test_moons_impute(strategy, depth, cast_float32, duplicate_K, min_score, k):
68 | n_upper = 100
69 | n_lower = 100
70 | n = n_upper + n_lower
71 | data, labels = make_moons(
72 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345)
73 | data4impute = data.copy()
74 | data4impute[:, 1] = np.nan
75 | X=np.concatenate([data, data4impute], axis=0)
76 |
77 | utree = UnmaskingTrees(
78 | depth=depth,
79 | strategy=strategy,
80 | duplicate_K=duplicate_K,
81 | cast_float32=cast_float32,
82 | random_state=12345,
83 | )
84 | utree.fit(X)
85 | kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(data)
86 |
87 | # Tests using fitted data
88 | imputedX = utree.impute(n_impute=k)
89 | assert imputedX.shape == (k, X.shape[0], X.shape[1])
90 | nannystate = ~np.isnan(X)
91 | for kk in range(k):
92 | imputedXcur = imputedX[kk, :, :]
93 | # Assert observed data is unchanged
94 | np.testing.assert_equal(imputedXcur[nannystate], X[nannystate])
95 | # Assert all NaNs removed
96 | assert np.isnan(imputedXcur).sum() == 0
97 | scores = kde.score_samples(imputedXcur)
98 | # Assert imputed data is close to observed data distribution
99 | assert scores.mean() >= min_score
100 | if CREATE_PLOTS:
101 | import matplotlib.pyplot as plt
102 | plt.figure()
103 | plt.scatter(imputedX[0, :, 0], imputedX[0, :, 1]);
104 | plt.savefig(f'test_moons_impute-{strategy}-{depth}-{duplicate_K}.pdf')
105 |
106 | # Tests providing data to impute
107 | imputedX = utree.impute(n_impute=k, X=data4impute)
108 | assert imputedX.shape == (k, data4impute.shape[0], data4impute.shape[1])
109 | nannystate = ~np.isnan(data4impute)
110 | for kk in range(k):
111 | imputedXcur = imputedX[kk, :, :]
112 |
113 | np.testing.assert_equal(imputedXcur[nannystate], data4impute[nannystate])
114 | scores = kde.score_samples(imputedXcur)
115 | assert scores.mean() >= min_score
116 |
117 | def test_moons_score_samples():
118 | n_upper = 100
119 | n_lower = 100
120 | n_generate = 123
121 | n = n_upper + n_lower
122 | data, labels = make_moons(
123 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345)
124 |
125 | utree = UnmaskingTrees(random_state=12345)
126 | utree.fit(data)
127 | scores = utree.score_samples(data)
128 | assert scores.shape == (200,)
129 | assert np.min(scores) > -0.5
130 |
131 | @pytest.mark.parametrize('target_type', [
132 | 'continuous', 'categorical', 'integer',
133 | ])
134 | def test_iris(target_type):
135 | my_data = load_iris()
136 | X, y = my_data['data'], my_data['target']
137 | n = X.shape[0]
138 | Xy = np.concatenate((X, np.expand_dims(y, axis=1)), axis=1)
139 | ut_model = UnmaskingTrees(random_state=12345)
140 | ut_model.fit(Xy, quantize_cols=['continuous']*4 + [target_type])
141 | Xy_gen_utrees = ut_model.generate(n_generate=n)
142 | col_names = my_data['feature_names'] + ['target_names']
143 | petalw = col_names.index('petal width (cm)')
144 | petall = col_names.index('petal length (cm)')
145 | assert np.unique(Xy_gen_utrees[:, -1]).size == 3
146 | assert np.unique(Xy_gen_utrees[:, petall]).size > n - 10
147 | assert 50 <= np.unique(Xy_gen_utrees[:, petalw]).size < 100
148 |
--------------------------------------------------------------------------------
/Results/tabular_imputation_results_kmeans.txt:
--------------------------------------------------------------------------------
1 | iris,MCAR(0.2),utrees,0.06436957134970737,0.09820029507565886,0.0,0.0,0.07406719488532511,0.2432489111314778,0.0035090903348105984,0.16882548182880674,0.14711389375938483,0.0,0.9533635056734473,5.240656932195027,0.0,0.9888721804511277,0.0,0.901246399258095,0.0,0.9657031467557784,0.0,0.9576322962287873
2 | wine,MCAR(0.2),utrees,0.10703117843764451,0.1500977475902249,0.0,0.0,0.40604224206690687,1.473187616004656,0.01009838423341393,0.2996230669137918,0.25479589314502804,0.0,0.9506248663768152,43.43113430341085,0.0,0.9471527074612197,0.0,0.8861589265469987,0.0,0.9910199172018928,0.0,0.9781679142971489
3 | parkinsons,MCAR(0.2),utrees,0.05551321452108013,0.08524537540547424,0.0,0.0,0.38426840866581696,1.7739918499212077,0.0065856889843212835,0.19988221995294644,0.16980011273351714,0.0,0.824004272401742,132.9267556667328,0.0,0.745174148812501,0.0,0.849369991297509,0.0,0.8452874989174725,0.0,0.856185450579485
4 | climate_model_crashes,MCAR(0.2),utrees,0.2445575415292093,0.3445494156494738,0.0,0.0,1.2426306605524053,3.8839612332455897,0.04650630572459653,0.814139369339031,0.7026100326117098,0.0,0.6776350950992771,275.9443356990814,0.0,0.8300706838056722,0.0,0.7040739514597695,0.0,0.4823199662304771,0.0,0.6940757789011901
5 | concrete_compression,MCAR(0.2),utrees,0.03210828842755,0.056903224276065374,139.9718029014977,0.22962962962962963,0.08964405779202678,0.5147330608069801,0.0040364825986187736,0.1363458859714353,0.11504088872391835,0.7501702370745574,0.0,115.43401980400085,0.5628067782571478,0.0,0.7370457763261031,0.0,0.8608975679029993,0.0,0.8399308258119796,0.0
6 | yacht_hydrodynamics,MCAR(0.2),utrees,0.047272415294858,0.09254175399891258,90.36945434969547,0.9809523809523808,0.09956587072417561,0.5353933813685045,0.007927911237284414,0.21188108168258518,0.18117387052298115,0.8961366668730334,0.0,14.487419525782265,0.6082034338573993,0.0,0.9851336619886606,0.0,0.9958894477927889,0.0,0.9953201238532842,0.0
7 | airfoil_self_noise,MCAR(0.2),utrees,0.037140222355384125,0.083416395329536,4.063234761149441,1.0,0.06861066115715818,0.2632849337812995,0.00906505128438473,0.2407473071657216,0.21101632145252772,0.726789378235716,0.0,61.72480670611064,0.5073497692761099,0.0,0.6260807273477592,0.0,0.8835357178295489,0.0,0.8901912984894461,0.0
8 | connectionist_bench_sonar,MCAR(0.2),utrees,0.12594510082585392,0.1575817200024125,0.0,0.0,1.9018403823025247,8.832364115331467,0.01524932730916776,0.36679856139287836,0.31067455387346066,0.0,0.7934502116490639,1217.6911363601685,0.0,0.7189555311397549,0.0,0.8081375928343585,0.0,0.8141292146446422,0.0,0.8325785079775009
9 | ionosphere,MCAR(0.2),utrees,0.09313030024983923,0.12300146406326612,0.0,0.0,0.8143947765317358,4.441172261619947,0.011246970278524575,0.2295432671127555,0.18978076793547133,0.0,0.8963486608285675,542.9356789588928,0.0,0.8371620356469938,0.0,0.9124253626300302,0.0,0.9388813305587143,0.0,0.8969259144785321
10 | qsar_biodegradation,MCAR(0.2),utrees,0.019527684473879602,0.027992824231103,0.0,0.0,0.23001487701993542,1.3982340329151852,0.001090018369507807,0.05588821862156047,0.04755009522818494,0.0,0.8446301860513109,1721.6118144194284,0.0,0.8467325146932341,0.0,0.8306542562137806,0.0,0.8528585746608389,0.0,0.8482753986373902
11 | seeds,MCAR(0.2),utrees,0.0615945639620859,0.09729499079999612,0.0,0.0,0.14014201616568808,0.4889822967496476,0.004909057215034096,0.20851218941450503,0.18024954813653288,0.0,0.8759684438048736,16.49672166506449,0.0,0.9312245285578619,0.0,0.7145206838096979,0.0,0.9206997637526373,0.0,0.9374287990992972
12 | glass,MCAR(0.2),utrees,0.056683279899008815,0.08418955377322465,0.0,0.0,0.15643328511219365,0.6629243709553694,0.005207425246596656,0.14701144249299913,0.12408580449687279,0.0,0.5400166896494403,27.309990644454956,0.0,0.5740018743324092,0.0,0.31148857904273825,0.0,0.6503503977514334,0.0,0.62422590747118
13 | ecoli,MCAR(0.2),utrees,0.07824011385064211,0.13473034411769022,0.0,0.0,0.19019169962553592,0.479811891076667,0.007038748092204368,0.26565302594552415,0.2363700591455002,0.0,0.6791062196038589,21.325106700261433,0.0,0.8008041093956577,0.0,0.39762456661176226,0.0,0.8004187478580278,0.0,0.7175774545499879
14 | yeast,MCAR(0.2),utrees,0.06941257417556176,0.12116565147613076,0.0,0.0,0.18620507065740766,0.3962720433056949,0.005375955083372073,0.2540420088841035,0.2300358190571259,0.0,0.4441336024074725,123.21836757659912,0.0,0.5279781488912345,0.0,0.24764316643869705,0.0,0.5049291325437661,0.0,0.4959839617561922
15 | libras,MCAR(0.2),utrees,0.04337659174664109,0.054188700617526744,0.0,0.0,0.9781172408544512,9.087276475583508,0.002856839376379127,0.14195162056368987,0.12417461862538949,0.0,0.5668127539679112,5551.839271465938,0.0,0.707391694725028,0.0,0.046151433972063276,0.0,0.8003763298429966,0.0,0.7133315573315573
16 | planning_relax,MCAR(0.2),utrees,0.10076780027173504,0.14972397424144582,0.0,0.0,0.37768020618544945,1.497770661672028,0.011499330547311781,0.3311062364520646,0.28206416197506373,0.0,0.45340396658398313,39.88673202196757,0.0,0.4101516792156664,0.0,0.498282297881215,0.0,0.4314105982187582,0.0,0.4737712910202928
17 | blood_transfusion,MCAR(0.2),utrees,0.041859507915742486,0.07342162483104794,0.0,0.0,0.03713507439059031,0.1120434048586678,0.004067193066445469,0.15581042388651056,0.13510836072561155,0.0,0.5932689994790394,22.858375310897827,0.0,0.5587901528296861,0.0,0.6208276889233224,0.0,0.5952800168363233,0.0,0.5981781393268258
18 | breast_cancer_diagnostic,MCAR(0.2),utrees,0.04528500136927905,0.059025866114478055,0.0,0.0,0.3547227648437304,1.8643734784330732,0.0016410376729804662,0.11556167150646471,0.10212184999736261,0.0,0.9585506174839166,749.6472186247507,0.0,0.9559701297680048,0.0,0.9570539836293582,0.0,0.9570414965792896,0.0,0.9641368599590135
19 | connectionist_bench_vowel,MCAR(0.2),utrees,0.061398164895609314,0.1039307212244072,0.0,0.0,0.20811034421411279,0.7362538908543316,0.005596639823743729,0.2511663316487918,0.2201338332636656,0.0,0.6648635771081516,159.47676269213358,0.0,0.6416428022305954,0.0,0.2316282014085894,0.0,0.9161645012332263,0.0,0.8700188035601953
20 | concrete_slump,MCAR(0.2),utrees,0.1333530760953803,0.2087869919401345,47.93493627427608,0.7416666666666666,0.2918570338779636,1.1616849899114061,0.020308189548916368,0.41754187877797,0.3449666588250957,0.6617184542749365,0.0,9.68403736750285,0.7671580550398038,0.0,0.6267470830753135,0.0,0.6622900928791725,0.0,0.590678586105456,0.0
21 | wine_quality_red,MCAR(0.2),utrees,0.043987667301618495,0.07125263601096245,34.704274950328255,1.0,0.1413180725084123,0.517063871726785,0.00269422418878734,0.1602979380994121,0.14191509671282354,0.30131984599606504,0.0,237.22118894259137,0.31614333412554846,0.0,0.28376551757830254,0.0,0.36190006000046004,0.0,0.2434704722799491,0.0
22 | wine_quality_white,MCAR(0.2),utrees,0.036103445695891,0.06268985046227243,84.75792080373398,0.5555555555555555,0.13722323358222618,0.4512826541091945,0.002325762946555635,0.1625876365418874,0.1462748590117181,0.33976933748992155,0.0,922.8720918496449,0.26866062995313283,0.0,0.2833327598765625,0.0,0.4323614905194416,0.0,0.3747224696105493,0.0
23 | california,MCAR(0.2),utrees,0.02069658137403717,0.05039244339583727,26.527604512665064,0.5259259259259259,0.0,0.0,0.0044263438785602975,0.152781496744373,0.13628165960935112,0.6474823603570286,0.0,2793.5273619492846,0.5902896683351077,0.0,0.4149105745891897,0.0,0.7815453463555943,0.0,0.8031838521482231,0.0
24 | bean,MCAR(0.2),utrees,0.014279476432662424,0.026375358712360665,0.0,0.0,0.0,0.0,0.0010325532628770013,0.08178814262320773,0.07467284322933423,0.0,0.785844029183105,6019.071438709895,0.0,0.7804254388322259,0.0,0.49402300709239133,0.0,0.9333896534836575,0.0,0.935538017324145
25 | tictactoe,MCAR(0.2),utrees,0.3302788913404004,0.5047493883573801,0.0,0.0,0.7706701479547405,1.923353640847097,0.019924972355421657,1.1954983413252993,0.9192553211800524,0.0,0.8260940741613819,24.388316075007122,0.0,0.7628063445625028,0.0,0.7088680436660513,0.0,0.900034114313759,0.0,0.932667794103214
26 | congress,MCAR(0.2),utrees,0.1890438211894098,0.2666739760698372,0.0,0.0,0.7869731800766308,2.359770114942525,0.007118368287168615,0.4271020972301169,0.3176210020716755,0.0,0.9343496695461981,28.509195009867355,0.0,0.9181753899593685,0.0,0.9395054129002633,0.0,0.9437576372605899,0.0,0.9359602380645711
27 | car,MCAR(0.2),utrees,0.45735025839689725,0.6895113519508365,0.0,0.0,0.5051616015436621,1.0813146733811403,0.02605064368622248,1.8221719722572072,1.4401332127730782,0.0,0.819537784685739,29.95012299219767,0.0,0.8518452545343415,0.0,0.7390913722188874,0.0,0.8187829437640975,0.0,0.8684315682256301
28 |
--------------------------------------------------------------------------------
/Results/tabular_imputation_results_tobt.txt:
--------------------------------------------------------------------------------
1 | 271,iris,MCAR(0.2),utrees,0.0738508822640819,0.08621383344459618,0.0,0.0,0.0641272750008991,0.23875198098882847,0.0010849252155242946,0.060747518260253,0.050638420184983594,0.0,0.9592957351290684,6.241593678792317,0.0,0.9844054580896685,0.0,0.9208420883859479,0.0,0.9719632414369256,0.0,0.9599721526037313
2 | 272,wine,MCAR(0.2),utrees,0.10847833071435786,0.1273898930976937,0.0,0.0,0.34477782264980283,1.436851201603759,0.0022681103341390227,0.1234763420665328,0.10114861403442213,0.0,0.9500534430124237,31.166414578755692,0.0,0.9395602586713697,0.0,0.8874777130686393,0.0,0.9913154133784308,0.0,0.9818603869312548
3 | 273,parkinsons,MCAR(0.2),utrees,0.05216119916570852,0.06282486691810032,0.0,0.0,0.2829609680941807,1.7057565523311848,0.0011464499359240465,0.06601010209647172,0.05496063290496879,0.0,0.8197209207924157,67.40996813774109,0.0,0.7448269072722554,0.0,0.8205556286221438,0.0,0.8631517929896829,0.0,0.8503493542855808
4 | 274,climate_model_crashes,MCAR(0.2),utrees,0.2808700271957786,0.33931437568865086,0.0,0.0,1.2236978336740576,3.8843171348144097,0.01981178122553466,0.41717855463271586,0.3435656973505204,0.0,0.732257043439506,122.85987591743469,0.0,0.825510951198097,0.0,0.7657393075515603,0.0,0.5183956256861775,0.0,0.8193822893221895
5 | 275,concrete_compression,MCAR(0.2),utrees,0.03208937400224214,0.0449843462357446,95.02280181221532,0.3555555555555555,0.06950370247361591,0.5018743805043794,0.001868359858469494,0.06612196061508667,0.05437818351022399,0.7585029515684282,0.0,57.00657558441162,0.56886904606764,0.0,0.7410657368387151,0.0,0.8713681358466046,0.0,0.8527088875207527,0.0
6 | 276,yacht_hydrodynamics,MCAR(0.2),utrees,0.043468972067696544,0.055231025220134546,24.99998836102717,1.0,0.05533555453165772,0.5070570594782794,0.0038244437942096835,0.06134887096816548,0.04755683277595425,0.8965889191222446,0.0,10.832573652267456,0.6041882722010656,0.0,0.9897168134399937,0.0,0.9966183190506115,0.0,0.9958322717973075,0.0
7 | 277,airfoil_self_noise,MCAR(0.2),utrees,0.03601927283689502,0.056442182219915434,3.602237261278865,1.0,0.04072949157540118,0.248333447015149,0.004789185084037393,0.09646590551563475,0.07741676024143188,0.7259813075249679,0.0,37.661417404810585,0.5089654463051735,0.0,0.6209385469506234,0.0,0.8853326392820764,0.0,0.888688597561998,0.0
8 | 278,connectionist_bench_sonar,MCAR(0.2),utrees,0.10325464112454648,0.11572714409057966,0.0,0.0,1.396946930870398,8.490695914502343,0.0024130064258775265,0.12086357229675992,0.09973393051247184,0.0,0.7961100547917862,484.8436369101206,0.0,0.7305667860775371,0.0,0.7917005749057889,0.0,0.826771736701832,0.0,0.8354011214819863
9 | 279,ionosphere,MCAR(0.2),utrees,0.09274800547354099,0.11362675052595668,0.0,0.0,0.7530395291283524,4.428334329991218,0.007861143042296006,0.14133471175286777,0.11219339961349425,0.0,0.9079339126392311,226.5654911994934,0.0,0.8704189158504423,0.0,0.913392839991736,0.0,0.9347812384932109,0.0,0.9131426562215345
10 | 280,qsar_biodegradation,MCAR(0.2),utrees,0.016950361470988242,0.021991270274625278,0.0,0.0,0.18070410987453125,1.3799801507746234,0.0007616428769309639,0.029576808224069363,0.02395775645070889,0.0,0.8464247159723208,607.7569101651509,0.0,0.8493412113767648,0.0,0.8298656606493132,0.0,0.8522279272923382,0.0,0.8542640645708677
11 | 281,seeds,MCAR(0.2),utrees,0.06558186398759239,0.0818383717537164,0.0,0.0,0.11806093655638271,0.4758862336495159,0.0012582765118176159,0.08697161235318307,0.07302355757537271,0.0,0.8815366461629893,16.37628761927287,0.0,0.9413927210862075,0.0,0.7207285702837521,0.0,0.9292578215030322,0.0,0.9347674717789659
12 | 282,glass,MCAR(0.2),utrees,0.062050774465719924,0.07363746271541312,0.0,0.0,0.13664248771775228,0.6396874281780643,0.0018520039435309709,0.06659311935922695,0.054469169747801305,0.0,0.5200758571660892,19.19480045636495,0.0,0.5561984171683729,0.0,0.3218814414073257,0.0,0.6756645735935488,0.0,0.5265589964951093
13 | 283,ecoli,MCAR(0.2),utrees,0.06453228583324466,0.07761285744910246,0.0,0.0,0.1068919956390921,0.401773448510927,0.0011357623617926616,0.06588892321651871,0.054906589233967906,0.0,0.6939790972544977,17.368636290232338,0.0,0.7412225128742794,0.0,0.43088376118385335,0.0,0.8040077940861244,0.0,0.7998023208737332
14 | 284,yeast,MCAR(0.2),utrees,0.05895051406248822,0.06929448120458553,0.0,0.0,0.10311086536653524,0.3164615766358421,0.0009136845133219293,0.059834683194831914,0.049196346463218736,0.0,0.4443877944655975,72.2295389175415,0.0,0.5457200855194014,0.0,0.23381158819953407,0.0,0.5280076182562742,0.0,0.470011885887181
15 | 285,libras,MCAR(0.2),utrees,0.030900635362958416,0.03501986248736869,0.0,0.0,0.6320113064928582,8.927761073481095,0.0004123027700454415,0.05310431815487,0.046219991058215906,0.0,0.5763398196282307,2609.6879085699716,0.0,0.7266957363624029,0.0,0.0656637197506967,0.0,0.8029773387106721,0.0,0.7100224836891503
16 | 286,planning_relax,MCAR(0.2),utrees,0.09702467095872941,0.11677807762938469,0.0,0.0,0.294182350854369,1.4490622094955194,0.0024094633661303623,0.12394061275131349,0.10278404304613803,0.0,0.4656162078089963,29.979744831720986,0.0,0.41686890777319113,0.0,0.5128104109771937,0.0,0.4468584026329554,0.0,0.4859271098526448
17 | 287,blood_transfusion,MCAR(0.2),utrees,0.05250419832506664,0.06273627296215434,0.0,0.0,0.03417193597069273,0.11108166737525997,0.0009944133374911643,0.05049024576996288,0.04146980880327056,0.0,0.5909121710221329,15.97951825459798,0.0,0.5517558674610492,0.0,0.6136836696925864,0.0,0.5923459245005579,0.0,0.6058632224343383
18 | 288,breast_cancer_diagnostic,MCAR(0.2),utrees,0.04248342287389557,0.049991474574389065,0.0,0.0,0.30053080575196994,1.8403605511797796,0.0005091639620044048,0.05575907086130531,0.047511666583858865,0.0,0.9559513116989217,350.96746794382733,0.0,0.9565720466705175,0.0,0.950461040184128,0.0,0.9552653152164161,0.0,0.9615068447246249
19 | 289,connectionist_bench_vowel,MCAR(0.2),utrees,0.06443732869709562,0.08390184593288284,0.0,0.0,0.16803464619838576,0.7150281660352723,0.0019676379777712324,0.11290896288770606,0.09527958048799712,0.0,0.6612800914403835,105.91338030497232,0.0,0.6372764684188381,0.0,0.2091374133599314,0.0,0.9306224975774907,0.0,0.868083986405273
20 | 290,concrete_slump,MCAR(0.2),utrees,0.14946281300454173,0.17743712288363436,50.530740741392904,0.6083333333333333,0.24881000910566675,1.1559990504226032,0.005047817126113788,0.14313051571169366,0.11337670728696354,0.6813521957189256,0.0,12.78035036722819,0.739561590918806,0.0,0.630150174247179,0.0,0.692742719426452,0.0,0.6629542982832657,0.0
21 | 291,wine_quality_red,MCAR(0.2),utrees,0.05145864317084543,0.06453161345110682,48.04485802828075,1.0,0.1277755226999056,0.5127004214844084,0.0010588208629251602,0.0743885786222127,0.062241781160318296,0.3121816124810164,0.0,125.84174267450967,0.3161742440812635,0.0,0.29229668834595846,0.0,0.3695386514219234,0.0,0.27071686607492024,0.0
22 | 292,wine_quality_white,MCAR(0.2),utrees,0.042593903786731196,0.05505373728025714,74.43341369338839,0.5111111111111111,0.12074880247067918,0.44595229298284883,0.0008893732992206452,0.0717649368850878,0.060195874296490495,0.3360038411750484,0.0,450.17718839645386,0.26787301686659737,0.0,0.279956460830806,0.0,0.4362087549627255,0.0,0.35997713204006476,0.0
23 | 293,california,MCAR(0.2),utrees,0.030580305310313076,0.042526266610077086,20.09642335364787,0.49629629629629635,0.0,0.0,0.0014429541654798791,0.05886321962908722,0.04926780789859086,0.654840243501494,0.0,1278.6375177701314,0.5915782908304531,0.0,0.43503356451736247,0.0,0.7823153738824907,0.0,0.8104337447756698,0.0
24 | 294,bean,MCAR(0.2),utrees,0.013745072140138138,0.01936528961128679,0.0,0.0,0.0,0.0,0.00032362892341678007,0.036293178330078155,0.03214980747629488,0.0,0.7827055281793264,2614.430075327555,0.0,0.7895514158594643,0.0,0.47266482176976476,0.0,0.9335693470438311,0.0,0.9350365280442451
25 | 295,tictactoe,MCAR(0.2),utrees,0.2934076415933544,0.5098007688248267,0.0,0.0,0.7750217580504756,1.923804213809092,0.02449292020035092,1.4695752120210548,1.1344409107142162,0.0,0.8228683601865097,30.083193381627403,0.0,0.7564519924401047,0.0,0.7051401013935286,0.0,0.8980469741728717,0.0,0.9318343727395332
26 | 296,congress,MCAR(0.2),utrees,0.17090379339869008,0.2739610466703952,0.0,0.0,0.7969348659003855,2.373180076628348,0.009687297604583728,0.5812378562750236,0.43111389253746524,0.0,0.9348000068108923,34.59704256057739,0.0,0.918716692128598,0.0,0.9357333401867587,0.0,0.9488637868680861,0.0,0.935886208060126
27 | 297,car,MCAR(0.2),utrees,0.4012625479200538,0.6826690498634778,0.0,0.0,0.4819585142305888,1.0697378070373702,0.029454461974252055,2.0602595278022156,1.6405766453740909,0.0,0.7836670401499316,37.12033700942993,0.0,0.8325497237184465,0.0,0.6373154415056159,0.0,0.8105739604491329,0.0,0.8542290349265308
28 |
--------------------------------------------------------------------------------
/Results/tabular_imputation_results.txt:
--------------------------------------------------------------------------------
1 | iris , MCAR(0.2) , utrees , 0.06171134900290921 , 0.09286773754928909 , 0.0 , 0.0 , 0.06980779208322663 , 0.2397118967709102 , 0.0030206573130502293 , 0.15978792534668954 , 0.13658212292576377 , 0.0 , 0.9422763803480664 , 5.160777409871419 , 0.0 , 0.9866332497911444 , 0.0 , 0.8702625553914046 , 0.0 , 0.9634663088698177 , 0.0 , 0.9487434073398984
2 | wine , MCAR(0.2) , utrees , 0.10133481778238645 , 0.1439122311329542 , 0.0 , 0.0 , 0.3895478278301316 , 1.4645430063143305 , 0.008630066798274041 , 0.2829484754161229 , 0.23908105735580445 , 0.0 , 0.9387301795008196 , 46.97041702270508 , 0.0 , 0.9424267274641209 , 0.0 , 0.8394181522214272 , 0.0 , 0.9920429711070391 , 0.0 , 0.9810328672106903
3 | parkinsons , MCAR(0.2) , utrees , 0.052245517775528584 , 0.07630051184397803 , 0.0 , 0.0 , 0.3425923785306587 , 1.7516166306883203 , 0.004681727732015078 , 0.17334971838062277 , 0.14844598117147428 , 0.0 , 0.8151910426840616 , 140.84079313278198 , 0.0 , 0.7497556842583829 , 0.0 , 0.8285637760247496 , 0.0 , 0.8434252258261432 , 0.0 , 0.8390194846269702
4 | climate_model_crashes , MCAR(0.2) , utrees , 0.25313482958818423 , 0.36388498949116616 , 0.0 , 0.0 , 1.3125855770026345 , 3.9308358361374527 , 0.058244783820290275 , 0.9247016041837531 , 0.7911253027692842 , 0.0 , 0.6687687876893608 , 303.72695573170984 , 0.0 , 0.7815931202543376 , 0.0 , 0.715044914544767 , 0.0 , 0.485082313212326 , 0.0 , 0.6933548027460131
5 | concrete_compression , MCAR(0.2) , utrees , 0.032205947195894674 , 0.05695862771138322 , 144.7204580042744 , 0.2148148148148148 , 0.08956094803214486 , 0.51587579790969 , 0.0039197280163505745 , 0.13092247924677985 , 0.10971847972639154 , 0.7506571899571826 , 0.0 , 110.7270872592926 , 0.5635045538284041 , 0.0 , 0.7426887614177372 , 0.0 , 0.8593832855583278 , 0.0 , 0.8370521590242612 , 0.0
6 | yacht_hydrodynamics , MCAR(0.2) , utrees , 0.04600778874007024 , 0.08016372314818775 , 130.04117264988247 , 0.8952380952380953 , 0.08194873657018106 , 0.5212796431105499 , 0.006810609756057632 , 0.16949249963483226 , 0.13837548506080527 , 0.8959975452512271 , 0.0 , 12.314840714136759 , 0.6079608300788333 , 0.0 , 0.9850257618973024 , 0.0 , 0.995870603583432 , 0.0 , 0.9951329854453406 , 0.0
7 | airfoil_self_noise , MCAR(0.2) , utrees , 0.030701968468046875 , 0.07149217526283297 , 4.426930062522432 , 1.0 , 0.05527627319305399 , 0.255348994194407 , 0.0087494171344429 , 0.20225251607348688 , 0.1726012568420425 , 0.7288531573629374 , 0.0 , 61.05456566810608 , 0.5081040017180197 , 0.0 , 0.6304988371229822 , 0.0 , 0.883136433947107 , 0.0 , 0.8936733566636403 , 0.0
8 | connectionist_bench_sonar , MCAR(0.2) , utrees , 0.11712002366199573 , 0.14441412477065826 , 0.0 , 0.0 , 1.7433453859125134 , 8.675589836023347 , 0.010292110372589414 , 0.31071706160493046 , 0.2646020070218156 , 0.0 , 0.789587435581842 , 1221.7377189000447 , 0.0 , 0.7365130837758166 , 0.0 , 0.7853301364692437 , 0.0 , 0.8107724273701802 , 0.0 , 0.8257340947121282
9 | ionosphere , MCAR(0.2) , utrees , 0.09598313460475286 , 0.12398897668765305 , 0.0 , 0.0 , 0.8212715075994221 , 4.452108646183872 , 0.010532593079086822 , 0.22607161765535344 , 0.1872299367967365 , 0.0 , 0.8971300518445046 , 544.0828011035919 , 0.0 , 0.8335572700966875 , 0.0 , 0.912945819294007 , 0.0 , 0.937803996414423 , 0.0 , 0.9042131215729008
10 | qsar_biodegradation , MCAR(0.2) , utrees , 0.018528342646379738 , 0.026946503392505854 , 0.0 , 0.0 , 0.22144208357565423 , 1.3998370122598578 , 0.0012175729510766418 , 0.055858663558376945 , 0.047125598729442726 , 0.0 , 0.8440079997743205 , 1700.7666370868683 , 0.0 , 0.8492198286276188 , 0.0 , 0.8287018939881141 , 0.0 , 0.8524921264445106 , 0.0 , 0.845618150037039
11 | seeds , MCAR(0.2) , utrees , 0.06156737632234309 , 0.09532838023661071 , 0.0 , 0.0 , 0.13727812015204705 , 0.4818844341111245 , 0.0041747195976600746 , 0.19172721995362418 , 0.16607901238564932 , 0.0 , 0.8876327743360164 , 16.02963876724243 , 0.0 , 0.9392051110518542 , 0.0 , 0.7593816864030843 , 0.0 , 0.917707799905501 , 0.0 , 0.9342364999836262
12 | glass , MCAR(0.2) , utrees , 0.05690890150540151 , 0.08010917130462747 , 0.0 , 0.0 , 0.14928190914582376 , 0.6556699224704824 , 0.004381661572533172 , 0.1302607772932201 , 0.11177271760603738 , 0.0 , 0.5480099780432925 , 26.026296854019165 , 0.0 , 0.584371607108248 , 0.0 , 0.3177680015479405 , 0.0 , 0.6578281866630086 , 0.0 , 0.6320721168539731
13 | ecoli , MCAR(0.2) , utrees , 0.05629358048089226 , 0.08645477169671197 , 0.0 , 0.0 , 0.11921348698293545 , 0.4096222068800511 , 0.004157091351302514 , 0.1608139108009185 , 0.13803921739849476 , 0.0 , 0.7018928975330095 , 23.58275763193766 , 0.0 , 0.7876587322107141 , 0.0 , 0.438941456261841 , 0.0 , 0.799616636930169 , 0.0 , 0.7813547647293144
14 | yeast , MCAR(0.2) , utrees , 0.04535097744322566 , 0.07376880404694568 , 0.0 , 0.0 , 0.10812830155357582 , 0.31949312377513595 , 0.0035031547260809034 , 0.1568094959224052 , 0.13619697159380023 , 0.0 , 0.4573822425410245 , 135.45259356498718 , 0.0 , 0.5372650246353922 , 0.0 , 0.24436352760113153 , 0.0 , 0.5434300184509417 , 0.0 , 0.5044703994766325
15 | libras , MCAR(0.2) , utrees , 0.04227497283897813 , 0.052300055087484684 , 0.0 , 0.0 , 0.9441110249540525 , 9.064625281197733 , 0.002386057977089709 , 0.1351752676790998 , 0.11785024852303577 , 0.0 , 0.57221071121062 , 5732.288067579269 , 0.0 , 0.7167657034323701 , 0.0 , 0.054326598126233104 , 0.0 , 0.8033840825840824 , 0.0 , 0.714366460699794
16 | planning_relax , MCAR(0.2) , utrees , 0.09274661933098485 , 0.14182509487282988 , 0.0 , 0.0 , 0.3577914833040218 , 1.4829015968756003 , 0.010863641129838425 , 0.33208213992992913 , 0.2843774348922432 , 0.0 , 0.45357795616942376 , 39.74175707499186 , 0.0 , 0.40884725474889433 , 0.0 , 0.4809160683691389 , 0.0 , 0.4311217822489055 , 0.0 , 0.49342671931075655
17 | blood_transfusion , MCAR(0.2) , utrees , 0.040394437501713114 , 0.07071949375830425 , 0.0 , 0.0 , 0.0353891109851724 , 0.11301666799401906 , 0.00436917758066595 , 0.15255551896399422 , 0.13027972107616098 , 0.0 , 0.5891788320696434 , 20.771088282267254 , 0.0 , 0.551163926355941 , 0.0 , 0.6263315792365312 , 0.0 , 0.5861426479620243 , 0.0 , 0.5930771747240771
18 | breast_cancer_diagnostic , MCAR(0.2) , utrees , 0.04347497878768007 , 0.05818701882778201 , 0.0 , 0.0 , 0.34965808120004954 , 1.8694501937914967 , 0.0020290733719139677 , 0.12244840244125642 , 0.10732914690076638 , 0.0 , 0.9563682040026953 , 790.5507067044575 , 0.0 , 0.9552594974942447 , 0.0 , 0.9559410899580288 , 0.0 , 0.9570062408277437 , 0.0 , 0.9572659877307639
19 | connectionist_bench_vowel , MCAR(0.2) , utrees , 0.05976956621841621 , 0.10382249912440761 , 0.0 , 0.0 , 0.20792119393041208 , 0.737387697709447 , 0.006410373126504616 , 0.26627932092860485 , 0.2320615991149993 , 0.0 , 0.6628657357921165 , 162.6827143828074 , 0.0 , 0.6420072805560552 , 0.0 , 0.22917844837129117 , 0.0 , 0.9198757522587209 , 0.0 , 0.8604014619823989
20 | concrete_slump , MCAR(0.2) , utrees , 0.13336506913435178 , 0.21163758490490042 , 78.37393727464351 , 0.7166666666666666 , 0.296686493798486 , 1.1753105280747134 , 0.020353934344685538 , 0.41736394686933864 , 0.3433281455096098 , 0.6495419122376092 , 0.0 , 8.064836661020914 , 0.7383579877825754 , 0.0 , 0.6537969299030401 , 0.0 , 0.6491909134165881 , 0.0 , 0.556821817848233 , 0.0
21 | wine_quality_red , MCAR(0.2) , utrees , 0.04184544842818448 , 0.0721537184045824 , 22.31912314809874 , 0.9939393939393939 , 0.14254835041317498 , 0.5198545898986846 , 0.003952135412439399 , 0.18018191797690808 , 0.15692197788181259 , 0.29580732210337435 , 0.0 , 261.72702566782635 , 0.3173300368121892 , 0.0 , 0.28245331130378315 , 0.0 , 0.3492149836631822 , 0.0 , 0.2342309566343427 , 0.0
22 | wine_quality_white , MCAR(0.2) , utrees , 0.03527853218161474 , 0.06364601037858153 , 71.46802560034284 , 0.5 , 0.13879125826315522 , 0.4537299873404477 , 0.0033708521907209153 , 0.1781443095298026 , 0.15794928677256942 , 0.33958677441484636 , 0.0 , 1023.7504643599192 , 0.26957257718077343 , 0.0 , 0.28524000983461484 , 0.0 , 0.43435310087766693 , 0.0 , 0.3691814097663301 , 0.0
23 | california , MCAR(0.2) , utrees , 0.02085733850254894 , 0.052293952522410515 , 33.24752259308236 , 0.4518518518518519 , 0.0 , 0.0 , 0.005275218897392506 , 0.16294639904729727 , 0.1442791493561823 , 0.6562729852476104 , 0.0 , 3039.9609892368317 , 0.5865712054822544 , 0.0 , 0.4482862136463962 , 0.0 , 0.7846209767129884 , 0.0 , 0.8056135451488026 , 0.0
24 | bean , MCAR(0.2) , utrees , 0.013300777074474399 , 0.026065942643214362 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0015598047749835822 , 0.08307973988602241 , 0.07537123904968326 , 0.0 , 0.7889339055492804 , 6546.674813985825 , 0.0 , 0.7696592871750925 , 0.0 , 0.5184820537374906 , 0.0 , 0.9330313282932579 , 0.0 , 0.9345629529912808
25 | tictactoe , MCAR(0.2) , utrees , 0.3302788913404004 , 0.5047493883573801 , 0.0 , 0.0 , 0.7706701479547405 , 1.923353640847097 , 0.019924972355421657 , 1.1954983413252993 , 0.9192553211800524 , 0.0 , 0.8260135515119187 , 24.652372360229492 , 0.0 , 0.7624842539646501 , 0.0 , 0.7088680436660513 , 0.0 , 0.900034114313759 , 0.0 , 0.9326677941032139
26 | congress , MCAR(0.2) , utrees , 0.1890438211894098 , 0.2666739760698372 , 0.0 , 0.0 , 0.7869731800766308 , 2.359770114942525 , 0.007118368287168615 , 0.4271020972301169 , 0.3176210020716755 , 0.0 , 0.9346015785483384 , 28.701904376347862 , 0.0 , 0.9191830259679294 , 0.0 , 0.9395054129002633 , 0.0 , 0.9437576372605899 , 0.0 , 0.9359602380645711
27 | car , MCAR(0.2) , utrees , 0.45735025839689725 , 0.6895113519508365 , 0.0 , 0.0 , 0.5051616015436621 , 1.0813146733811403 , 0.02605064368622248 , 1.8221719722572072 , 1.4401332127730782 , 0.0 , 0.8191771109071323 , 29.15845568974813 , 0.0 , 0.8504025594199145 , 0.0 , 0.7390913722188874 , 0.0 , 0.8187829437640975 , 0.0 , 0.8684315682256301
28 |
--------------------------------------------------------------------------------
/Results/tabular_imputation_results_baltobot4d8f71b.txt:
--------------------------------------------------------------------------------
1 | iris , MCAR(0.2) , utrees , 0.05999383499585656 , 0.0891422525651869 , 0.0 , 0.0 , 0.06622844415087935 , 0.2399510489123476 , 0.0026519395607755435 , 0.14782735105221712 , 0.12211262422232116 , 0.0 , 0.952578092959672 , 5.305421670277913 , 0.0 , 0.9866332497911444 , 0.0 , 0.8935526948684843 , 0.0 , 0.9657432470064047 , 0.0 , 0.9643831801726537
2 | wine , MCAR(0.2) , utrees , 0.09483351208087865 , 0.1309220058072117 , 0.0 , 0.0 , 0.3544626598677714 , 1.4421646366534844 , 0.005640737222570057 , 0.23726015274535953 , 0.19939392703444025 , 0.0 , 0.9369265175356536 , 26.762377103169758 , 0.0 , 0.9632461278304145 , 0.0 , 0.8151893968182157 , 0.0 , 0.9871233059233867 , 0.0 , 0.9821472395705978
3 | parkinsons , MCAR(0.2) , utrees , 0.04692251597772083 , 0.06515695016716706 , 0.0 , 0.0 , 0.2935529924783244 , 1.7112368414543835 , 0.002726737547563139 , 0.12343294098537427 , 0.10333986633321987 , 0.0 , 0.8296471697593928 , 58.98031155268352 , 0.0 , 0.7628487453017361 , 0.0 , 0.838878856219079 , 0.0 , 0.8611200454212758 , 0.0 , 0.8557410320954802
4 | climate_model_crashes , MCAR(0.2) , utrees , 0.23759707521127293 , 0.33835467617379056 , 0.0 , 0.0 , 1.2202348736572843 , 3.8713029908094416 , 0.04378390850643781 , 0.7939392208828289 , 0.6876264315110343 , 0.0 , 0.7075253091655707 , 103.73175875345868 , 0.0 , 0.8216503397548771 , 0.0 , 0.7460006942675264 , 0.0 , 0.4956869617631286 , 0.0 , 0.7667632408767506
5 | concrete_compression , MCAR(0.2) , utrees , 0.02333038049146902 , 0.051929753638068496 , 117.4323195156428 , 0.24444444444444446 , 0.07950083700139465 , 0.5046659061722523 , 0.0056108336568096535 , 0.15947619621621514 , 0.1297940883578644 , 0.7550948876750425 , 0.0 , 47.19429659843445 , 0.5677727274909049 , 0.0 , 0.7378459497855332 , 0.0 , 0.8652996866710022 , 0.0 , 0.8494611867527299 , 0.0
6 | yacht_hydrodynamics , MCAR(0.2) , utrees , 0.027212127217251424 , 0.06525405415346494 , 87.8356344827388 , 0.9619047619047618 , 0.06461383266455692 , 0.5112652032726271 , 0.012247314991194035 , 0.1899437795671296 , 0.14489959622012152 , 0.8958445449867422 , 0.0 , 8.892837285995483 , 0.6051986543649295 , 0.0 , 0.9878021530766475 , 0.0 , 0.9957263778055365 , 0.0 , 0.9946509946998549 , 0.0
7 | airfoil_self_noise , MCAR(0.2) , utrees , 0.024177747618689484 , 0.06644906397149036 , 3.367058543769148 , 1.0 , 0.04595237772071268 , 0.25049733817022946 , 0.012535618795616179 , 0.22880454575076564 , 0.18357767558903554 , 0.7236830200188702 , 0.0 , 29.907463312149048 , 0.5076127218483516 , 0.0 , 0.6223125813761253 , 0.0 , 0.8807123657190816 , 0.0 , 0.884094411131922 , 0.0
8 | connectionist_bench_sonar , MCAR(0.2) , utrees , 0.09859288170328898 , 0.11840726141041186 , 0.0 , 0.0 , 1.429042454382952 , 8.508219085835636 , 0.005136353547750008 , 0.22206330143535155 , 0.18837572035423833 , 0.0 , 0.7987797915830596 , 440.6014119784037 , 0.0 , 0.7211420515274188 , 0.0 , 0.8168914005300593 , 0.0 , 0.8238102292793955 , 0.0 , 0.8332754849953646
9 | ionosphere , MCAR(0.2) , utrees , 0.08523464358983444 , 0.11811873486059866 , 0.0 , 0.0 , 0.7824976996114046 , 4.441913223417513 , 0.014656794885137926 , 0.25404546334261535 , 0.202321645205517 , 0.0 , 0.9098789590534062 , 201.8103465239207 , 0.0 , 0.8732615124612854 , 0.0 , 0.9279816794874949 , 0.0 , 0.9329457905205485 , 0.0 , 0.9053268537442958
10 | qsar_biodegradation , MCAR(0.2) , utrees , 0.015254570285918954 , 0.023403987013942196 , 0.0 , 0.0 , 0.19212359101896018 , 1.3856575124044705 , 0.001247031131769453 , 0.053514867631470106 , 0.04361988321482159 , 0.0 , 0.8486549180518138 , 560.5758926868439 , 0.0 , 0.8511907477600135 , 0.0 , 0.8314650338714398 , 0.0 , 0.8565060426520982 , 0.0 , 0.8554578479237033
11 | seeds , MCAR(0.2) , utrees , 0.0536062409688773 , 0.08449679733209269 , 0.0 , 0.0 , 0.12191528713404252 , 0.47779747604268186 , 0.003367278377665574 , 0.17413053418775576 , 0.14828548838041686 , 0.0 , 0.8828062340725844 , 14.943417390187584 , 0.0 , 0.942468848874979 , 0.0 , 0.7190544422051238 , 0.0 , 0.9288809390255441 , 0.0 , 0.9408207061846907
12 | glass , MCAR(0.2) , utrees , 0.0496206279015378 , 0.07585823492308724 , 0.0 , 0.0 , 0.1395464455143801 , 0.6422554966107725 , 0.00509210478076279 , 0.1437355547360558 , 0.11662889755117278 , 0.0 , 0.543444784278404 , 17.121328353881836 , 0.0 , 0.5597076318333436 , 0.0 , 0.33844731725051297 , 0.0 , 0.6716156389645124 , 0.0 , 0.6040085490652478
13 | ecoli , MCAR(0.2) , utrees , 0.051482101920573564 , 0.07996184068979219 , 0.0 , 0.0 , 0.1094547532281822 , 0.40433483990272084 , 0.0035961465359099646 , 0.1538224632971858 , 0.13044989354532197 , 0.0 , 0.6831116597210881 , 14.686814228693645 , 0.0 , 0.7796676608032461 , 0.0 , 0.3758002592414551 , 0.0 , 0.8016274932079795 , 0.0 , 0.7753512256316719
14 | yeast , MCAR(0.2) , utrees , 0.04378552325566909 , 0.07404024912539708 , 0.0 , 0.0 , 0.10720798893085218 , 0.31883588185505407 , 0.0035788436758160324 , 0.1730284435457075 , 0.14983662228831252 , 0.0 , 0.44429114417699456 , 62.958171367645264 , 0.0 , 0.5408617994798776 , 0.0 , 0.23325600825320442 , 0.0 , 0.520586653798653 , 0.0 , 0.48246011517624343
15 | libras , MCAR(0.2) , utrees , 0.03063284486586349 , 0.03638737201499229 , 0.0 , 0.0 , 0.6566358196725776 , 8.933323976745273 , 0.0008112346007017335 , 0.08171470473946038 , 0.07063236617473345 , 0.0 , 0.568547098210967 , 1975.7754249572754 , 0.0 , 0.7305759672426337 , 0.0 , 0.046811110866585924 , 0.0 , 0.8014451030451031 , 0.0 , 0.6953562116895451
16 | planning_relax , MCAR(0.2) , utrees , 0.08411692836706103 , 0.12137667314935206 , 0.0 , 0.0 , 0.3058909258660944 , 1.460002019772343 , 0.005571385169514211 , 0.23866930500059752 , 0.20262838701678848 , 0.0 , 0.45205072102550564 , 25.245799938837685 , 0.0 , 0.4181319219394311 , 0.0 , 0.4824260898832327 , 0.0 , 0.42651343398470287 , 0.0 , 0.4811314382946558
17 | blood_transfusion , MCAR(0.2) , utrees , 0.03444502502881287 , 0.06688821811064323 , 0.0 , 0.0 , 0.03212365113334971 , 0.11167959040823328 , 0.004795884147921586 , 0.15814227169156175 , 0.13153806150238792 , 0.0 , 0.5870497695788247 , 13.1585369904836 , 0.0 , 0.5558678508828397 , 0.0 , 0.6147739158922271 , 0.0 , 0.584815196941721 , 0.0 , 0.5927421145985113
18 | breast_cancer_diagnostic , MCAR(0.2) , utrees , 0.03956558039073609 , 0.051628040615653036 , 0.0 , 0.0 , 0.31024376602207276 , 1.846053935538329 , 0.0012147251350276003 , 0.10086794072744275 , 0.08730993646823641 , 0.0 , 0.9586634419673519 , 279.3262273470561 , 0.0 , 0.957830676675561 , 0.0 , 0.9554613875073541 , 0.0 , 0.9554992846097188 , 0.0 , 0.9658624190767728
19 | connectionist_bench_vowel , MCAR(0.2) , utrees , 0.05296810071933365 , 0.09388309984454896 , 0.0 , 0.0 , 0.1878788763752319 , 0.725079410750455 , 0.0055250297398647535 , 0.24841801262652985 , 0.21448056138434413 , 0.0 , 0.6641100768645267 , 79.84957003593445 , 0.0 , 0.6514435776789809 , 0.0 , 0.20775287118264388 , 0.0 , 0.9232551451738964 , 0.0 , 0.8739887134225857
20 | concrete_slump , MCAR(0.2) , utrees , 0.1251534691485566 , 0.18784554316110125 , 47.56531663695364 , 0.725 , 0.2639606575784001 , 1.159978473434967 , 0.014801640827374072 , 0.34044859085444257 , 0.27503999548986047 , 0.6745252901753798 , 0.0 , 9.739169359207153 , 0.7541523974515636 , 0.0 , 0.6569805595622225 , 0.0 , 0.6881385743588726 , 0.0 , 0.5988296293288606 , 0.0
21 | wine_quality_red , MCAR(0.2) , utrees , 0.040802649372952006 , 0.0716406477510256 , 20.145217590505506 , 1.0 , 0.14132476856233 , 0.5169713070740208 , 0.0035164216726677883 , 0.18320501100487552 , 0.1580449829731422 , 0.3059707949127771 , 0.0 , 106.22926425933838 , 0.3170426193941901 , 0.0 , 0.2886979501708553 , 0.0 , 0.3691983195529146 , 0.0 , 0.24894429053314845 , 0.0
22 | wine_quality_white , MCAR(0.2) , utrees , 0.03407725293496344 , 0.06448257401806023 , 77.35079172193863 , 0.47777777777777775 , 0.14023782049929634 , 0.45284739625845083 , 0.0033595964423761425 , 0.19108912571449158 , 0.16837634325966563 , 0.33787091146852133 , 0.0 , 357.68214988708496 , 0.2680334535357732 , 0.0 , 0.2820063034609025 , 0.0 , 0.43474693460671304 , 0.0 , 0.3666969542706964 , 0.0
23 | california , MCAR(0.2) , utrees , 0.0196647697319096 , 0.04973844353652448 , 23.21634260208505 , 0.5925925925925927 , 0.0 , 0.0 , 0.004982026822227645 , 0.15803420495132264 , 0.13993056163009363 , 0.6548593415015938 , 0.0 , 968.1369389692943 , 0.5916727995885759 , 0.0 , 0.4237977876023389 , 0.0 , 0.7901864293781431 , 0.0 , 0.8137803494373173 , 0.0
24 | bean , MCAR(0.2) , utrees , 0.010192959415544817 , 0.02059535423265372 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0010522477503439255 , 0.06418175575996427 , 0.05799520234562458 , 0.0 , 0.7816665143749051 , 1929.1566933790846 , 0.0 , 0.7961193970932972 , 0.0 , 0.4618599044548654 , 0.0 , 0.9329963646203907 , 0.0 , 0.9356903913310666
25 | tictactoe , MCAR(0.2) , utrees , 0.29630144489131827 , 0.5110035391675715 , 0.0 , 0.0 , 0.7762402088772818 , 1.9328428706121097 , 0.024496550015139 , 1.4697930009083398 , 1.126450727895411 , 0.0 , 0.8234316722545656 , 25.853189309438072 , 0.0 , 0.7602380388772744 , 0.0 , 0.7110043303006972 , 0.0 , 0.8965042745894501 , 0.0 , 0.9259800452508402
26 | congress , MCAR(0.2) , utrees , 0.1729734997216229 , 0.2774138758682399 , 0.0 , 0.0 , 0.8032567049808451 , 2.3781609195402256 , 0.00975922053794807 , 0.5855532322768843 , 0.4332464178272987 , 0.0 , 0.9327371459215983 , 30.412309885025024 , 0.0 , 0.9068635943563466 , 0.0 , 0.932205365760348 , 0.0 , 0.9494113973748777 , 0.0 , 0.9424682261948213
27 | car , MCAR(0.2) , utrees , 0.4036153107206999 , 0.6852703901190534 , 0.0 , 0.0 , 0.48369512783406193 , 1.070682515914789 , 0.029514653850087412 , 2.0644051407904334 , 1.6485394454306355 , 0.0 , 0.80069479177201 , 30.177862962086998 , 0.0 , 0.8484301461689044 , 0.0 , 0.6644293855448289 , 0.0 , 0.8200663390532182 , 0.0 , 0.8698532963210888
28 |
--------------------------------------------------------------------------------
/Results/latex_tables.txt:
--------------------------------------------------------------------------------
1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
2 | % Fri Jul 5 22:11:55 2024
3 | \begin{table}[ht]
4 | \centering
5 | \begin{tabular}{rlllllllll}
6 | \hline
7 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
8 | \hline
9 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\
10 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\
11 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\
12 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\
13 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\
14 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\
15 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\
16 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\
17 | utrees & 0.1 (0.02) & 0.14 (0.03) & 0.42 (0.09) & 1.89 (0.49) & 0.27 (0.06) & 0.61 (0.1) & 0.76 (0.04) & 0.68 (0.24) & 0.68 (0.14) \\
18 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\
19 | \hline
20 | \end{tabular}
21 | \end{table}
22 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
23 | % Fri Jul 5 22:11:55 2024
24 | \begin{table}[ht]
25 | \centering
26 | \begin{tabular}{rlllllllll}
27 | \hline
28 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
29 | \hline
30 | KNN & 5.5 (0.5) & 6.3 (0.5) & 4.7 (0.4) & 4.8 (0.4) & 8.4 (0) & 6.5 (1) & 5.7 (1.1) & 6.2 (1) & 5.4 (0.6) \\
31 | ICE & 6.7 (0.4) & 4.5 (0.4) & 7 (0.5) & 7.2 (0.4) & 1.6 (0.2) & 6 (1.1) & 7 (0.6) & 5.7 (0.9) & 5.2 (0.6) \\
32 | MICE-Forest & 3.9 (0.4) & 2.5 (0.4) & 2.8 (0.2) & 2.7 (0.3) & 3.6 (0.2) & 3.7 (1.4) & 3.2 (1) & 5.5 (1.2) & 4.1 (0.6) \\
33 | MissForest & 2.6 (0.5) & 4 (0.4) & 1.8 (0.3) & 1.9 (0.3) & 5.4 (0.2) & 3.8 (1.4) & 2.3 (0.6) & 5.5 (1.5) & 3.2 (0.4) \\
34 | Softimpute & 6.7 (0.4) & 7.6 (0.4) & 7.1 (0.5) & 7.3 (0.5) & 8.4 (0) & 5.8 (1) & 7.8 (0.4) & 6.3 (0.9) & 6.5 (0.4) \\
35 | OT & 5.9 (0.4) & 6 (0.3) & 6 (0.5) & 6 (0.5) & 3.7 (0.3) & 6.2 (0.5) & 6.8 (0.6) & 5.3 (0.7) & 4.8 (0.5) \\
36 | GAIN & 4.6 (0.4) & 6.5 (0.3) & 5.8 (0.3) & 6 (0.3) & 6.9 (0.1) & 5.5 (0.8) & 5.4 (0.8) & 4.7 (1) & 5.1 (0.6) \\
37 | Forest-VP & 5.1 (0.4) & 3.8 (0.5) & 5.6 (0.3) & 5 (0.4) & 3.2 (0.4) & 4.5 (0.9) & 4.6 (0.8) & 3.2 (0.7) & 5.4 (0.6) \\
38 | utrees & 4 (0.5) & 3.8 (0.5) & 4.3 (0.5) & 4.2 (0.5) & 3.7 (0.3) & 3 (0.8) & 2.3 (0.6) & 2.7 (1) & 5.4 (0.6) \\
39 | \hline
40 | \end{tabular}
41 | \end{table}
42 |
43 |
44 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
45 | % Sun Jul 7 22:39:34 2024
46 | \begin{table}[ht]
47 | \centering
48 | \begin{tabular}{rlllllllll}
49 | \hline
50 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\
51 | \hline
52 | GaussianCopula & 2.74 (0.56) & 2.99 (0.61) & 0.18 (0.04) & 0.37 (0.06) & 0.2 (0.14) & 0.46 (0.06) & NA (0.05) & 2.27 (0.77) & 0.23 (0.12) \\
53 | TVAE & 2.12 (0.58) & 2.35 (0.63) & 0.33 (0.04) & 0.63 (0.04) & -0.47 (0.61) & 0.52 (0.08) & NA (0.01) & 4.15 (1.97) & 0.26 (0.09) \\
54 | CTGAN & 3.58 (0.99) & 3.74 (1.01) & 0.12 (0.03) & 0.28 (0.04) & -0.43 (0.08) & 0.35 (0.04) & NA (0.01) & 2.48 (1.3) & 0.2 (0.08) \\
55 | CTAB-GAN+ & 2.71 (0.81) & 2.89 (0.83) & 0.22 (0.04) & 0.44 (0.05) & 0.05 (0.12) & 0.44 (0.05) & NA (0.02) & 2.95 (1.04) & 0.26 (0.07) \\
56 | STaSy & 3.41 (1.39) & 3.66 (1.42) & 0.38 (0.05) & 0.63 (0.05) & -4.21 (4.44) & 0.61 (0.06) & NA (0.02) & 1.23 (0.44) & 0.45 (0.12) \\
57 | TabDDPM & 4.27 (1.89) & 4.79 (1.89) & 0.76 (0.06) & 0.8 (0.06) & 0.6 (0.11) & 0.66 (0.06) & NA (0.03) & 0.76 (0.28) & 0.72 (0.11) \\
58 | Forest-VP & 1.46 (0.4) & 1.94 (0.5) & 0.67 (0.05) & 0.84 (0.03) & 0.55 (0.1) & 0.73 (0.04) & NA (0.01) & 0.94 (0.3) & 0.52 (0.15) \\
59 | Forest-Flow & 1.36 (0.39) & 1.9 (0.5) & 0.83 (0.03) & 0.9 (0.03) & 0.57 (0.11) & 0.73 (0.04) & NA (0.01) & 0.83 (0.23) & 0.63 (0.11) \\
60 | UTrees & 1.73 (0.52) & 2.12 (0.58) & 0.63 (0.05) & 0.82 (0.03) & 0 (0.51) & 0.67 (0.06) & NA (NA) & 2.04 (1.03) & 0.36 (0.07) \\
61 | Oracle & 0 (0) & 1.81 (0.47) & 0.99 (0.01) & 0.91 (0.04) & 0.64 (0.09) & 0.77 (0.04) & NA (0) & 0 (0) & 1 (0) \\
62 | \hline
63 | \end{tabular}
64 | \end{table}
65 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
66 | % Sun Jul 7 22:39:34 2024
67 | \begin{table}[ht]
68 | \centering
69 | \begin{tabular}{rlllllllll}
70 | \hline
71 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\
72 | \hline
73 | GaussianCopula & 7.1 (0.3) & 7.2 (0.3) & 7.3 (0.3) & 7.3 (0.3) & 6 (0) & 6.4 (0.3) & 7 (0.4) & 6.3 (1.1) & 7.3 (0.8) \\
74 | TVAE & 5.2 (0.2) & 5 (0.2) & 5.7 (0.2) & 5.6 (0.2) & 6.5 (0.7) & 6 (0.5) & 5.5 (0.3) & 7.3 (0.6) & 6.7 (0.6) \\
75 | CTGAN & 8.4 (0.1) & 8.4 (0.2) & 8.3 (0.2) & 8.1 (0.2) & 8.3 (0.3) & 8.3 (0.2) & 6.7 (0.3) & 5.3 (1.1) & 7.2 (0.5) \\
76 | CTAB-GAN+ & 6.8 (0.3) & 6.6 (0.3) & 7.2 (0.3) & 7 (0.3) & 6.8 (0.4) & 6.8 (0.4) & 6.9 (0.3) & 7.7 (0.8) & 7 (0.8) \\
77 | STaSy & 6.1 (0.2) & 6.2 (0.2) & 5.3 (0.2) & 5.2 (0.3) & 6 (1.2) & 5 (0.3) & 6.1 (0.3) & 4.5 (0.8) & 4.4 (1.1) \\
78 | TabDDPM & 3 (0.7) & 3.8 (0.6) & 2.8 (0.5) & 3.2 (0.5) & 1.2 (0.2) & 3.6 (0.6) & 3.2 (0.4) & 3 (0.9) & 1.6 (0.2) \\
79 | Forest-VP & 2.8 (0.1) & 2.6 (0.2) & 3.3 (0.3) & 3 (0.3) & 3 (0.4) & 2.1 (0.3) & 4.3 (0.4) & 3.2 (0.9) & 3.3 (0.7) \\
80 | Forest-Flow & 1.9 (0.2) & 1.4 (0.2) & 1.7 (0.2) & 1.7 (0.2) & 2.3 (0.4) & 2.3 (0.3) & 4.3 (0.4) & 2.8 (0.5) & 2.7 (0.3) \\
81 | UTrees & 3.5 (0.2) & 3.6 (0.2) & 3.6 (0.3) & 3.7 (0.3) & 4.8 (0.7) & 4.3 (0.4) & 1 (0) & 4.8 (0.8) & 4.8 (0.9) \\
82 | \hline
83 | \end{tabular}
84 | \end{table}
85 |
86 |
87 | Rscript generation_script_nmiss50.R
88 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
89 | % Fri Aug 9 21:21:43 2024
90 | \begin{table}[ht]
91 | \centering
92 | \begin{tabular}{rlllllllll}
93 | \hline
94 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\
95 | \hline
96 | GaussianCopula & 7 (0.3) & 7.1 (0.2) & 7.2 (0.3) & 7.1 (0.3) & 6 (0.4) & 6.6 (0.3) & 6.7 (0.4) & 5.5 (1) & 7.7 (0.6) \\
97 | TVAE & 5.2 (0.3) & 4.8 (0.3) & 5.7 (0.3) & 5.7 (0.2) & 6 (1) & 5.7 (0.5) & 5.8 (0.4) & 7.8 (0.6) & 6.2 (1) \\
98 | CTGAN & 8.3 (0.2) & 8.4 (0.2) & 8.4 (0.2) & 8.3 (0.2) & 8.2 (0.5) & 8.3 (0.2) & 6.5 (0.2) & 4.8 (1.2) & 6.9 (0.7) \\
99 | CTABGAN & 6.6 (0.4) & 6.5 (0.4) & 7 (0.3) & 6.7 (0.3) & 7.3 (0.6) & 7.1 (0.4) & 6.6 (0.3) & 7.5 (1) & 6.1 (0.5) \\
100 | Stasy & 5.9 (0.2) & 6 (0.2) & 5.2 (0.2) & 5 (0.3) & 5.5 (1) & 4.3 (0.4) & 5.3 (0.4) & 3.8 (0.4) & 4.6 (1.1) \\
101 | TabDDPM & 3 (0.7) & 3.4 (0.7) & 2.3 (0.5) & 2.9 (0.6) & 1.7 (0.3) & 3.2 (0.6) & 3.9 (0.6) & 3.8 (1.2) & 2 (0.5) \\
102 | Forest-VP & 3.1 (0.3) & 3 (0.3) & 3.6 (0.2) & 3.5 (0.3) & 3 (0.3) & 2.2 (0.3) & 4.2 (0.4) & 4.3 (0.8) & 4.2 (1.2) \\
103 | Forest-Flow & 2.5 (0.4) & 2.3 (0.3) & 2.3 (0.2) & 2.5 (0.3) & 3 (0.7) & 3.5 (0.3) & 5 (0.5) & 3.7 (0.8) & 3.1 (0.8) \\
104 | UTrees & 3.3 (0.2) & 3.6 (0.3) & 3.3 (0.3) & 3.3 (0.3) & 4.3 (1.1) & 4.2 (0.5) & 1 (0) & 3.7 (1) & 4.3 (1) \\
105 | \hline
106 | \end{tabular}
107 | \end{table}
108 |
109 |
110 | Rscript imputation_script.R -- commit daa3b12
111 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
112 | % Sat Aug 31 13:59:43 2024
113 | \begin{table}[ht]
114 | \centering
115 | \begin{tabular}{rlllllllll}
116 | \hline
117 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
118 | \hline
119 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\
120 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\
121 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\
122 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\
123 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\
124 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\
125 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\
126 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\
127 | utrees & 0.09 (0.02) & 0.13 (0.03) & 0.36 (0.08) & 1.87 (0.48) & 0.19 (0.07) & 0.61 (0.1) & 0.76 (0.04) & 0.44 (0.14) & 0.73 (0.12) \\
128 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\
129 | \hline
130 | \end{tabular}
131 | \end{table}
132 | % latex table generated in R 4.4.1 by xtable 1.8-4 package
133 | % Sat Aug 31 13:59:43 2024
134 | \begin{table}[ht]
135 | \centering
136 | \begin{tabular}{rlllllllll}
137 | \hline
138 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\
139 | \hline
140 | KNN & 5.5 (0.5) & 6.3 (0.5) & 5 (0.4) & 5 (0.4) & 8.4 (0) & 6.5 (1) & 5.8 (1.1) & 6.2 (1) & 5.5 (0.6) \\
141 | ICE & 6.8 (0.4) & 4.5 (0.4) & 7 (0.5) & 7.2 (0.4) & 1.5 (0.2) & 6.2 (1) & 7 (0.6) & 5.7 (0.9) & 5.2 (0.7) \\
142 | MICE-Forest & 3.9 (0.4) & 2.3 (0.4) & 2.9 (0.2) & 3.1 (0.2) & 3.4 (0.2) & 3.8 (1.4) & 3.2 (1) & 5.5 (1.2) & 4.3 (0.6) \\
143 | MissForest & 2.8 (0.5) & 4 (0.4) & 1.8 (0.3) & 2.1 (0.3) & 5.3 (0.2) & 4 (1.3) & 2.3 (0.6) & 5.5 (1.5) & 3.4 (0.5) \\
144 | Softimpute & 6.7 (0.4) & 7.5 (0.4) & 7.1 (0.5) & 7.3 (0.5) & 8.4 (0) & 6.2 (0.8) & 7.8 (0.4) & 6.3 (0.9) & 6.7 (0.4) \\
145 | OT & 6 (0.4) & 6 (0.3) & 6.1 (0.5) & 6 (0.5) & 3.5 (0.3) & 6.2 (0.5) & 6.8 (0.6) & 5.5 (0.8) & 4.9 (0.5) \\
146 | GAIN & 4.7 (0.4) & 6.5 (0.3) & 6 (0.3) & 6 (0.2) & 6.9 (0.1) & 5.7 (0.8) & 5.4 (0.8) & 4.7 (1) & 5.1 (0.6) \\
147 | Forest-VP & 5.3 (0.4) & 3.8 (0.5) & 5.9 (0.3) & 5.2 (0.4) & 3.1 (0.4) & 4.3 (0.9) & 4.6 (0.8) & 3.5 (0.7) & 5.4 (0.6) \\
148 | utrees & 3.3 (0.5) & 4.1 (0.5) & 3.2 (0.4) & 3.1 (0.5) & 4.3 (0.2) & 2.2 (0.7) & 2.2 (0.5) & 2.2 (0.7) & 4.5 (0.5) \\
149 | \hline
150 | \end{tabular}
151 | \end{table}
152 |
--------------------------------------------------------------------------------
/paper/moons.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "7fbcfa7d-672a-4033-bb7d-4f58038778c5",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import matplotlib as mpl\n",
11 | "mpl.rcParams['font.family'] = 'Arial'\n",
12 | "mpl.rcParams['text.usetex'] = False\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import numpy as np\n",
15 | "import seaborn as sns\n",
16 | "import sklearn.datasets as skd\n",
17 | "import pandas as pd\n",
18 | "\n",
19 | "from sklearn.utils import check_random_state\n",
20 | "from ForestDiffusion import ForestDiffusionModel\n",
21 | "from miceforest import ImputationKernel\n",
22 | "from missforest import MissForest\n",
23 | "from utrees import UnmaskingTrees"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "id": "cbc3d109-ae5b-4f5f-9070-e1d9da13cea1",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "rix = 0\n",
34 | "n = 200\n",
35 | "nimp = 1 # number of multiple imputations needed\n",
36 | "ngen = 200\n",
37 | "\n",
38 | "rng = check_random_state(rix)\n",
39 | "data, labels = skd.make_moons(n, shuffle=False, noise=0.1, random_state=rix)\n",
40 | "data4impute = data.copy()\n",
41 | "data4impute[:, 1] = np.nan\n",
42 | "X=np.concatenate([data, data4impute], axis=0)\n",
43 | "impute_samples = np.isnan(X).any(axis=1)\n",
44 | "\n",
45 | "missfer = MissForest(random_state=rix)\n",
46 | "impute_missf = missfer.fit_transform(X.copy(), cat_vars=None)\n",
47 | "\n",
48 | "micer = ImputationKernel(X.copy(), random_state=rix)\n",
49 | "micer.mice(5)\n",
50 | "impute_mice = micer.complete_data()\n",
51 | "\n",
52 | "utreer = UnmaskingTrees(random_state=rix)\n",
53 | "utreer.fit(X.copy())\n",
54 | "gen_utrees = utreer.generate(n_generate=ngen);\n",
55 | "impute_utrees = utreer.impute(n_impute=nimp)[0, :, :]\n",
56 | "\n",
57 | "utaber = UnmaskingTrees(tabpfn=True, random_state=rix)\n",
58 | "utaber.fit(X.copy())\n",
59 | "gen_utab = utaber.generate(n_generate=ngen);\n",
60 | "impute_utab = utaber.impute(n_impute=nimp)[0, :, :]\n",
61 | "\n",
62 | "forestvper = ForestDiffusionModel(\n",
63 | " X=X.copy(),\n",
64 | " n_t=50, duplicate_K=100, diffusion_type='vp',\n",
65 | " bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n",
66 | "gen_forestvp = forestvper.generate(batch_size=ngen)\n",
67 | "impute_forestvp_fast = forestvper.impute(k=nimp) # regular (fast)\n",
68 | "impute_forestvp_repaint = forestvper.impute(repaint=True, r=10, j=5, k=nimp) # REPAINT (slow, but better)\n",
69 | "\n",
70 | "forestflower = ForestDiffusionModel(\n",
71 | " X=X.copy(),\n",
72 | " n_t=50, duplicate_K=100, diffusion_type='flow',\n",
73 | " bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n",
74 | "gen_forestflow = forestflower.generate(batch_size=ngen)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "b209f40a-e281-4e72-ab39-4ed723a2e1fb",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(7, 5), squeeze=False, dpi=200, sharex=True, sharey=True);\n",
85 | "markersize = 5\n",
86 | "alpha = 0.8\n",
87 | "color = 'blue'\n",
88 | "axes[0, 0].set_title('Original data');\n",
89 | "axes[0, 0].scatter(data[:, 0], data[:, 1], s=markersize, alpha=alpha, color='black', marker='x', linewidth=0.5,);\n",
90 | "axes[1, 0].set_title('Training data');\n",
91 | "axes[1, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color='green', marker='x', linewidth=0.5,);\n",
92 | "xlim = axes[0, 0].get_xlim();\n",
93 | "ylim = axes[0, 0].get_ylim();\n",
94 | "onlyfirst = X[~np.isnan(X[:, 0]), 0]\n",
95 | "onlyfirst = data[:, 0]\n",
96 | "onlysecond = X[~np.isnan(X[:, 1]), 1]\n",
97 | "onlysecond = data[:, 1]\n",
98 | "axes[1, 0].scatter(onlyfirst, ylim[0]*np.ones_like(onlyfirst), marker='|', s=100, linewidth=0.5, color='green');\n",
99 | "axes[1, 0].scatter(xlim[0]*np.ones_like(onlysecond), onlysecond, marker='_', s=100, linewidth=0.5, color='green');\n",
100 | "axes[1, 0].set_xlim(xlim);\n",
101 | "axes[1, 0].set_ylim(ylim);\n",
102 | "\n",
103 | "axes[0, 0].set_title('Training data');\n",
104 | "axes[0, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color='black', marker='x', linewidth=0.5,);\n",
105 | "\n",
106 | "axes[0, 1].set_title('Forest-VP (generate)');\n",
107 | "axes[0, 1].scatter(gen_forestvp[:, 0], gen_forestvp[:, 1], s=markersize, alpha=alpha, color=color);\n",
108 | "axes[1, 1].set_title('Forest-Flow (generate)');\n",
109 | "axes[1, 1].scatter(gen_forestflow[:, 0], gen_forestflow[:, 1], s=markersize, alpha=alpha, color=color);\n",
110 | "axes[0, 2].set_title('UnmaskingTrees (generate)');\n",
111 | "axes[0, 2].scatter(gen_utrees[:, 0], gen_utrees[:, 1], s=markersize, alpha=alpha, color=color);\n",
112 | "axes[1, 2].set_title('UnmaskingTabPFN (generate)');\n",
113 | "axes[1, 2].scatter(gen_utab[:, 0], gen_utab[:, 1], s=markersize, alpha=alpha, color=color);\n",
114 | "plt.tight_layout();\n",
115 | "\n",
116 | "\n",
117 | "fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(4, 4), squeeze=False, dpi=200, sharex=True, sharey=True);\n",
118 | "markersize = 5\n",
119 | "alpha = 0.7\n",
120 | "color = 'red'\n",
121 | "datacolor = 'green'\n",
122 | "axes[0, 0].set_title('Forest-VP');\n",
123 | "axes[0, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n",
124 | "axes[0, 0].scatter(gen_forestvp[:, 0], gen_forestvp[:, 1], s=markersize, alpha=alpha, color=color);\n",
125 | "axes[1, 0].set_title('Forest-Flow');\n",
126 | "axes[1, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5);\n",
127 | "axes[1, 0].scatter(gen_forestflow[:, 0], gen_forestflow[:, 1], s=markersize, alpha=alpha, color=color);\n",
128 | "axes[0, 1].set_title('UnmaskingTrees');\n",
129 | "axes[0, 1].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5);\n",
130 | "axes[0, 1].scatter(gen_utrees[:, 0], gen_utrees[:, 1], s=markersize, alpha=alpha, color=color);\n",
131 | "axes[1, 1].set_title('UnmaskingTabPFN');\n",
132 | "axes[1, 1].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5,label='data');\n",
133 | "axes[1, 1].scatter(gen_utab[:, 0], gen_utab[:, 1], s=markersize, alpha=alpha, color=color,label='generated');\n",
134 | "plt.legend(handlelength=0.4, borderpad=0.4, labelspacing=0.3,framealpha=0.9,handletextpad=0.3);\n",
135 | "for curax in axes.flatten():\n",
136 | " curax.set_yticks([0, 1]);\n",
137 | "plt.tight_layout();\n",
138 | "plt.savefig('moons-generation.png');"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "id": "bf0923b0-5bdb-4533-9ada-25c10ab98581",
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(6, 4.3), squeeze=False, dpi=200, sharex=True, sharey=True);\n",
149 | "markersize = 5\n",
150 | "alpha = 0.8\n",
151 | "color = 'blue'\n",
152 | "datacolor = 'green'\n",
153 | "\n",
154 | "axes[0, 0].set_title('MissForest');\n",
155 | "axes[0, 0].scatter(\n",
156 | " impute_missf[impute_samples, 0], impute_missf[impute_samples, 1],\n",
157 | " s=markersize, alpha=alpha, color=color);\n",
158 | "axes[0, 0].scatter(\n",
159 | " data[:, 0], data[:, 1],\n",
160 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n",
161 | "axes[1, 0].set_title('MICE-Forest');\n",
162 | "axes[1, 0].scatter(\n",
163 | " impute_mice[impute_samples, 0], impute_mice[impute_samples, 1],\n",
164 | " s=markersize, alpha=alpha, color=color);\n",
165 | "axes[1, 0].scatter(\n",
166 | " data[:, 0], data[:, 1],\n",
167 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n",
168 | "axes[0, 1].set_title('Forest-VP');\n",
169 | "axes[0, 1].scatter(\n",
170 | " impute_forestvp_fast[impute_samples, 0], impute_forestvp_fast[impute_samples, 1],\n",
171 | " s=markersize, alpha=alpha, color=color);\n",
172 | "axes[0, 1].scatter(\n",
173 | " data[:, 0], data[:, 1],\n",
174 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n",
175 | "axes[1, 1].set_title('Forest-VP w/ RePaint');\n",
176 | "axes[1, 1].scatter(\n",
177 | " impute_forestvp_repaint[impute_samples, 0], impute_forestvp_repaint[impute_samples, 1],\n",
178 | " s=markersize, alpha=alpha, color=color);\n",
179 | "axes[1, 1].scatter(\n",
180 | " data[:, 0], data[:, 1],\n",
181 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n",
182 | "axes[0, 2].set_title('UnmaskingTrees');\n",
183 | "axes[0, 2].scatter(\n",
184 | " impute_utrees[impute_samples, 0], impute_utrees[impute_samples, 1],\n",
185 | " s=markersize, alpha=alpha, color=color);\n",
186 | "axes[0, 2].scatter(\n",
187 | " data[:, 0], data[:, 1],\n",
188 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n",
189 | "axes[1, 2].set_title('UnmaskingTabPFN');\n",
190 | "axes[1, 2].scatter(\n",
191 | " data[:, 0], data[:, 1],\n",
192 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5,label='data');\n",
193 | "axes[1, 2].scatter(\n",
194 | " impute_utab[impute_samples, 0], impute_utab[impute_samples, 1],\n",
195 | " s=markersize, alpha=alpha, color=color,label='imputed');\n",
196 | "\n",
197 | "for curax in axes.flatten():\n",
198 | " curax.set_yticks([0, 1]);\n",
199 | "plt.legend(handlelength=0.4, borderpad=0.4, labelspacing=0.3,framealpha=0.9,handletextpad=0.3);\n",
200 | "plt.tight_layout();\n",
201 | "plt.savefig('moons-imputation.png');"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "id": "10a7d730-04db-4b7c-94c1-422e1556cf3f",
208 | "metadata": {},
209 | "outputs": [],
210 | "source": []
211 | }
212 | ],
213 | "metadata": {
214 | "kernelspec": {
215 | "display_name": "Python 3 (ipykernel)",
216 | "language": "python",
217 | "name": "python3"
218 | },
219 | "language_info": {
220 | "codemirror_mode": {
221 | "name": "ipython",
222 | "version": 3
223 | },
224 | "file_extension": ".py",
225 | "mimetype": "text/x-python",
226 | "name": "python",
227 | "nbconvert_exporter": "python",
228 | "pygments_lexer": "ipython3",
229 | "version": "3.9.19"
230 | }
231 | },
232 | "nbformat": 4,
233 | "nbformat_minor": 5
234 | }
235 |
--------------------------------------------------------------------------------
/utrees/kdi_quantizer.py:
--------------------------------------------------------------------------------
1 | # Adapted from scikit-learn/sklearn/preprocessing/_discretization.py
2 | # with the following authorship and license:
3 |
4 | # Author: Henry Lin
5 | # Tom Dupré la Tour
6 |
7 | # License: BSD
8 | import warnings
9 | from numbers import Integral
10 |
11 | import numpy as np
12 |
13 | from kditransform import KDITransformer
14 | from sklearn.base import BaseEstimator, TransformerMixin
15 | from sklearn.cluster import KMeans
16 | from sklearn.preprocessing import OneHotEncoder
17 | from sklearn.utils import (
18 | check_random_state,
19 | resample,
20 | )
21 | from sklearn.utils._param_validation import Interval, Options, StrOptions
22 | from sklearn.utils.stats import _weighted_percentile
23 | from sklearn.utils.validation import (
24 | _check_feature_names_in,
25 | _check_sample_weight,
26 | check_array,
27 | check_is_fitted,
28 | )
29 |
30 |
31 | class KDIQuantizer(TransformerMixin, BaseEstimator):
32 | _parameter_constraints: dict = {
33 | "n_bins": [Interval(Integral, 2, None, closed="left"), "array-like"],
34 | "encode": [StrOptions({"ordinal"})],
35 | "strategy": [StrOptions({"uniform", "quantile", "kmeans", "kdiquantile"})],
36 | "dtype": [Options(type, {np.float64, np.float32}), None],
37 | "subsample": [Interval(Integral, 1, None, closed="left"), None],
38 | "random_state": ["random_state"],
39 | }
40 |
41 | def __init__(
42 | self,
43 | n_bins=5,
44 | *,
45 | encode="ordinal",
46 | strategy="kdiquantile",
47 | dtype=None,
48 | subsample=200_000,
49 | random_state=None,
50 | ):
51 | self.n_bins = n_bins
52 | self.encode = encode
53 | self.strategy = strategy
54 | self.dtype = dtype
55 | self.subsample = subsample
56 | self.random_state = random_state
57 |
58 | def fit(self, X, y=None, sample_weight=None):
59 | """
60 | Fit the estimator.
61 |
62 | Parameters
63 | ----------
64 | X : array-like of shape (n_samples, n_features)
65 | Data to be discretized.
66 |
67 | y : None
68 | Ignored. This parameter exists only for compatibility with
69 | :class:`~sklearn.pipeline.Pipeline`.
70 |
71 | sample_weight : ndarray of shape (n_samples,)
72 | Contains weight values to be associated with each sample.
73 | Cannot be used when `strategy` is set to `'uniform'`.
74 |
75 | .. versionadded:: 1.3
76 |
77 | Returns
78 | -------
79 | self : object
80 | Returns the instance itself.
81 | """
82 | X = self._validate_data(X, dtype="numeric")
83 |
84 | if self.dtype in (np.float64, np.float32):
85 | output_dtype = self.dtype
86 | else: # self.dtype is None
87 | output_dtype = X.dtype
88 |
89 | n_samples, n_features = X.shape
90 |
91 | if sample_weight is not None and self.strategy in ("uniform", "kdiquantile"):
92 | raise ValueError(
93 | "`sample_weight` was provided but it cannot be "
94 | "used with uniform or kdiquantile. Got strategy="
95 | f"{self.strategy!r} instead."
96 | )
97 |
98 | if self.subsample is not None and n_samples > self.subsample:
99 | # Take a subsample of `X`
100 | X = resample(
101 | X,
102 | replace=False,
103 | n_samples=self.subsample,
104 | random_state=self.random_state,
105 | )
106 |
107 | n_features = X.shape[1]
108 | n_bins = self._validate_n_bins(n_features)
109 |
110 | if sample_weight is not None:
111 | sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
112 |
113 | bin_edges = np.zeros(n_features, dtype=object)
114 | for jj in range(n_features):
115 | column = X[:, jj]
116 | col_min, col_max = column.min(), column.max()
117 |
118 | if col_min == col_max:
119 | # warnings.warn(
120 | # 'Feature %d is constant and will be replaced with 0.' % jj
121 | # )
122 | n_bins[jj] = 1
123 | bin_edges[jj] = np.array([-np.inf, np.inf])
124 | continue
125 |
126 | if self.strategy == "uniform":
127 | bin_edges[jj] = np.linspace(col_min, col_max, n_bins[jj] + 1)
128 | elif self.strategy == "quantile":
129 | quantiles = np.linspace(0, 100, n_bins[jj] + 1)
130 | if sample_weight is None:
131 | bin_edges[jj] = np.asarray(np.percentile(column, quantiles))
132 | else:
133 | bin_edges[jj] = np.asarray(
134 | [
135 | _weighted_percentile(column, sample_weight, q)
136 | for q in quantiles
137 | ],
138 | dtype=np.float64,
139 | )
140 | elif self.strategy == "kdiquantile":
141 | quantiles = np.linspace(0, 1, n_bins[jj] + 1)
142 | kdier = KDITransformer().fit(column.reshape(-1, 1))
143 | cur_bin_edges = kdier.inverse_transform(
144 | quantiles.reshape(-1, 1)
145 | ).ravel()
146 | if np.unique(cur_bin_edges).shape[0] < n_bins[jj] + 1:
147 | # XXX - this bugfix is gross: find local minima instead
148 | cur_bin_edges = np.linspace(col_min, col_max, n_bins[jj] + 1)
149 | warnings.warn("kdiquantile numerical error detected, backing off.")
150 | bin_edges[jj] = cur_bin_edges # ndarray of shape (n_bins[jj] + 1,)
151 | elif self.strategy == "kmeans":
152 | # Deterministic initialization with uniform spacing
153 | uniform_edges = np.linspace(col_min, col_max, n_bins[jj] + 1)
154 | init = (uniform_edges[1:] + uniform_edges[:-1])[:, None] * 0.5
155 |
156 | # 1D k-means procedure
157 | km = KMeans(n_clusters=n_bins[jj], init=init, n_init=1)
158 | centers = km.fit(
159 | column[:, None], sample_weight=sample_weight
160 | ).cluster_centers_[:, 0]
161 | # Must sort, centers may be unsorted even with sorted init
162 | centers.sort()
163 | bin_edges[jj] = (centers[1:] + centers[:-1]) * 0.5
164 | bin_edges[jj] = np.r_[col_min, bin_edges[jj], col_max]
165 |
166 | """
167 | # Remove bins whose width are too small (i.e., <= 1e-8)
168 | if self.strategy in ('quantile', 'kmeans', 'kdiquantile'):
169 | mask = np.ediff1d(bin_edges[jj], to_begin=np.inf) > 1e-8
170 | bin_edges[jj] = bin_edges[jj][mask]
171 | if len(bin_edges[jj]) - 1 != n_bins[jj]:
172 | warnings.warn(
173 | 'Bins whose width are too small (i.e., <= '
174 | '1e-8) in feature %d are removed. Consider '
175 | 'decreasing the number of bins.' % jj
176 | )
177 | n_bins[jj] = len(bin_edges[jj]) - 1
178 | """
179 |
180 | self.bin_edges_ = bin_edges
181 | self.n_bins_ = n_bins
182 |
183 | if "onehot" in self.encode:
184 | self._encoder = OneHotEncoder(
185 | categories=[np.arange(i) for i in self.n_bins_],
186 | sparse_output=self.encode == "onehot",
187 | dtype=output_dtype,
188 | )
189 | # Fit the OneHotEncoder with toy datasets
190 | # so that it's ready for use after the KDIQuantizer is fitted
191 | self._encoder.fit(np.zeros((1, len(self.n_bins_))))
192 |
193 | return self
194 |
195 | def _validate_n_bins(self, n_features):
196 | """Returns n_bins_, the number of bins per feature."""
197 | orig_bins = self.n_bins
198 | if isinstance(orig_bins, Integral):
199 | return np.full(n_features, orig_bins, dtype=int)
200 |
201 | n_bins = check_array(orig_bins, dtype=int, copy=True, ensure_2d=False)
202 |
203 | if n_bins.ndim > 1 or n_bins.shape[0] != n_features:
204 | raise ValueError("n_bins must be a scalar or array of shape (n_features,).")
205 |
206 | bad_nbins_value = (n_bins < 2) | (n_bins != orig_bins)
207 |
208 | violating_indices = np.where(bad_nbins_value)[0]
209 | if violating_indices.shape[0] > 0:
210 | indices = ", ".join(str(i) for i in violating_indices)
211 | raise ValueError(
212 | "{} received an invalid number "
213 | "of bins at indices {}. Number of bins "
214 | "must be at least 2, and must be an int.".format(
215 | KDIQuantizer.__name__, indices
216 | )
217 | )
218 | return n_bins
219 |
220 | def transform(self, X):
221 | """
222 | Discretize the data.
223 |
224 | Parameters
225 | ----------
226 | X : array-like of shape (n_samples, n_features)
227 | Data to be discretized.
228 |
229 | Returns
230 | -------
231 | Xt : {ndarray, sparse matrix}, dtype={np.float32, np.float64}
232 | Data in the binned space. Will be a sparse matrix if
233 | `self.encode='onehot'` and ndarray otherwise.
234 | """
235 | check_is_fitted(self)
236 |
237 | # check input and attribute dtypes
238 | dtype = (np.float64, np.float32) if self.dtype is None else self.dtype
239 | Xt = self._validate_data(X, copy=True, dtype=dtype, reset=False)
240 |
241 | bin_edges = self.bin_edges_
242 | for jj in range(Xt.shape[1]):
243 | Xt[:, jj] = np.searchsorted(bin_edges[jj][1:-1], Xt[:, jj], side="right")
244 |
245 | if self.encode == "ordinal":
246 | return Xt
247 |
248 | dtype_init = None
249 | if "onehot" in self.encode:
250 | dtype_init = self._encoder.dtype
251 | self._encoder.dtype = Xt.dtype
252 | try:
253 | Xt_enc = self._encoder.transform(Xt)
254 | finally:
255 | # revert the initial dtype to avoid modifying self.
256 | self._encoder.dtype = dtype_init
257 | return Xt_enc
258 |
259 | def inverse_transform(self, X):
260 | """
261 | Transform discretized data back to original feature space.
262 |
263 | Note that this function does not regenerate the original data
264 | due to discretization rounding.
265 |
266 | Parameters
267 | ----------
268 | X : array-like of shape (n_samples, n_features)
269 | Transformed data in the binned space.
270 |
271 | Returns
272 | -------
273 | Xinv : ndarray, dtype={np.float32, np.float64}
274 | Data in the original feature space.
275 | """
276 | check_is_fitted(self)
277 |
278 | if "onehot" in self.encode:
279 | X = self._encoder.inverse_transform(X)
280 |
281 | Xinv = check_array(X, copy=True, dtype=(np.float64, np.float32))
282 | n_features = self.n_bins_.shape[0]
283 | if Xinv.shape[1] != n_features:
284 | raise ValueError(
285 | "Incorrect number of features. Expecting {}, received {}.".format(
286 | n_features, Xinv.shape[1]
287 | )
288 | )
289 |
290 | for jj in range(n_features):
291 | bin_edges = self.bin_edges_[jj]
292 | bin_centers = (bin_edges[1:] + bin_edges[:-1]) * 0.5
293 | Xinv[:, jj] = bin_centers[(Xinv[:, jj]).astype(np.int64)]
294 |
295 | return Xinv
296 |
297 | def inverse_transform_sample(self, X=None, random_state=None):
298 | rng = check_random_state(random_state)
299 | check_is_fitted(self)
300 |
301 | if "onehot" in self.encode:
302 | X = self._encoder.inverse_transform(X)
303 |
304 | Xinv = check_array(X, copy=True, dtype=(np.float64, np.float32))
305 | n_features = self.n_bins_.shape[0]
306 | if Xinv.shape[1] != n_features:
307 | raise ValueError(
308 | "Incorrect number of features. Expecting {}, received {}.".format(
309 | n_features, Xinv.shape[1]
310 | )
311 | )
312 | n = X.shape[0]
313 | for jj in range(n_features):
314 | jitter = rng.uniform(0.0, 1.0, size=n)
315 | bin_edges = self.bin_edges_[jj]
316 | bin_centers = (bin_edges[1:] + bin_edges[:-1]) * 0.5
317 | bin_lefts = bin_edges[1:][(Xinv[:, jj]).astype(np.int64)]
318 | bin_rights = bin_edges[:-1][(Xinv[:, jj]).astype(np.int64)]
319 | Xinv[:, jj] = bin_lefts * jitter + bin_rights * (1 - jitter)
320 |
321 | return Xinv
322 |
323 | def get_feature_names_out(self, input_features=None):
324 | """Get output feature names.
325 |
326 | Parameters
327 | ----------
328 | input_features : array-like of str or None, default=None
329 | Input features.
330 |
331 | - If `input_features` is `None`, then `feature_names_in_` is
332 | used as feature names in. If `feature_names_in_` is not defined,
333 | then the following input feature names are generated:
334 | `['x0', 'x1', ..., 'x(n_features_in_ - 1)']`.
335 | - If `input_features` is an array-like, then `input_features` must
336 | match `feature_names_in_` if `feature_names_in_` is defined.
337 |
338 | Returns
339 | -------
340 | feature_names_out : ndarray of str objects
341 | Transformed feature names.
342 | """
343 | check_is_fitted(self, "n_features_in_")
344 | input_features = _check_feature_names_in(self, input_features)
345 | if hasattr(self, "_encoder"):
346 | return self._encoder.get_feature_names_out(input_features)
347 |
348 | # ordinal encoding
349 | return input_features
350 |
--------------------------------------------------------------------------------
/utrees/baltobot.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional
2 | from copy import deepcopy
3 | import warnings
4 |
5 | import numpy as np
6 | import xgboost as xgb
7 |
8 | from sklearn.base import BaseEstimator, ClassifierMixin
9 | from sklearn.preprocessing import LabelEncoder
10 | from sklearn.utils import check_random_state
11 |
12 | from utrees.kdi_quantizer import KDIQuantizer
13 |
14 | XGBOOST_DEFAULT_KWARGS = {
15 | "tree_method": "hist",
16 | "verbosity": 1,
17 | "objective": "binary:logistic",
18 | }
19 |
20 | TABPFN_DEFAULT_KWARGS = {
21 | "N_ensemble_configurations": 1,
22 | "seed": 42,
23 | "batch_size_inference": 1,
24 | "subsample_features": True,
25 | }
26 |
27 |
28 | class NanTabPFNClassifier(BaseEstimator, ClassifierMixin):
29 | def __init__(
30 | self,
31 | random_state=None,
32 | **tabpfn_kwargs,
33 | ):
34 | # We do not import this at the top, because Python dependencies are a nightmare.
35 | from tabpfn import TabPFNClassifier
36 |
37 | self.tabpfn = TabPFNClassifier(**tabpfn_kwargs)
38 | self.rng = check_random_state(random_state)
39 |
40 | def fit(
41 | self,
42 | X,
43 | y,
44 | ):
45 | self.Xtrain = X.copy()
46 | self.ytrain = y.copy()
47 | self.n_classes = np.unique(y).shape[0]
48 | return self
49 |
50 | def predict_proba(
51 | self,
52 | X,
53 | ):
54 | dtype = self.Xtrain.dtype
55 | X = X.astype(dtype)
56 | n_test, d = X.shape
57 | pred_probas = np.zeros((n_test, self.n_classes), dtype=dtype)
58 |
59 | obsXtest = ~np.isnan(X)
60 | obsXtrain = ~np.isnan(self.Xtrain)
61 | obs_patterns = np.unique(obsXtest, axis=0)
62 | n_patterns = obs_patterns.shape[0]
63 | for pix in range(n_patterns):
64 | isobs = obs_patterns[[pix], :]
65 | test_ixs = (obsXtest == isobs).all(axis=1).nonzero()[0]
66 | obs_ixs = isobs.ravel().nonzero()[0]
67 | train_ixs = (obsXtrain | ~isobs).all(axis=1).nonzero()[0]
68 | while train_ixs.shape[0] == 0:
69 | # No training example covers the observed features in this test example.
70 | # This should be rare, but we handle it anyways.
71 | # We introduce fake missingness into test until training covers it.
72 | isobs[0, self.rng.choice(obs_ixs)] = False
73 | obs_ixs = isobs.ravel().nonzero()[0]
74 | train_ixs = (obsXtrain | ~isobs).all(axis=1).nonzero()[0]
75 | if isobs.sum() == 0:
76 | curXtrain = np.zeros((self.Xtrain.shape[0], 1), dtype=dtype)
77 | curytrain = self.ytrain
78 | curXtest = np.zeros((X[test_ixs, :].shape[0], 1), dtype=dtype)
79 | else:
80 | curXtrain = self.Xtrain[np.ix_(train_ixs, obs_ixs)]
81 | curytrain = self.ytrain[train_ixs]
82 | curXtest = X[np.ix_(test_ixs, obs_ixs)]
83 | if curXtrain.shape[0] > 1024:
84 | # warnings.warn(f'TabPFN must shrink from {curXtrain.shape[0]} to 1024 samples')
85 | sixs = self.rng.choice(curXtrain.shape[0], 1024, replace=False)
86 | curXtrain = curXtrain[sixs, :]
87 | curytrain = curytrain[sixs]
88 | if np.unique(curytrain).shape[0] == 1:
89 | # Removing missingness might make us only see one label
90 | pred_probas[test_ixs, curytrain[0]] = 1.0
91 | else:
92 | self.tabpfn.fit(curXtrain, curytrain)
93 | cur_pred_prob = self.tabpfn.predict_proba(curXtest)
94 | pred_probas[test_ixs, :] = cur_pred_prob
95 | return pred_probas
96 |
97 |
98 | class Baltobot(BaseEstimator):
99 | """Performs probabilistic prediction using a BALanced Tree Of BOosted Trees.
100 |
101 | Parameters
102 | ----------
103 | depth : int >= 0
104 | Depth of balanced binary tree for recursively quantizing each feature.
105 | The total number of quantization bins is 2^depth.
106 |
107 | clf_kwargs : dict
108 | Arguments for XGBoost (or TabPFN) classifier.
109 |
110 | strategy : 'kdiquantile', 'quantile', 'uniform', 'kmeans'
111 | The quantization strategy for discretizing continuous features.
112 |
113 | softmax_temp : float > 0
114 | Softmax temperature for sampling from predicted probabilities.
115 | As temperature decreases below the default of 1, predictions converge
116 | to the argmax of each conditional distribution.
117 |
118 | tabpfn: bool, or int >= 0
119 | Whether to use TabPFN instead of XGBoost classifier.
120 |
121 | random_state : int, RandomState instance or None, default=None
122 | Determines random number generation.
123 |
124 |
125 | Attributes
126 | ----------
127 | constant_val_ : None or float
128 | Stores the value of a non-varying target variable, avoiding uniform sampling.
129 | This allows us to sample from discrete or mixed-type random variables.
130 |
131 | quantizer_ : KDIQuantizer
132 | Maps continuous value to bin. Almost always 2 bins.
133 |
134 | encoder_ : LabelEncoder
135 | Maps bin (float) to label (int). Almost always 2 classes.
136 |
137 | clfer_ : XGBClassifier
138 | Binary classifier that tell us whether to go left or right bin.
139 |
140 | left_child_, right_child_ : Baltobot
141 | Next level in the balanced binary tree.
142 | """
143 |
144 | def __init__(
145 | self,
146 | depth: int = 4,
147 | clf_kwargs: dict = {},
148 | strategy: str = "kdiquantile",
149 | softmax_temp: float = 1.0,
150 | tabpfn: bool = False,
151 | random_state=None,
152 | ):
153 |
154 | self.depth = depth
155 | self.clf_kwargs = clf_kwargs
156 | self.strategy = strategy
157 | self.softmax_temp = softmax_temp
158 | self.tabpfn = tabpfn
159 | self.random_state = check_random_state(random_state)
160 | assert depth < 10 # 2^10 models is insane
161 |
162 | self.left_child_ = None
163 | self.right_child_ = None
164 | self.quantizer_ = KDIQuantizer(
165 | n_bins=2, strategy=strategy, random_state=self.random_state
166 | )
167 | self.encoder_ = LabelEncoder()
168 | my_kwargs = (
169 | XGBOOST_DEFAULT_KWARGS | clf_kwargs | {"random_state": self.random_state}
170 | )
171 | self.clfer_ = xgb.XGBClassifier(**my_kwargs)
172 |
173 | if self.tabpfn:
174 | assert self.tabpfn > 0
175 | my_kwargs = TABPFN_DEFAULT_KWARGS | clf_kwargs
176 | self.clfer_ = NanTabPFNClassifier(**my_kwargs)
177 |
178 | self.constant_val_ = None
179 |
180 | if depth > 0:
181 | self.left_child_ = Baltobot(
182 | depth=depth - 1,
183 | clf_kwargs=clf_kwargs,
184 | strategy=strategy,
185 | softmax_temp=softmax_temp,
186 | tabpfn=tabpfn,
187 | random_state=self.random_state,
188 | )
189 | self.right_child_ = Baltobot(
190 | depth=depth - 1,
191 | clf_kwargs=clf_kwargs,
192 | strategy=strategy,
193 | softmax_temp=softmax_temp,
194 | tabpfn=tabpfn,
195 | random_state=self.random_state,
196 | )
197 |
198 | def fit(
199 | self,
200 | X: np.ndarray,
201 | y: np.ndarray,
202 | ):
203 | """Recursively fits balanced tree of boosted tree classifiers.
204 |
205 | Parameters
206 | ----------
207 | X : array-like of shape (n_samples, n_features)
208 | Input variables of train set.
209 |
210 | y : array-like of shape (n_samples,)
211 | Continuous target variable of train set.
212 |
213 | Returns
214 | -------
215 | self
216 | """
217 |
218 | assert np.isnan(y).sum() == 0
219 | if y.size == 0:
220 | self.constant_val_ = 0.0
221 | return self
222 | self.constant_val_ = None
223 | rng = check_random_state(self.random_state)
224 |
225 | self.quantizer_.fit(y.reshape(-1, 1))
226 | y_quant = self.quantizer_.transform(y.reshape(-1, 1)).ravel()
227 | self.encoder_.fit(y_quant)
228 | if len(self.encoder_.classes_) != 2:
229 | assert len(self.encoder_.classes_) == 1
230 | assert np.unique(y).shape[0] == 1
231 | self.constant_val_ = np.mean(y)
232 | return self
233 | y_enc = self.encoder_.transform(y_quant)
234 |
235 | self.clfer_.fit(X, y_enc)
236 |
237 | if self.depth > 0:
238 | left_ixs = y_enc == 0
239 | right_ixs = y_enc == 1
240 | self.left_child_.fit(X[left_ixs, :], y[left_ixs])
241 | self.right_child_.fit(X[right_ixs, :], y[right_ixs])
242 |
243 | return self
244 |
245 | def sample(
246 | self,
247 | X: np.ndarray,
248 | ):
249 | """Samples y conditional on X.
250 |
251 | Parameters
252 | ----------
253 | X : array-like of shape (n_samples, n_features)
254 | Input variables of test set.
255 |
256 | Returns
257 | -------
258 | y : array-like of shape (n_samples,)
259 | Samples from the conditional distribution.
260 | """
261 |
262 | n, _ = X.shape
263 | rng = check_random_state(self.random_state)
264 |
265 | if self.constant_val_ is not None:
266 | return np.full((n,), fill_value=self.constant_val_)
267 |
268 | pred_prob = self.clfer_.predict_proba(X)
269 |
270 | with np.errstate(divide="ignore"):
271 | annealed_logits = np.log(pred_prob) / self.softmax_temp
272 | pred_prob = np.exp(annealed_logits) / np.sum(
273 | np.exp(annealed_logits), axis=1, keepdims=True
274 | )
275 | pred_enc = np.zeros((n,), dtype=int)
276 | for i in range(n):
277 | pred_enc[i] = rng.choice(a=len(self.encoder_.classes_), p=pred_prob[i, :])
278 |
279 | if self.depth == 0:
280 | pred_quant = self.encoder_.inverse_transform(pred_enc)
281 | pred_val = self.quantizer_.inverse_transform_sample(
282 | pred_quant.reshape(-1, 1)
283 | )
284 | return pred_val.ravel()
285 | else:
286 | left_ixs = pred_enc == 0
287 | right_ixs = pred_enc == 1
288 | pred_val = np.zeros((n,))
289 | if left_ixs.sum() > 0:
290 | pred_val[left_ixs] = self.left_child_.sample(X[left_ixs, :])
291 | if right_ixs.sum() > 0:
292 | pred_val[right_ixs] = self.right_child_.sample(X[right_ixs, :])
293 | return pred_val
294 |
295 | def score_samples(
296 | self,
297 | X: np.ndarray,
298 | y: np.ndarray,
299 | ):
300 | """Compute the conditional log-likelihood of y given X.
301 |
302 | Parameters
303 | ----------
304 | X : array-like of shape (n_samples, n_features)
305 | An array of points to query. Last dimension should match dimension
306 | of training data (n_features).
307 |
308 | y : array-like of shape (n_samples,)
309 | Output variable.
310 |
311 | Returns
312 | -------
313 | density : ndarray of shape (n_samples,)
314 | Log-likelihood of each sample. These are normalized to be
315 | probability densities, after exponentiating, though will not
316 | necessarily exactly sum to 1 due to numerical error.
317 | """
318 |
319 | n, d = X.shape
320 | rng = check_random_state(self.random_state)
321 | if self.constant_val_ is not None:
322 | return np.log(y == self.constant_val_)
323 |
324 | if self.tabpfn:
325 | XX = X.copy()
326 | XX[np.isnan(XX)] = rng.normal(size=XX.shape)[np.isnan(XX)]
327 | pred_prob = self.clfer_.predict_proba(XX)
328 | else:
329 | pred_prob = self.clfer_.predict_proba(X)
330 |
331 | y_quant = self.quantizer_.transform(y.reshape(-1, 1)).ravel()
332 | y_enc = self.encoder_.transform(y_quant)
333 | if self.depth == 0:
334 | left_ixs = y_enc == 0
335 | right_ixs = y_enc == 1
336 | bin_edges = self.quantizer_.bin_edges_[0]
337 | left_width = bin_edges[1] - bin_edges[0] + 1e-15
338 | right_width = bin_edges[2] - bin_edges[1] + 1e-15
339 | scores = np.zeros((n,))
340 | scores[left_ixs] = np.log(pred_prob[left_ixs, 0] / left_width)
341 | scores[right_ixs] = np.log(pred_prob[right_ixs, 1] / right_width)
342 | return scores
343 | else:
344 | left_ixs = y_enc == 0
345 | right_ixs = y_enc == 1
346 | scores = np.zeros((n,))
347 | if left_ixs.sum() > 0:
348 | left_scores = self.left_child_.score_samples(
349 | X[left_ixs, :], y[left_ixs]
350 | )
351 | scores[left_ixs] = np.log(pred_prob[left_ixs, 0]) + left_scores
352 | if right_ixs.sum() > 0:
353 | right_scores = self.right_child_.score_samples(
354 | X[right_ixs, :], y[right_ixs]
355 | )
356 | scores[right_ixs] = np.log(pred_prob[right_ixs, 1]) + right_scores
357 | return scores
358 |
359 | def score(
360 | self,
361 | X: np.ndarray,
362 | y: np.ndarray,
363 | ):
364 | """Compute the total log probability density under the model.
365 |
366 | Parameters
367 | ----------
368 | X : array_like, shape (n_samples, n_features)
369 | List of n_features-dimensional data points. Each row
370 | corresponds to a single data point.
371 |
372 | y : array-like of shape (n_samples,)
373 | Output variable.
374 |
375 | Returns
376 | -------
377 | logprob : float
378 | Total log-likelihood of y given X.
379 | """
380 | logprob = np.sum(self.score_samples(X, y))
381 | return logprob
382 |
--------------------------------------------------------------------------------
/paper/baltobot-simple-demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "c8be354b-c400-4d08-9434-cc0051072192",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import matplotlib as mpl\n",
11 | "mpl.rcParams['font.family'] = 'Arial'\n",
12 | "mpl.rcParams['text.usetex'] = False\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import numpy as np\n",
15 | "import seaborn as sns\n",
16 | "import pandas as pd\n",
17 | "import time\n",
18 | "from utrees import Baltobot\n",
19 | "from treeffuser import Treeffuser\n",
20 | "time.time()\n",
21 | "\n",
22 | "# Generate the data\n",
23 | "seed = 0\n",
24 | "n = 5000\n",
25 | "rng = np.random.default_rng(seed=seed)\n",
26 | "x = rng.uniform(0, 2 * np.pi, size=n)\n",
27 | "z = rng.integers(0, 2, size=n)\n",
28 | "y = z * np.sin(x - np.pi / 2) + (1 - z) * np.cos(x) + rng.laplace(scale=x / 30, size=n)"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "id": "6383abd1-52a1-42ab-ae29-eb09b5d28c6b",
35 | "metadata": {
36 | "scrolled": true
37 | },
38 | "outputs": [
39 | {
40 | "name": "stderr",
41 | "output_type": "stream",
42 | "text": [
43 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n",
44 | " X = _check_array(X)\n",
45 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:113: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n",
46 | " y = _check_array(y)\n",
47 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n",
48 | " X = _check_array(X)\n",
49 | "/Users/calvinm/sandbox/unmasking-trees/utrees/baltobot.py:49: UserWarning: Support for TabPFN is experimental.\n",
50 | " warnings.warn('Support for TabPFN is experimental.')\n",
51 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
52 | " from .autonotebook import tqdm as notebook_tqdm\n"
53 | ]
54 | }
55 | ],
56 | "source": [
57 | "# Fit the models\n",
58 | "start_time = time.time()\n",
59 | "tfer = Treeffuser(sde_initialize_from_data=True, seed=seed)\n",
60 | "tfer.fit(x, y)\n",
61 | "tf_train_time = time.time() - start_time\n",
62 | "y_tfer = tfer.sample(x, n_samples=1, seed=seed, verbose=True)\n",
63 | "tf_time = time.time() - start_time\n",
64 | "\n",
65 | "start_time = time.time()\n",
66 | "tber = Baltobot(random_state=seed)\n",
67 | "tber.fit(x.reshape(-1, 1), y)\n",
68 | "tb_train_time = time.time() - start_time\n",
69 | "y_tber = tber.sample(x.reshape(-1, 1))\n",
70 | "tb_time = time.time() - start_time\n",
71 | "\n",
72 | "start_time = time.time()\n",
73 | "tbtaber = Baltobot(tabpfn=True, random_state=seed)\n",
74 | "tbtaber.fit(x.reshape(-1, 1), y)\n",
75 | "tbtab_train_time = time.time() - start_time\n",
76 | "y_tbtaber = tbtaber.sample(x.reshape(-1, 1))\n",
77 | "tbtab_time = time.time() - start_time"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 3,
83 | "id": "2cb69a60-0533-40f0-a405-93be28619f6d",
84 | "metadata": {},
85 | "outputs": [
86 | {
87 | "data": {
88 | "text/html": [
89 | "\n",
90 | "\n",
103 | "
\n",
104 | " \n",
105 | " \n",
106 | " | \n",
107 | " Method | \n",
108 | " Task | \n",
109 | " Time | \n",
110 | "
\n",
111 | " \n",
112 | " \n",
113 | " \n",
114 | " | 0 | \n",
115 | " Treeffuser | \n",
116 | " Total | \n",
117 | " 6.401965 | \n",
118 | "
\n",
119 | " \n",
120 | " | 1 | \n",
121 | " Treeffuser | \n",
122 | " Training | \n",
123 | " 1.356733 | \n",
124 | "
\n",
125 | " \n",
126 | " | 2 | \n",
127 | " Treeffuser | \n",
128 | " Sampling | \n",
129 | " 5.045232 | \n",
130 | "
\n",
131 | " \n",
132 | " | 3 | \n",
133 | " Baltobot | \n",
134 | " Total | \n",
135 | " 3.175086 | \n",
136 | "
\n",
137 | " \n",
138 | " | 4 | \n",
139 | " Baltobot | \n",
140 | " Training | \n",
141 | " 2.396088 | \n",
142 | "
\n",
143 | " \n",
144 | " | 5 | \n",
145 | " Baltobot | \n",
146 | " Sampling | \n",
147 | " 0.778998 | \n",
148 | "
\n",
149 | " \n",
150 | " | 6 | \n",
151 | " BaltoboTabPFN | \n",
152 | " Total | \n",
153 | " 12.216128 | \n",
154 | "
\n",
155 | " \n",
156 | " | 7 | \n",
157 | " BaltoboTabPFN | \n",
158 | " Training | \n",
159 | " 2.148450 | \n",
160 | "
\n",
161 | " \n",
162 | " | 8 | \n",
163 | " BaltoboTabPFN | \n",
164 | " Sampling | \n",
165 | " 10.067678 | \n",
166 | "
\n",
167 | " \n",
168 | "
\n",
169 | "
"
170 | ],
171 | "text/plain": [
172 | " Method Task Time\n",
173 | "0 Treeffuser Total 6.401965\n",
174 | "1 Treeffuser Training 1.356733\n",
175 | "2 Treeffuser Sampling 5.045232\n",
176 | "3 Baltobot Total 3.175086\n",
177 | "4 Baltobot Training 2.396088\n",
178 | "5 Baltobot Sampling 0.778998\n",
179 | "6 BaltoboTabPFN Total 12.216128\n",
180 | "7 BaltoboTabPFN Training 2.148450\n",
181 | "8 BaltoboTabPFN Sampling 10.067678"
182 | ]
183 | },
184 | "execution_count": 3,
185 | "metadata": {},
186 | "output_type": "execute_result"
187 | }
188 | ],
189 | "source": [
190 | "fig, axes = plt.subplots(nrows=3, figsize=(7, 7), sharex=True, dpi=300);\n",
191 | "axes[0].scatter(x, y, s=1, label=\"Observed data\")\n",
192 | "axes[0].scatter(x, y_tfer[0, :], s=1, alpha=0.7, label=\"Treeffuser samples\")\n",
193 | "axes[0].legend();\n",
194 | "\n",
195 | "axes[1].scatter(x, y, s=1, label=\"Observed data\")\n",
196 | "axes[1].scatter(x, y_tber, s=1, alpha=0.7, label=\"Baltobot samples\")\n",
197 | "axes[1].legend();\n",
198 | "\n",
199 | "axes[2].scatter(x, y, s=1, label=\"Observed data\")\n",
200 | "axes[2].scatter(x, y_tbtaber, s=1, alpha=0.7, label=\"BaltoboTabPFN samples\")\n",
201 | "axes[2].legend();\n",
202 | "plt.tight_layout();\n",
203 | "plt.savefig('wave-demo.png');\n",
204 | "plt.close();\n",
205 | "\n",
206 | "plt.figure(dpi=200, figsize=(4,3));\n",
207 | "total_time_df = pd.DataFrame.from_dict({'Treeffuser': [tf_time], 'Baltobot': [tb_time], 'BaltoboTabPFN': [tbtab_time]}).T\n",
208 | "total_time_df.columns = ['Total']\n",
209 | "train_time_df = pd.DataFrame.from_dict({'Treeffuser': [tf_train_time], 'Baltobot': [tb_train_time], 'BaltoboTabPFN': [tbtab_train_time]}).T\n",
210 | "train_time_df.columns = ['Training']\n",
211 | "time_df = pd.concat([total_time_df, train_time_df], axis=1)\n",
212 | "time_df['Sampling'] = time_df['Total'] - time_df['Training']\n",
213 | "\n",
214 | "time_dff = time_df.stack().reset_index()\n",
215 | "time_dff.columns = ['Method', 'Task', 'Time']\n",
216 | "sns.barplot(data=time_dff, y='Method', x='Time', hue='Task');\n",
217 | "plt.ylabel('Method');\n",
218 | "plt.xlabel('Time (s)');\n",
219 | "plt.tight_layout();\n",
220 | "plt.savefig('wave-demo-time.png');\n",
221 | "plt.close()\n",
222 | "time_dff"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 4,
228 | "id": "ccb3079e-72a3-403a-8ea4-d3b9a2435493",
229 | "metadata": {},
230 | "outputs": [],
231 | "source": [
232 | "lhs, rhs = np.meshgrid(np.linspace(-1, 7, 30), np.linspace(-3,2, 30))\n",
233 | "lhsrhs = np.hstack([lhs.reshape(-1, 1), rhs.reshape(-1, 1)])\n",
234 | "plt.figure();\n",
235 | "plt.scatter(lhsrhs[:, 0], lhsrhs[:, 1])\n",
236 | "scores = tber.score_samples(lhs.reshape(-1, 1), rhs.reshape(-1))\n",
237 | "plt.close();\n",
238 | "plt.figure();\n",
239 | "plt.scatter(lhsrhs[:, 0], lhsrhs[:, 1], s=100*np.exp(scores));\n",
240 | "plt.close();"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 5,
246 | "id": "671dac1d-054b-4f64-8ca8-17676bd064b7",
247 | "metadata": {},
248 | "outputs": [
249 | {
250 | "name": "stdout",
251 | "output_type": "stream",
252 | "text": [
253 | "0.9820769258371294 0.9859963150095391\n"
254 | ]
255 | }
256 | ],
257 | "source": [
258 | "Xs = 2 * np.ones((1000, 1))\n",
259 | "Ys = np.linspace(-5, 5, 1000)\n",
260 | "tb_scores = tber.score_samples(Xs, Ys)\n",
261 | "tbtab_scores = tbtaber.score_samples(Xs, Ys)\n",
262 | "print(np.exp(tb_scores).sum() * (Ys[1]-Ys[0]), np.exp(tbtab_scores).sum() * (Ys[1]-Ys[0]))\n",
263 | "plt.figure(figsize=(4,2), dpi=200);\n",
264 | "plt.plot(Ys, np.exp(tb_scores), label='Baltobot');\n",
265 | "plt.plot(Ys, np.exp(tbtab_scores), '--', label='BaltoboTabPFN');\n",
266 | "plt.xlim(-2, 2);\n",
267 | "plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05));\n",
268 | "plt.xlabel('y');\n",
269 | "plt.ylabel('pdf at x=2');\n",
270 | "plt.tight_layout();\n",
271 | "plt.savefig('wave-pdfat2.png');\n",
272 | "plt.close();"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": 6,
278 | "id": "24ca0b83-9b7e-468f-8f34-823a36c917e6",
279 | "metadata": {},
280 | "outputs": [
281 | {
282 | "name": "stderr",
283 | "output_type": "stream",
284 | "text": [
285 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n",
286 | " X = _check_array(X)\n",
287 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:113: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n",
288 | " y = _check_array(y)\n",
289 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n",
290 | " X = _check_array(X)\n"
291 | ]
292 | }
293 | ],
294 | "source": [
295 | "nP = 500\n",
296 | "rng = np.random.default_rng(seed=seed)\n",
297 | "XP = rng.uniform(0, 3, size=nP)\n",
298 | "YP = rng.poisson(np.sqrt(XP), size=nP)\n",
299 | "tfer = Treeffuser(sde_initialize_from_data=True, seed=seed)\n",
300 | "tfer.fit(XP, YP)\n",
301 | "YP_tfer = tfer.sample(XP, n_samples=1, seed=seed, verbose=True)\n",
302 | "tber = Baltobot(random_state=seed)\n",
303 | "tber.fit(XP.reshape(-1, 1), YP)\n",
304 | "YP_tber = tber.sample(XP.reshape(-1, 1))"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 7,
310 | "id": "12dfdd1a-805c-4f7d-9b58-f06552ab97fb",
311 | "metadata": {},
312 | "outputs": [],
313 | "source": [
314 | "dfP = pd.DataFrame(); dfP['x'] = XP; dfP['y'] = YP\n",
315 | "s = 8; linewidth=0.3; edgecolor='white'; markercolor='blue';\n",
316 | "fig, axes = plt.subplots(figsize=(7,3), ncols=3, sharey=True, dpi=500);\n",
317 | "sns.scatterplot(data=dfP, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[0])\n",
318 | "axes[0].set_title('Original data');\n",
319 | "dfP_tfer = pd.DataFrame(); dfP_tfer['x'] = XP; dfP_tfer['y'] = YP_tfer.ravel()\n",
320 | "sns.scatterplot(data=dfP_tfer, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[1])\n",
321 | "axes[1].set_title('Treeffuser')\n",
322 | "dfP_tber = pd.DataFrame(); dfP_tber['x'] = XP; dfP_tber['y'] = YP_tber\n",
323 | "sns.scatterplot(data=dfP_tber, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[2])\n",
324 | "axes[2].set_title('Baltobot')\n",
325 | "plt.tight_layout();\n",
326 | "plt.savefig('poisson-demo.png');\n",
327 | "plt.close();"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "id": "92ae5c29-84e4-403b-8dd4-5fd7f0bb827c",
334 | "metadata": {},
335 | "outputs": [],
336 | "source": []
337 | }
338 | ],
339 | "metadata": {
340 | "kernelspec": {
341 | "display_name": "Python 3 (ipykernel)",
342 | "language": "python",
343 | "name": "python3"
344 | },
345 | "language_info": {
346 | "codemirror_mode": {
347 | "name": "ipython",
348 | "version": 3
349 | },
350 | "file_extension": ".py",
351 | "mimetype": "text/x-python",
352 | "name": "python",
353 | "nbconvert_exporter": "python",
354 | "pygments_lexer": "ipython3",
355 | "version": "3.9.19"
356 | }
357 | },
358 | "nbformat": 4,
359 | "nbformat_minor": 5
360 | }
361 |
--------------------------------------------------------------------------------
/paper/iris.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "1f6d6fc7-a52a-471d-9607-ad42f993159e",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import matplotlib as mpl\n",
11 | "mpl.rcParams['font.family'] = 'Arial'\n",
12 | "mpl.rcParams['text.usetex'] = False\n",
13 | "import numpy as np\n",
14 | "import pandas as pd\n",
15 | "import seaborn as sns\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "\n",
18 | "from sklearn.datasets import load_iris\n",
19 | "from sklearn.utils import check_random_state\n",
20 | "from ForestDiffusion import ForestDiffusionModel\n",
21 | "from utrees import UnmaskingTrees\n",
22 | "from missforest import MissForest"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "6348b6f5-8eff-461d-833a-73c1bb8cfc16",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# Iris: numpy dataset with 4 variables (all numerical) and 1 outcome (categorical)\n",
33 | "my_data = load_iris()\n",
34 | "X, y = my_data['data'], my_data['target']\n",
35 | "Xy = np.concatenate((X, np.expand_dims(y, axis=1)), axis=1)\n",
36 | "palette = {'setosa': 'red', 'versicolor': 'green', 'virginica': 'blue'} \n",
37 | "alpha = 0.6\n",
38 | "rix = 0"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "5eb55da3-b595-448a-a776-5b6e67ff9a38",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "forestflow = ForestDiffusionModel(\n",
49 | " X, label_y=y, diffusion_type='flow', \n",
50 | " n_t=50, duplicate_K=100, bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n",
51 | "Xy_gen_forestflow = forestflow.generate(batch_size=X.shape[0]) # last variable will be the label_y\n",
52 | "\n",
53 | "forestvp = ForestDiffusionModel(\n",
54 | " X, label_y=y, diffusion_type='vp', \n",
55 | " n_t=50, duplicate_K=100, bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n",
56 | "Xy_gen_forestvp = forestvp.generate(batch_size=X.shape[0]) # last variable will be the label_y\n",
57 | "\n",
58 | "utrees = UnmaskingTrees(random_state=rix)\n",
59 | "utrees.fit(Xy, quantize_cols=['continuous']*4 + ['categorical'])\n",
60 | "Xy_gen_utrees = utrees.generate(n_generate=X.shape[0]);"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "id": "1b97c3e5-c53b-4493-b5ca-235f8328bc81",
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "X_mcar = X.copy()\n",
71 | "rng = check_random_state(rix)\n",
72 | "hasmissing = rng.normal(size=(150,)) < 0.\n",
73 | "X_hasmissing = X_mcar[hasmissing, :]\n",
74 | "X_hasmissing[rng.normal(size=X_hasmissing.shape) < 0] = np.nan\n",
75 | "X_mcar[hasmissing, :] = X_hasmissing\n",
76 | "Xy_mcar = np.concatenate((X_mcar, np.expand_dims(y, axis=1)), axis=1)\n",
77 | "np.isnan(X_mcar).sum()"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "5047c2aa-cc70-4b08-a153-89bed3ee51d3",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "missfer = MissForest(random_state=rix)\n",
88 | "Xy_impute_missf = missfer.fit_transform(Xy_mcar.copy(), cat_vars=[4])\n",
89 | "\n",
90 | "forestvp_mcar = ForestDiffusionModel(\n",
91 | " X=Xy_mcar, diffusion_type='vp', \n",
92 | " n_t=50, duplicate_K=100, bin_indexes=[], cat_indexes=[4], int_indexes=[0], n_jobs=-1, seed=rix)\n",
93 | "Xy_impute_forest_fast = forestvp_mcar.impute(k=1) # regular (fast)\n",
94 | "Xy_impute_forest = forestvp_mcar.impute(repaint=True, r=10, j=5, k=1) # REPAINT (slow, but better)\n",
95 | "\n",
96 | "utrees_mcar = UnmaskingTrees(random_state=rix)\n",
97 | "utrees_mcar.fit(Xy_mcar, quantize_cols=['continuous']*4 + ['categorical'])\n",
98 | "Xy_impute_utrees = utrees_mcar.impute(n_impute=1)[0, :, :]"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "id": "b4fffa3c-4778-47d7-a18c-7fb778a67528",
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(7, 3.5), sharex=True, sharey=True, squeeze=False);\n",
109 | "\n",
110 | "Xy_df = pd.DataFrame(data=Xy, columns=my_data['feature_names']+['target_names'])\n",
111 | "Xy_df['Species'] = Xy_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
112 | "sns.scatterplot(\n",
113 | " data=Xy_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n",
114 | " palette=palette, alpha=alpha, ax=axes[0, 0], legend=False)\n",
115 | "axes[0, 0].set_title('Original data');\n",
116 | "\n",
117 | "Xy_gen_forestvp_df = pd.DataFrame(data=Xy_gen_forestvp, columns=my_data['feature_names']+['target_names'])\n",
118 | "Xy_gen_forestvp_df['Species'] = Xy_gen_forestvp_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
119 | "sns.scatterplot(\n",
120 | " data=Xy_gen_forestvp_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n",
121 | " palette=palette, alpha=alpha, ax=axes[0, 1], legend=False)\n",
122 | "axes[0, 1].set_title('Forest-VP\\ngenerated');\n",
123 | "\n",
124 | "Xy_gen_forestflow_df = pd.DataFrame(data=Xy_gen_forestflow, columns=my_data['feature_names']+['target_names'])\n",
125 | "Xy_gen_forestflow_df['Species'] = Xy_gen_forestflow_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
126 | "sns.scatterplot(\n",
127 | " data=Xy_gen_forestflow_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n",
128 | " palette=palette, alpha=alpha, ax=axes[0, 2], legend=False)\n",
129 | "axes[0, 2].set_title('Forest-Flow\\ngenerated');\n",
130 | "\n",
131 | "Xy_gen_utrees_df = pd.DataFrame(data=Xy_gen_utrees, columns=my_data['feature_names']+['target_names'])\n",
132 | "Xy_gen_utrees_df['Species'] = Xy_gen_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
133 | "sns.scatterplot(\n",
134 | " data=Xy_gen_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n",
135 | " palette=palette, alpha=alpha, ax=axes[0, 3], legend=True)\n",
136 | "axes[0, 3].set_title('UnmaskingTrees\\ngenerated');\n",
137 | "\n",
138 | "sns.move_legend(axes[0, 3], \"upper left\", bbox_to_anchor=(1, 1.1))\n",
139 | "\n",
140 | "Xy_impute_missf_df = pd.DataFrame(data=Xy_impute_missf, columns=my_data['feature_names']+['target_names'])\n",
141 | "Xy_impute_missf_df['Species'] = Xy_impute_missf_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
142 | "Xy_impute_missf_df['Imputed'] = hasmissing\n",
143 | "sns.scatterplot(\n",
144 | " data=Xy_impute_missf_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
145 | " size='Imputed', sizes={False: 10, True: 30},\n",
146 | " palette=palette, alpha=alpha, ax=axes[1, 0], legend=False)\n",
147 | "axes[1, 0].set_title('MissForest\\nimputed');\n",
148 | "\n",
149 | "Xy_impute_forest_fast_df = pd.DataFrame(data=Xy_impute_forest_fast, columns=my_data['feature_names']+['target_names'])\n",
150 | "Xy_impute_forest_fast_df['Species'] = Xy_impute_forest_fast_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
151 | "Xy_impute_forest_fast_df['Imputed'] = hasmissing\n",
152 | "sns.scatterplot(\n",
153 | " data=Xy_impute_forest_fast_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
154 | " size='Imputed', sizes={False: 10, True: 30},\n",
155 | " palette=palette, alpha=alpha, ax=axes[1, 1], legend=False)\n",
156 | "axes[1, 1].set_title('Forest-VP\\nimputed');\n",
157 | "\n",
158 | "Xy_impute_forest_df = pd.DataFrame(data=Xy_impute_forest, columns=my_data['feature_names']+['target_names'])\n",
159 | "Xy_impute_forest_df['Species'] = Xy_impute_forest_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
160 | "Xy_impute_forest_df['Imputed'] = hasmissing\n",
161 | "sns.scatterplot(\n",
162 | " data=Xy_impute_forest_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
163 | " size='Imputed', sizes={False: 10, True: 30},\n",
164 | " palette=palette, alpha=alpha, ax=axes[1, 2], legend=False)\n",
165 | "axes[1, 2].set_title('Forest-VP\\nimputed w/ RePaint');\n",
166 | "\n",
167 | "Xy_impute_utrees_df = pd.DataFrame(data=Xy_impute_utrees, columns=my_data['feature_names']+['target_names'])\n",
168 | "Xy_impute_utrees_df['Species'] = Xy_impute_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
169 | "Xy_impute_utrees_df['Imputed'] = hasmissing\n",
170 | "\n",
171 | "sns.scatterplot(\n",
172 | " data=Xy_impute_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
173 | " size='Imputed', sizes={False: 10, True: 30},\n",
174 | " palette=palette, alpha=alpha, ax=axes[1, 3], legend=True)\n",
175 | "axes[1, 3].set_title('UnmaskingTrees\\nimputed');\n",
176 | "sns.move_legend(axes[1, 3], \"upper left\", bbox_to_anchor=(1, 1.1))\n",
177 | "\n",
178 | "plt.tight_layout();\n",
179 | "plt.savefig('iris.png');"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "id": "0d040488-6031-45ea-8f1d-497e86f612a4",
186 | "metadata": {},
187 | "outputs": [],
188 | "source": [
189 | "fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(5, 7), dpi=200, sharex=True, sharey=True, squeeze=False);\n",
190 | "\n",
191 | "Xy_df = pd.DataFrame(data=Xy, columns=my_data['feature_names']+['target_names'])\n",
192 | "Xy_df['Species'] = Xy_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
193 | "sns.scatterplot(\n",
194 | " data=Xy_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n",
195 | " palette=palette, alpha=alpha, ax=axes[0, 0], legend=False)\n",
196 | "axes[0, 0].set_title('Original data');\n",
197 | "\n",
198 | "Xy_gen_forestvp_df = pd.DataFrame(data=Xy_gen_forestvp, columns=my_data['feature_names']+['target_names'])\n",
199 | "Xy_gen_forestvp_df['Species'] = Xy_gen_forestvp_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
200 | "sns.scatterplot(\n",
201 | " data=Xy_gen_forestvp_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n",
202 | " palette=palette, alpha=alpha, ax=axes[1, 0], legend=False)\n",
203 | "axes[1, 0].set_title('Forest-VP\\ngenerated');\n",
204 | "\n",
205 | "Xy_gen_forestflow_df = pd.DataFrame(data=Xy_gen_forestflow, columns=my_data['feature_names']+['target_names'])\n",
206 | "Xy_gen_forestflow_df['Species'] = Xy_gen_forestflow_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
207 | "sns.scatterplot(\n",
208 | " data=Xy_gen_forestflow_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n",
209 | " palette=palette, alpha=alpha, ax=axes[2, 0], legend=False)\n",
210 | "axes[2, 0].set_title('Forest-Flow\\ngenerated');\n",
211 | "\n",
212 | "Xy_gen_utrees_df = pd.DataFrame(data=Xy_gen_utrees, columns=my_data['feature_names']+['target_names'])\n",
213 | "Xy_gen_utrees_df['Species'] = Xy_gen_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
214 | "sns.scatterplot(\n",
215 | " data=Xy_gen_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n",
216 | " palette=palette, alpha=alpha, ax=axes[3, 0], legend=False)\n",
217 | "axes[3, 0].set_title('UnmaskingTrees\\ngenerated');\n",
218 | "\n",
219 | "#sns.move_legend(axes[3, 0], \"upper left\", bbox_to_anchor=(1, 1.1))\n",
220 | "\n",
221 | "Xy_impute_missf_df = pd.DataFrame(data=Xy_impute_missf, columns=my_data['feature_names']+['target_names'])\n",
222 | "Xy_impute_missf_df['Species'] = Xy_impute_missf_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
223 | "Xy_impute_missf_df['Imputed'] = hasmissing\n",
224 | "sns.scatterplot(\n",
225 | " data=Xy_impute_missf_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
226 | " size='Imputed', sizes={False: 5, True: 20},\n",
227 | " palette=palette, alpha=alpha, ax=axes[0, 1], legend=False)\n",
228 | "axes[0, 1].set_title('MissForest\\nimputed');\n",
229 | "\n",
230 | "Xy_impute_forest_fast_df = pd.DataFrame(data=Xy_impute_forest_fast, columns=my_data['feature_names']+['target_names'])\n",
231 | "Xy_impute_forest_fast_df['Species'] = Xy_impute_forest_fast_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
232 | "Xy_impute_forest_fast_df['Imputed'] = hasmissing\n",
233 | "sns.scatterplot(\n",
234 | " data=Xy_impute_forest_fast_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
235 | " size='Imputed', sizes={False: 5, True: 20},\n",
236 | " palette=palette, alpha=alpha, ax=axes[1, 1], legend=False)\n",
237 | "axes[1, 1].set_title('Forest-VP\\nimputed');\n",
238 | "\n",
239 | "Xy_impute_forest_df = pd.DataFrame(data=Xy_impute_forest, columns=my_data['feature_names']+['target_names'])\n",
240 | "Xy_impute_forest_df['Species'] = Xy_impute_forest_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
241 | "Xy_impute_forest_df['Imputed'] = hasmissing\n",
242 | "sns.scatterplot(\n",
243 | " data=Xy_impute_forest_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
244 | " size='Imputed', sizes={False: 5, True: 20},\n",
245 | " palette=palette, alpha=alpha, ax=axes[2, 1], legend=False)\n",
246 | "axes[2, 1].set_title('Forest-VP\\nimputed w/ RePaint');\n",
247 | "\n",
248 | "Xy_impute_utrees_df = pd.DataFrame(data=Xy_impute_utrees, columns=my_data['feature_names']+['target_names'])\n",
249 | "Xy_impute_utrees_df['Species'] = Xy_impute_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n",
250 | "Xy_impute_utrees_df['Imputed'] = hasmissing\n",
251 | "\n",
252 | "sns.scatterplot(\n",
253 | " data=Xy_impute_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n",
254 | " size='Imputed', sizes={False: 5, True: 20},\n",
255 | " palette=palette, alpha=alpha, ax=axes[3, 1], legend=True)\n",
256 | "axes[3, 1].set_title('UnmaskingTrees\\nimputed');\n",
257 | "sns.move_legend(axes[3, 1], \"upper left\", bbox_to_anchor=(1.1, 1.3))\n",
258 | "\n",
259 | "plt.tight_layout();\n",
260 | "plt.savefig('iris-vertical.png');\n",
261 | "plt.savefig('iris-vertical.pdf');"
262 | ]
263 | }
264 | ],
265 | "metadata": {
266 | "kernelspec": {
267 | "display_name": "Python 3 (ipykernel)",
268 | "language": "python",
269 | "name": "python3"
270 | },
271 | "language_info": {
272 | "codemirror_mode": {
273 | "name": "ipython",
274 | "version": 3
275 | },
276 | "file_extension": ".py",
277 | "mimetype": "text/x-python",
278 | "name": "python",
279 | "nbconvert_exporter": "python",
280 | "pygments_lexer": "ipython3",
281 | "version": "3.9.19"
282 | }
283 | },
284 | "nbformat": 4,
285 | "nbformat_minor": 5
286 | }
287 |
--------------------------------------------------------------------------------
/Results/imputation_script.R:
--------------------------------------------------------------------------------
1 | options(show.error.locations = TRUE)
2 | options(repos=c(CRAN = "https://cloud.r-project.org"))
3 | install.packages('plotrix')
4 | install.packages('matrixStats')
5 | install.packages('xtable')
6 | library(matrixStats)
7 | library(xtable)
8 | library(plotrix)
9 |
10 | data = read.csv("Results_tabular - imputation.csv")
11 | data = read.csv("Results_tabular - imputation.csv")
12 |
13 | # Better scaling for clean tables
14 | data$PercentBias = data$PercentBias / 100
15 |
16 | # debug small floats
17 | data$MeanVariance[data$MeanVariance < 1e-12] = 0
18 | data$MeanMAD[data$MeanMAD < 1e-12] = 0
19 | data$MedianMAD[data$MedianMAD < 1e-12] = 0
20 |
21 | all_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'california', 'bean', 'car','congress','tictactoe')
22 | W_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'car','congress','tictactoe')
23 | R2_vars = c('concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'wine_quality_red', 'wine_quality_white', 'california')
24 | F1_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'bean', 'car','congress','tictactoe')
25 |
26 | methods = c('KNN(n_neighbors=1)', 'ice', 'miceforest', 'MissForest', 'softimpute', 'OT', 'GAIN', 'forest_diffusion_repaint_nt51_ycond', 'utrees')
27 | methods_all = c('KNN(n_neighbors=1)', 'ice', 'miceforest', 'MissForest', 'softimpute', 'OT', 'GAIN', 'forest_diffusion_repaint_nt51_ycond', 'utrees', 'oracle')
28 |
29 | data_summary = data.frame(matrix(ncol = length(methods_all), nrow = 10))
30 | colnames(data_summary) = methods_all
31 | rownames(data_summary) = c("MinMAE", "AvgMAE", "W_train", "W_test", "MedianMAD", "R2_imp", "F1_imp", "PercentBias", "CoverageRate", "time")
32 |
33 | for (method in methods_all){
34 |
35 | MinMAE_mean = mean(data$MinMAE[data$method==method & data$dataset %in% all_vars])
36 | MinMAE_sd = std.error(data$MinMAE[data$method==method & data$dataset %in% all_vars])
37 |
38 | AvgMAE_mean = mean(data$AvgMAE[data$method==method & data$dataset %in% all_vars])
39 | AvgMAE_sd = std.error(data$AvgMAE[data$method==method & data$dataset %in% all_vars])
40 |
41 | W_train_mean = mean(data$W_train[data$method==method & data$dataset %in% W_vars])
42 | W_train_sd = std.error(data$W_train[data$method==method & data$dataset %in% W_vars])
43 |
44 | W_test_mean = mean(data$W_test[data$method==method & data$dataset %in% W_vars])
45 | W_test_sd = std.error(data$W_test[data$method==method & data$dataset %in% W_vars])
46 |
47 | PercentBias_mean = mean(data$PercentBias[data$method==method & data$dataset %in% R2_vars])
48 | PercentBias_sd = std.error(data$PercentBias[data$method==method & data$dataset %in% R2_vars])
49 |
50 | CoverageRate_mean = mean(data$CoverageRate[data$method==method & data$dataset %in% R2_vars])
51 | CoverageRate_sd = std.error(data$CoverageRate[data$method==method & data$dataset %in% R2_vars])
52 |
53 | MedianMAD_mean = mean(data$MedianMAD[data$method==method & data$dataset %in% all_vars])
54 | MedianMAD_sd = std.error(data$MedianMAD[data$method==method & data$dataset %in% all_vars])
55 |
56 | R2_imp_mean = mean(data$R2_imp[data$method==method & data$dataset %in% R2_vars])
57 | R2_imp_sd = std.error(data$R2_imp[data$method==method & data$dataset %in% R2_vars])
58 |
59 | F1_imp_mean = mean(data$F1_imp[data$method==method & data$dataset %in% F1_vars])
60 | F1_imp_sd = std.error(data$F1_imp[data$method==method & data$dataset %in% F1_vars])
61 |
62 | time_mean = mean(data$time[data$method==method & data$dataset %in% all_vars])
63 | time_sd = std.error(data$time[data$method==method & data$dataset %in% all_vars])
64 |
65 | data_summary[method] = c(
66 | paste0(as.character(round(MinMAE_mean, 2)), ' (', as.character(round(MinMAE_sd, 2)), ')'),
67 | paste0(as.character(round(AvgMAE_mean, 2)), ' (', as.character(round(AvgMAE_sd, 2)), ')'),
68 | paste0(as.character(round(W_train_mean, 2)), ' (', as.character(round(W_train_sd, 2)), ')'),
69 | paste0(as.character(round(W_test_mean, 2)), ' (', as.character(round(W_test_sd, 2)), ')'),
70 | paste0(as.character(round(MedianMAD_mean, 2)), ' (', as.character(round(MedianMAD_sd, 2)), ')'),
71 | paste0(as.character(round(R2_imp_mean, 2)), ' (', as.character(round(R2_imp_sd, 2)), ')'),
72 | paste0(as.character(round(F1_imp_mean, 2)), ' (', as.character(round(F1_imp_sd, 2)), ')'),
73 | paste0(as.character(round(PercentBias_mean, 2)), ' (', as.character(round(PercentBias_sd, 2)), ')'),
74 | paste0(as.character(round(CoverageRate_mean, 2)), ' (', as.character(round(CoverageRate_sd, 2)), ')'),
75 | paste0(as.character(round(time_mean, 2)), ' (', as.character(round(time_sd, 2)), ')')
76 | )
77 |
78 | }
79 | data_summary
80 | data_summary_t <- t(data_summary)
81 | data_summary_t
82 |
83 | ###################### RANK
84 |
85 | data_summary_rank = data.frame(matrix(ncol = length(methods), nrow = 10))
86 | colnames(data_summary_rank) = methods
87 | rownames(data_summary_rank) = c("MinMAE", "AvgMAE", "W_train", "W_test", "MedianMAD", "R2_imp", "F1_imp", "PercentBias", "CoverageRate", "time")
88 |
89 | MinMAE = data.frame(matrix(ncol = length(methods), nrow = length(all_vars)))
90 | colnames(MinMAE) = methods
91 | rownames(MinMAE) = all_vars
92 | AvgMAE = data.frame(matrix(ncol = length(methods), nrow = length(all_vars)))
93 | colnames(AvgMAE) = methods
94 | rownames(AvgMAE) = all_vars
95 | W_train = data.frame(matrix(ncol = length(methods), nrow = length(W_vars)))
96 | colnames(W_train) = methods
97 | rownames(W_train) = W_vars
98 | W_test = data.frame(matrix(ncol = length(methods), nrow = length(W_vars)))
99 | colnames(W_test) = methods
100 | rownames(W_test) = W_vars
101 | MedianMAD = data.frame(matrix(ncol = length(methods), nrow = length(all_vars)))
102 | colnames(MedianMAD) = methods
103 | rownames(MedianMAD) = all_vars
104 | R2_imp = data.frame(matrix(ncol = length(methods), nrow = length(R2_vars)))
105 | colnames(R2_imp) = methods
106 | rownames(R2_imp) = R2_vars
107 | F1_imp = data.frame(matrix(ncol = length(methods), nrow = length(F1_vars)))
108 | colnames(F1_imp) = methods
109 | rownames(F1_imp) = F1_vars
110 | PercentBias = data.frame(matrix(ncol = length(methods), nrow = length(R2_vars)))
111 | colnames(PercentBias) = methods
112 | rownames(PercentBias) = R2_vars
113 | CoverageRate = data.frame(matrix(ncol = length(methods), nrow = length(R2_vars)))
114 | colnames(CoverageRate) = methods
115 | rownames(CoverageRate) = R2_vars
116 | Time = data.frame(matrix(ncol = length(methods), nrow = length(all_vars)))
117 | colnames(Time) = methods
118 | rownames(Time) = all_vars
119 |
120 | for (method in methods){
121 | print(method)
122 | MinMAE[method] = data$MinMAE[data$method==method & data$dataset %in% all_vars]
123 | AvgMAE[method] = data$AvgMAE[data$method==method & data$dataset %in% all_vars]
124 | print(method)
125 | W_train[method] = data$W_train[data$method==method & data$dataset %in% W_vars]
126 | W_test[method] = data$W_test[data$method==method & data$dataset %in% W_vars]
127 | print(method)
128 | MedianMAD[method] = -data$MedianMAD[data$method==method & data$dataset %in% all_vars]
129 | print(method)
130 | PercentBias[method] = data$PercentBias[data$method==method & data$dataset %in% R2_vars]
131 | CoverageRate[method] = -data$CoverageRate[data$method==method & data$dataset %in% R2_vars]
132 | print(method)
133 | R2_imp[method] = -data$R2_imp[data$method==method & data$dataset %in% R2_vars]
134 | F1_imp[method] = -data$F1_imp[data$method==method & data$dataset %in% F1_vars]
135 |
136 | Time[method] = data$time[data$method==method & data$dataset %in% all_vars]
137 |
138 | }
139 |
140 | # Rank by dataset
141 |
142 | MinMAE_ = as.data.frame(t(sapply(as.data.frame(t(MinMAE)), rank)))
143 | colnames(MinMAE_) = colnames(MinMAE)
144 | MinMAE = MinMAE_
145 |
146 | AvgMAE_ = as.data.frame(t(sapply(as.data.frame(t(AvgMAE)), rank)))
147 | colnames(AvgMAE_) = colnames(AvgMAE)
148 | AvgMAE = AvgMAE_
149 |
150 | W_train_ = as.data.frame(t(sapply(as.data.frame(t(W_train)), rank)))
151 | colnames(W_train_) = colnames(W_train)
152 | W_train = W_train_
153 |
154 | W_test_ = as.data.frame(t(sapply(as.data.frame(t(W_test)), rank)))
155 | colnames(W_test_) = colnames(W_test)
156 | W_test = W_test_
157 |
158 | MedianMAD_ = as.data.frame(t(sapply(as.data.frame(t(MedianMAD)), rank)))
159 | colnames(MedianMAD_) = colnames(MedianMAD)
160 | MedianMAD = MedianMAD_
161 |
162 | PercentBias_ = as.data.frame(t(sapply(as.data.frame(t(PercentBias)), rank)))
163 | colnames(PercentBias_) = colnames(PercentBias)
164 | PercentBias = PercentBias_
165 |
166 | CoverageRate_ = as.data.frame(t(sapply(as.data.frame(t(CoverageRate)), rank)))
167 | colnames(CoverageRate_) = colnames(CoverageRate)
168 | CoverageRate = CoverageRate_
169 |
170 | R2_imp_ = as.data.frame(t(sapply(as.data.frame(t(R2_imp)), rank)))
171 | colnames(R2_imp_) = colnames(R2_imp)
172 | R2_imp = R2_imp_
173 |
174 | F1_imp_ = as.data.frame(t(sapply(as.data.frame(t(F1_imp)), rank)))
175 | colnames(F1_imp_) = colnames(F1_imp)
176 | F1_imp = F1_imp_
177 |
178 | Time_ = as.data.frame(t(sapply(as.data.frame(t(Time)), rank)))
179 | colnames(Time_) = colnames(Time)
180 | Time = Time_
181 |
182 | for (method in methods){
183 |
184 | AvgMAE_mean = mean(unlist(AvgMAE[method]))
185 | AvgMAE_sd = std.error(unlist(AvgMAE[method]))
186 |
187 | MinMAE_mean = mean(unlist(MinMAE[method]))
188 | MinMAE_sd = std.error(unlist(MinMAE[method]))
189 |
190 | W_train_mean = mean(unlist(W_train[method]))
191 | W_train_sd = std.error(unlist(W_train[method]))
192 |
193 | W_test_mean = mean(unlist(W_test[method]))
194 | W_test_sd = std.error(unlist(W_test[method]))
195 |
196 | MedianMAD_mean = mean(unlist(MedianMAD[method]))
197 | MedianMAD_sd = std.error(unlist(MedianMAD[method]))
198 |
199 | PercentBias_mean = mean(unlist(PercentBias[method]))
200 | PercentBias_sd = std.error(unlist(PercentBias[method]))
201 |
202 | CoverageRate_mean = mean(unlist(CoverageRate[method]))
203 | CoverageRate_sd = std.error(unlist(CoverageRate[method]))
204 |
205 | R2_imp_mean = mean(unlist(R2_imp[method]))
206 | R2_imp_sd = std.error(unlist(R2_imp[method]))
207 |
208 | F1_imp_mean = mean(unlist(F1_imp[method]))
209 | F1_imp_sd = std.error(unlist(F1_imp[method]))
210 |
211 | time_mean = mean(unlist(Time[method]))
212 | time_sd = std.error(unlist(Time[method]))
213 |
214 | data_summary_rank[method] = c(
215 | paste0(as.character(round(AvgMAE_mean, 1)), ' (', as.character(round(AvgMAE_sd, 1)), ')'),
216 | paste0(as.character(round(MinMAE_mean, 1)), ' (', as.character(round(MinMAE_sd, 1)), ')'),
217 | paste0(as.character(round(W_train_mean, 1)), ' (', as.character(round(W_train_sd, 1)), ')'),
218 | paste0(as.character(round(W_test_mean, 1)), ' (', as.character(round(W_test_sd, 1)), ')'),
219 | paste0(as.character(round(MedianMAD_mean, 1)), ' (', as.character(round(MedianMAD_sd, 1)), ')'),
220 | paste0(as.character(round(PercentBias_mean, 1)), ' (', as.character(round(PercentBias_sd, 1)), ')'),
221 | paste0(as.character(round(CoverageRate_mean, 1)), ' (', as.character(round(CoverageRate_sd, 1)), ')'),
222 | paste0(as.character(round(R2_imp_mean, 1)), ' (', as.character(round(R2_imp_sd, 1)), ')'),
223 | paste0(as.character(round(F1_imp_mean, 1)), ' (', as.character(round(F1_imp_sd, 1)), ')'),
224 | paste0(as.character(round(time_mean, 1)), ' (', as.character(round(time_sd, 1)), ')')
225 | )
226 |
227 | }
228 | data_summary_rank
229 | data_summary_rank_t <- t(data_summary_rank)
230 | data_summary_rank_t
231 |
232 | ########### Latex tables
233 |
234 | rownames(data_summary_t) = c("KNN", "ICE", "MICE-Forest", "MissForest", "Softimpute", "OT", "GAIN", "Forest-VP", "utrees", "Oracle")
235 | rownames(data_summary_rank_t) = c("KNN", "ICE", "MICE-Forest", "MissForest", "Softimpute", "OT", "GAIN", "Forest-VP", "utrees")
236 |
237 | xtable(data_summary_t[,-10], type = "latex")
238 | xtable(data_summary_rank_t[,-10], type = "latex")
239 |
240 | if (FALSE) {
241 | # For building the barplot, to csv
242 |
243 | MinMAE = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
244 | colnames(MinMAE) = methods_all
245 | rownames(MinMAE) = all_vars
246 | MinMAE[,] = 0
247 | AvgMAE = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
248 | colnames(AvgMAE) = methods_all
249 | rownames(AvgMAE) = all_vars
250 | AvgMAE[,] = 0
251 | W_train = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
252 | colnames(W_train) = methods_all
253 | rownames(W_train) = all_vars
254 | W_train[,] = 0
255 | W_test = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
256 | colnames(W_test) = methods_all
257 | rownames(W_test) = all_vars
258 | W_test[,] = 0
259 | MedianMAD = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
260 | colnames(MedianMAD) = methods_all
261 | rownames(MedianMAD) = all_vars
262 | MedianMAD[,] = 0
263 | R2_imp = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
264 | colnames(R2_imp) = methods_all
265 | rownames(R2_imp) = all_vars
266 | R2_imp[,] = 0
267 | F1_imp = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
268 | colnames(F1_imp) = methods_all
269 | rownames(F1_imp) = all_vars
270 | F1_imp[,] = 0
271 | PercentBias = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
272 | colnames(PercentBias) = methods_all
273 | rownames(PercentBias) = all_vars
274 | PercentBias[,] = 0
275 | CoverageRate = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
276 | colnames(CoverageRate) = methods_all
277 | rownames(CoverageRate) = all_vars
278 | CoverageRate[,] = 0
279 |
280 | for (method in methods_all){
281 | MinMAE[rownames(MinMAE) %in% all_vars, method] = data$MinMAE[data$method==method & data$dataset %in% all_vars]
282 | AvgMAE[rownames(AvgMAE) %in% all_vars, method] = data$AvgMAE[data$method==method & data$dataset %in% all_vars]
283 |
284 | W_train[rownames(W_train) %in% W_vars, method] = data$W_train[data$method==method & data$dataset %in% W_vars]
285 | W_test[rownames(W_test) %in% W_vars, method] = data$W_test[data$method==method & data$dataset %in% W_vars]
286 | MedianMAD[, method] = data$MedianMAD[data$method==method & data$dataset %in% all_vars]
287 |
288 | R2_imp[rownames(R2_imp) %in% R2_vars, method] = data$R2_imp[data$method==method & data$dataset %in% R2_vars]
289 | F1_imp[rownames(F1_imp) %in% F1_vars, method] = data$F1_imp[data$method==method & data$dataset %in% F1_vars]
290 | PercentBias[rownames(PercentBias) %in% R2_vars, method] = data$PercentBias[data$method==method & data$dataset %in% R2_vars]
291 | CoverageRate[rownames(CoverageRate) %in% R2_vars, method] = data$CoverageRate[data$method==method & data$dataset %in% R2_vars]
292 | }
293 |
294 |
295 | # For building the barplot, to csv
296 | write.csv(MinMAE, file="imp_MinMAE.csv")
297 | write.csv(AvgMAE, file="imp_AvgMAE.csv")
298 | write.csv(W_test, file="imp_Wtrain.csv")
299 | write.csv(W_test, file="imp_W.csv")
300 | write.csv(MedianMAD, file="imp_MedianMAD.csv")
301 | write.csv(R2_imp, file="imp_R2.csv")
302 | write.csv(F1_imp, file="imp_F1.csv")
303 | write.csv(PercentBias, file="imp_pb.csv")
304 | write.csv(CoverageRate, file="imp_cr.csv")
305 | }
306 |
--------------------------------------------------------------------------------
/utrees/unmasking_trees.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Union, Optional
3 | import warnings
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import xgboost as xgb
8 |
9 | from sklearn.base import BaseEstimator
10 | from sklearn.utils import check_random_state
11 | from sklearn.preprocessing import LabelEncoder
12 | from tqdm import tqdm
13 |
14 | from utrees.baltobot import Baltobot, NanTabPFNClassifier
15 | from utrees.kdi_quantizer import KDIQuantizer
16 |
17 |
18 | XGBOOST_DEFAULT_KWARGS = {
19 | "tree_method": "hist",
20 | "verbosity": 1,
21 | "objective": "binary:logistic",
22 | }
23 |
24 | TABPFN_DEFAULT_KWARGS = {
25 | "seed": 42,
26 | "batch_size_inference": 1,
27 | "subsample_features": True,
28 | }
29 |
30 |
31 | class UnmaskingTrees(BaseEstimator):
32 | """Performs generative modeling and multiple imputation on tabular data.
33 |
34 | This method generates training data by iteratively masking samples,
35 | then training per-feature XGBoost models to unmask the data.
36 |
37 | Parameters
38 | ----------
39 | depth : int >= 0
40 | Depth of balanced binary tree for recursively quantizing each feature.
41 | The total number of quantization bins is 2^depth.
42 |
43 | duplicate_K : int > 0
44 | Number of random masking orders per actual sample.
45 | The training dataset will be of size (n_samples * n_dims * duplicate_K, n_dims).
46 |
47 | clf_kwargs : dict
48 | Arguments for XGBoost (or TabPFN) classifier.
49 |
50 | strategy : 'kdiquantile', 'quantile', 'uniform', 'kmeans'
51 | The quantization strategy for discretizing continuous features.
52 |
53 | softmax_temp : float > 0
54 | Softmax temperature for sampling from predicted probabilities.
55 | As temperature decreases below the default of 1, predictions converge
56 | to the argmax of each conditional distribution.
57 |
58 | tabpfn: bool
59 | Whether to use TabPFN instead of XGBoost classifier.
60 |
61 | cast_float32: bool
62 | Whether to always convert inputs to float32.
63 |
64 | random_state : int, RandomState instance or None, default=None
65 | Determines random number generation.
66 |
67 |
68 | Attributes
69 | ----------
70 | trees_ : list of Baltobot or XGBClassifier
71 | Fitted per-feature sampling models. For each column ix, this
72 | will be Baltobot if quantize_cols_[ix], and XGBClassifier otherwise.
73 |
74 | constant_vals_ : list of float or None, with length n_dims
75 | Gives the values of constant features, or None otherwise.
76 |
77 | quantize_cols : str or list of strs
78 | Saved user-provided quantize_cols input.
79 |
80 | quantize_cols_ : list of bool, with length n_dims
81 | Whether to apply (and un-apply) discretization to each feature.
82 |
83 | encoders_ : list of None or LabelEncoder
84 | Preprocessors for non-continuous features.
85 |
86 | clf_kwargs_ : dict
87 | Args passed to XGBClassifier or NanTabPFNClassifier.
88 |
89 | X_ : np.ndarray (n_samples, n_dims)
90 | Input data.
91 |
92 | """
93 |
94 | def __init__(
95 | self,
96 | depth: int = 4,
97 | duplicate_K: int = 50,
98 | clf_kwargs: dict = {},
99 | strategy: str = "kdiquantile",
100 | softmax_temp: float = 1.0,
101 | cast_float32: bool = True,
102 | tabpfn: bool = False,
103 | random_state=None,
104 | ):
105 | self.depth = depth
106 | self.duplicate_K = duplicate_K
107 | self.clf_kwargs = clf_kwargs
108 | self.strategy = strategy
109 | self.softmax_temp = softmax_temp
110 | self.cast_float32 = cast_float32
111 | self.tabpfn = tabpfn
112 | self.random_state = random_state
113 |
114 | if self.tabpfn:
115 | self.clf_kwargs_ = TABPFN_DEFAULT_KWARGS.copy()
116 | else:
117 | self.clf_kwargs_ = XGBOOST_DEFAULT_KWARGS.copy()
118 | self.clf_kwargs_.update(clf_kwargs)
119 |
120 | assert 1 <= duplicate_K
121 | assert strategy in ("kdiquantile", "quantile", "uniform", "kmeans")
122 | assert 0 < softmax_temp
123 |
124 | self.random_state_ = check_random_state(random_state)
125 | self.trees_ = None
126 | self.constant_vals_ = None
127 | self.quantize_cols_ = None
128 | self.encoders_ = None
129 | self.X_ = None
130 |
131 | def fit(
132 | self,
133 | X: np.ndarray,
134 | quantize_cols: Union[str, list] = "all",
135 | ):
136 | """
137 | Fit the estimator.
138 |
139 | Parameters
140 | ----------
141 | X : array-like of shape (n_samples, n_dims)
142 | Data to be modeled and imputed (possibly with np.nan).
143 |
144 | quantize_cols : 'all', 'none', or list of strs ('continuous', 'categorical', 'integer')
145 | Whether to apply (and un-apply) discretization to each feature.
146 |
147 | Returns
148 | -------
149 | self : object
150 | Returns the instance itself.
151 | """
152 | rng = check_random_state(self.random_state_)
153 | assert isinstance(X, np.ndarray)
154 | self.X_ = X.copy()
155 | if self.cast_float32:
156 | X = X.astype(np.float32)
157 | n_samples, n_dims = X.shape
158 |
159 | if isinstance(quantize_cols, list):
160 | assert len(quantize_cols) == n_dims
161 | self.quantize_cols_ = []
162 | for d, elt in enumerate(quantize_cols):
163 | if elt == "continuous":
164 | self.quantize_cols_.append(True)
165 | elif elt == "categorical":
166 | self.quantize_cols_.append(False)
167 | elif elt == "integer":
168 | self.quantize_cols_.append(True)
169 | else:
170 | assert elt in ("continuous", "categorical", "integer")
171 | elif quantize_cols == "none":
172 | self.quantize_cols_ = [False] * n_dims
173 | elif quantize_cols == "all":
174 | self.quantize_cols_ = [True] * n_dims
175 | else:
176 | raise ValueError(f"unexpected quantize_cols: {quantize_cols}")
177 | self.quantize_cols = deepcopy(quantize_cols)
178 |
179 | # Find features with constant vals, to be unmasked before training and inference
180 | self.constant_vals_ = []
181 | for d in range(n_dims):
182 | col_d = X[~np.isnan(X[:, d]), d]
183 | if len(np.unique(col_d)) == 1:
184 | self.constant_vals_.append(np.unique(col_d).item())
185 | else:
186 | self.constant_vals_.append(None)
187 |
188 | # Fit encoders
189 | self.encoders_ = []
190 | for d in range(n_dims):
191 | if self.quantize_cols_[d]:
192 | cur_enc = None
193 | else:
194 | cur_enc = LabelEncoder()
195 | cur_enc.fit(X[~np.isnan(X[:, d]), d])
196 | self.encoders_.append(cur_enc)
197 |
198 | # Generate training data
199 | X_train = []
200 | Y_train = []
201 | for dupix in range(self.duplicate_K):
202 | mask_ixs = np.repeat(np.arange(n_dims)[np.newaxis, :], n_samples, axis=0)
203 | mask_ixs = np.apply_along_axis(
204 | rng.permutation, axis=1, arr=mask_ixs
205 | ) # n_samples, n_dims
206 | for n in range(n_samples):
207 | fuller_X = X[n, :]
208 | for d in range(n_dims):
209 | victim_ix = mask_ixs[n, d]
210 | if fuller_X[victim_ix] != np.nan:
211 | emptier_X = fuller_X.copy()
212 | emptier_X[victim_ix] = np.nan
213 | X_train.append(emptier_X.reshape(1, -1))
214 | Y_train.append(fuller_X.reshape(1, -1))
215 | fuller_X = emptier_X
216 | X_train = np.concatenate(X_train, axis=0)
217 | Y_train = np.concatenate(Y_train, axis=0)
218 |
219 | # Unmask constant-value columns before training
220 | for d in range(n_dims):
221 | if self.constant_vals_[d] is not None:
222 | X_train[:, d] = self.constant_vals_[d]
223 |
224 | # Fit trees
225 | self.trees_ = []
226 | for d in range(n_dims):
227 | if self.constant_vals_[d] is not None:
228 | self.trees_.append(None)
229 | continue
230 | train_ixs = ~np.isnan(Y_train[:, d])
231 | curX_train = X_train[train_ixs, :]
232 | if self.quantize_cols_[d]:
233 | curY_train = Y_train[train_ixs, d]
234 | balto = Baltobot(
235 | depth=self.depth,
236 | clf_kwargs=self.clf_kwargs,
237 | strategy=self.strategy,
238 | softmax_temp=self.softmax_temp,
239 | tabpfn=self.tabpfn,
240 | random_state=self.random_state_,
241 | )
242 | balto.fit(curX_train, curY_train)
243 | self.trees_.append(balto)
244 | else:
245 | curY_train = self.encoders_[d].transform(Y_train[train_ixs, d])
246 | if self.tabpfn:
247 | clfer = NanTabPFNClassifier(**self.clf_kwargs)
248 | else:
249 | clfer = xgb.XGBClassifier(**self.clf_kwargs)
250 |
251 | clfer.fit(curX_train, curY_train)
252 | self.trees_.append(clfer)
253 | return self
254 |
255 | def generate(
256 | self,
257 | n_generate: int = 1,
258 | ):
259 | """
260 | Generate new data.
261 |
262 | Parameters
263 | ----------
264 | n_generate : int > 0
265 | Desired number of samples.
266 |
267 | Returns
268 | -------
269 | X : np.ndarray of size (n_generate, n_dims)
270 | Generated data.
271 | """
272 | _, n_dims = self.X_.shape
273 | X = np.full(fill_value=np.nan, shape=(n_generate, n_dims))
274 | for d in range(n_dims):
275 | if self.constant_vals_[d] is not None:
276 | X[:, d] = self.constant_vals_[d]
277 | genX = self.impute(n_impute=1, X=X)[0, :, :]
278 | return genX
279 |
280 | def impute(
281 | self,
282 | n_impute: int = 1,
283 | X: np.ndarray = None,
284 | ):
285 | """
286 | Generate new data.
287 |
288 | Parameters
289 | ----------
290 | n_impute : int > 0
291 | Desired number multiple imputations per sample.
292 |
293 | X : np.ndarray of size (n_samples, n_dims) or None
294 | Data to impute. If None, uses the data passed to fit.
295 |
296 | Returns
297 | -------
298 | imputedX : np.ndarray of size (n_impute, n_samples, n_dims)
299 | Imputed data.
300 | """
301 | if X is None:
302 | X = self.X_.copy()
303 | (n_samples, n_dims) = X.shape
304 | rng = check_random_state(self.random_state_)
305 |
306 | for d in range(n_dims):
307 | if self.constant_vals_[d] is not None:
308 | # Only replace nans with constant vals, because this is impute, ya know
309 | X[np.isnan(X[:, d]), d] = self.constant_vals_[d]
310 |
311 | imputedX = np.repeat(
312 | X[np.newaxis, :, :], repeats=n_impute, axis=0
313 | ) # (n_impute, n_samples, n_dims)
314 | with tqdm(total=n_samples * n_impute) as pbar:
315 | for n in range(n_samples):
316 | to_unmask = np.where(np.isnan(X[n, :]))[0] # (n_to_unmask,)
317 | unmask_ixs = np.repeat(
318 | to_unmask[np.newaxis, :], n_impute, axis=0
319 | ) # (n_impute, n_to_unmask)
320 | unmask_ixs = np.apply_along_axis(
321 | rng.permutation, axis=1, arr=unmask_ixs
322 | ) # (n_impute, n_to_unmask)
323 | n_to_unmask = unmask_ixs.shape[1]
324 | for kix in range(n_impute):
325 | pbar.update(1)
326 | for dix in range(n_to_unmask):
327 | unmask_ix = unmask_ixs[kix, dix]
328 | if self.quantize_cols_[unmask_ix]:
329 | pred_val = self.trees_[unmask_ix].sample(
330 | imputedX[kix, [n], :]
331 | )
332 | else:
333 | # TODO: use TabPFN if requested
334 | pred_probas = self.trees_[unmask_ix].predict_proba(
335 | imputedX[kix, [n], :]
336 | )
337 | with np.errstate(divide="ignore"):
338 | annealed_logits = (
339 | np.log(pred_probas) / self.softmax_temp
340 | )
341 | pred_probas = np.exp(annealed_logits) / np.sum(
342 | np.exp(annealed_logits), axis=1
343 | )
344 | cur_enc = self.encoders_[unmask_ix]
345 | pred_enc = rng.choice(
346 | a=len(cur_enc.classes_), p=pred_probas.ravel()
347 | )
348 | pred_val = cur_enc.inverse_transform(np.array([pred_enc]))
349 | imputedX[kix, n, unmask_ix] = pred_val.item()
350 | return imputedX
351 |
352 | def score_samples(
353 | self,
354 | X: np.ndarray,
355 | n_evals: int = 1,
356 | ):
357 | """Compute the log-likelihood of each sample under the model.
358 |
359 | Parameters
360 | ----------
361 | X : array-like of shape (n_samples, n_features)
362 | An array of points to query. Last dimension should match dimension
363 | of training data (n_features).
364 |
365 | n_evals : int > 0
366 |
367 | Returns
368 | -------
369 | density : ndarray of shape (n_samples,)
370 | Log-likelihood of each sample in `X`. These are normalized to be
371 | probability densities, so values will be low for high-dimensional
372 | data.
373 | """
374 | (n_samples, n_dims) = X.shape
375 | rng = check_random_state(self.random_state_)
376 |
377 | density = np.zeros((n_samples,))
378 | for n in range(n_samples):
379 | cur_density = np.zeros((n_evals,))
380 | for k in range(n_evals):
381 | eval_order = rng.permutation(n_dims)
382 | for dix in range(n_dims):
383 | eval_ix = eval_order[dix]
384 | if not np.isnan(X[n, eval_ix]):
385 | evalX = X[[n], :].copy()
386 | evalX[:, eval_order[dix:]] = np.nan
387 | evaly = X[[n], eval_ix]
388 | if self.quantize_cols_[eval_ix]:
389 | cur_density[k] += (
390 | self.trees_[eval_ix].score_samples(evalX, evaly).item()
391 | )
392 | else:
393 | probas = self.trees_[eval_ix].predict_proba(evalX).ravel()
394 | cur_enc = self.encoders_[eval_ix]
395 | true_class = cur_enc.transform(evaly.reshape(1, 1)).item()
396 | cur_density[k] += np.log(probas[true_class])
397 | density[n] = np.mean(cur_density)
398 | return density
399 |
--------------------------------------------------------------------------------
/Results/generation_script_nmiss50.R:
--------------------------------------------------------------------------------
1 | options(show.error.locations = TRUE)
2 | options(repos=c(CRAN = "https://cloud.r-project.org"))
3 | install.packages('matrixStats')
4 | install.packages('xtable')
5 | install.packages('plotrix')
6 | library(matrixStats)
7 | library(xtable)
8 | library(plotrix)
9 |
10 | all_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'california', 'bean', 'car','congress','tictactoe')
11 | W_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'car','congress','tictactoe')
12 | R2_vars = c('concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'wine_quality_red', 'wine_quality_white', 'california')
13 | F1_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'bean', 'car','congress','tictactoe')
14 |
15 | methods = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'utrees')
16 | methods_all = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'oracle', 'utrees')
17 |
18 | # Change the folder to your own
19 | data = read.csv("Results_tabular - generation_miss.csv")
20 |
21 | # Better scaling for clean tables
22 | data$percent_bias = data$percent_bias / 100
23 |
24 | #### CHOOOSE HERE
25 |
26 | #data = data[data$missingness!='MCAR(0.2 MissForest)',] # miceforest results
27 | data = data[data$missingness!='MCAR(0.2 miceforest)',] # MissForest results
28 |
29 | ####
30 |
31 | data_summary = data.frame(matrix(ncol = 10, nrow = 10))
32 | colnames(data_summary) = methods_all
33 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time")
34 |
35 | for (method in methods_all){
36 | W_train_mean = mean(data$W_train[data$method==method & data$dataset %in% W_vars])
37 | W_train_sd = std.error(data$W_train[data$method==method & data$dataset %in% W_vars])
38 |
39 | W_test_mean = mean(data$W_test[data$method==method & data$dataset %in% W_vars])
40 | W_test_sd = std.error(data$W_test[data$method==method & data$dataset %in% W_vars])
41 |
42 | coverage_mean = mean(data$coverage_train[data$method==method & data$dataset %in% all_vars])
43 | coverage_sd = std.error(data$coverage_train[data$method==method & data$dataset %in% all_vars])
44 |
45 | coverage_test_mean = mean(data$coverage_test[data$method==method & data$dataset %in% all_vars])
46 | coverage_test_sd = std.error(data$coverage_test[data$method==method & data$dataset %in% all_vars])
47 |
48 | R2_fake_mean = mean(data$R2_fake[data$method==method & data$dataset %in% R2_vars])
49 | R2_fake_sd = std.error(data$R2_fake[data$method==method & data$dataset %in% R2_vars])
50 |
51 | F1_fake_mean = mean(data$F1_fake[data$method==method & data$dataset %in% F1_vars])
52 | F1_fake_sd = std.error(data$F1_fake[data$method==method & data$dataset %in% F1_vars])
53 |
54 | percent_bias_mean = mean(data$percent_bias[data$method==method & data$dataset %in% R2_vars])
55 | percent_bias_sd = std.error(data$percent_bias[data$method==method & data$dataset %in% R2_vars])
56 |
57 | coverage_rate_mean = mean(data$coverage_rate[data$method==method & data$dataset %in% R2_vars])
58 | coverage_rate_sd = std.error(data$coverage_rate[data$method==method & data$dataset %in% R2_vars])
59 |
60 | class_score_mean = mean(data$class_score[data$method==method & data$dataset %in% all_vars])
61 | class_score_sd = std.error(data$class_score[data$method==method & data$dataset %in% all_vars])
62 |
63 | time_mean = mean(data$time[data$method==method & data$dataset %in% all_vars])
64 | time_sd = std.error(data$time[data$method==method & data$dataset %in% all_vars])
65 |
66 | data_summary[method] = c(
67 | paste0(as.character(round(W_train_mean, 2)), ' (', as.character(round(W_train_sd, 2)), ')'),
68 | paste0(as.character(round(W_test_mean, 2)), ' (', as.character(round(W_test_sd, 2)), ')'),
69 | paste0(as.character(round(coverage_mean, 2)), ' (', as.character(round(coverage_sd, 2)), ')'),
70 | paste0(as.character(round(coverage_test_mean, 2)), ' (', as.character(round(coverage_test_sd, 2)), ')'),
71 | paste0(as.character(round(R2_fake_mean, 2)), ' (', as.character(round(R2_fake_sd, 2)), ')'),
72 | paste0(as.character(round(F1_fake_mean, 2)), ' (', as.character(round(F1_fake_sd, 2)), ')'),
73 | paste0(as.character(round(class_score_mean, 2)), ' (', as.character(round(class_score_sd, 2)), ')'),
74 | paste0(as.character(round(percent_bias_mean, 2)), ' (', as.character(round(percent_bias_sd, 2)), ')'),
75 | paste0(as.character(round(coverage_rate_mean, 2)), ' (', as.character(round(coverage_rate_sd, 2)), ')'),
76 | paste0(as.character(round(time_mean, 2)), ' (', as.character(round(time_sd, 2)), ')')
77 | )
78 |
79 | }
80 | data_summary
81 | data_summary_t <- t(data_summary)
82 | data_summary_t
83 |
84 |
85 |
86 | ###################### RANK
87 |
88 | data_summary_rank = data.frame(matrix(ncol = 9, nrow = 10))
89 | colnames(data_summary_rank) = methods
90 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time")
91 |
92 | W_train = data.frame(matrix(ncol = 9, nrow = length(W_vars)))
93 | colnames(W_train) = methods
94 | rownames(W_train) = W_vars
95 | W_test = data.frame(matrix(ncol = 9, nrow = length(W_vars)))
96 | colnames(W_test) = methods
97 | rownames(W_test) = W_vars
98 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
99 | colnames(coverage) = methods
100 | rownames(coverage) = all_vars
101 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
102 | colnames(coverage_test) = methods
103 | rownames(coverage_test) = all_vars
104 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(R2_vars)))
105 | colnames(R2_fake) = methods
106 | rownames(R2_fake) = R2_vars
107 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(F1_vars)))
108 | colnames(F1_fake) = methods
109 | rownames(F1_fake) = F1_vars
110 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(R2_vars)))
111 | colnames(percent_bias) = methods
112 | rownames(percent_bias) = R2_vars
113 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(R2_vars)))
114 | colnames(coverage_rate) = methods
115 | rownames(coverage_rate) = R2_vars
116 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
117 | colnames(ClassScore) = methods
118 | rownames(ClassScore) = all_vars
119 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
120 | colnames(Time) = methods
121 | rownames(Time) = all_vars
122 |
123 | for (method in methods){
124 |
125 | W_train[method] = data$W_train[data$method==method & data$dataset %in% W_vars]
126 | W_test[method] = data$W_test[data$method==method & data$dataset %in% W_vars]
127 |
128 | coverage[method] = -data$coverage_train[data$method==method & data$dataset %in% all_vars]
129 | coverage_test[method] = -data$coverage_test[data$method==method & data$dataset %in% all_vars]
130 |
131 | R2_fake[method] = -data$R2_fake[data$method==method & data$dataset %in% R2_vars]
132 | F1_fake[method] = -data$F1_fake[data$method==method & data$dataset %in% F1_vars]
133 | percent_bias[method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars]
134 | coverage_rate[method] = -data$coverage_rate[data$method==method & data$dataset %in% R2_vars]
135 | ClassScore[method] = data$class_score[data$method==method & data$dataset %in% all_vars]
136 | Time[method] = data$time[data$method==method & data$dataset %in% all_vars]
137 |
138 | }
139 |
140 | # Rank by dataset
141 |
142 | W_train_ = as.data.frame(t(sapply(as.data.frame(t(W_train)), rank)))
143 | colnames(W_train_) = colnames(W_train)
144 | W_train = W_train_
145 |
146 | W_test_ = as.data.frame(t(sapply(as.data.frame(t(W_test)), rank)))
147 | colnames(W_test_) = colnames(W_test)
148 | W_test = W_test_
149 |
150 | coverage_ = as.data.frame(t(sapply(as.data.frame(t(coverage)), rank)))
151 | colnames(coverage_) = colnames(coverage)
152 | coverage = coverage_
153 | coverage_test_ = as.data.frame(t(sapply(as.data.frame(t(coverage_test)), rank)))
154 | colnames(coverage_test_) = colnames(coverage_test)
155 | coverage_test = coverage_test_
156 |
157 | R2_fake_ = as.data.frame(t(sapply(as.data.frame(t(R2_fake)), rank)))
158 | colnames(R2_fake_) = colnames(R2_fake)
159 | R2_fake = R2_fake_
160 |
161 | F1_fake_ = as.data.frame(t(sapply(as.data.frame(t(F1_fake)), rank)))
162 | colnames(F1_fake_) = colnames(F1_fake)
163 | F1_fake = F1_fake_
164 |
165 | percent_bias_ = as.data.frame(t(sapply(as.data.frame(t(percent_bias)), rank)))
166 | colnames(percent_bias_) = colnames(percent_bias)
167 | percent_bias = percent_bias_
168 | coverage_rate_ = as.data.frame(t(sapply(as.data.frame(t(coverage_rate)), rank)))
169 | colnames(coverage_rate_) = colnames(coverage_rate)
170 | coverage_rate = coverage_rate_
171 |
172 | ClassScore_ = as.data.frame(t(sapply(as.data.frame(t(ClassScore)), rank)))
173 | colnames(ClassScore_) = colnames(ClassScore)
174 | ClassScore = ClassScore_
175 | Time_ = as.data.frame(t(sapply(as.data.frame(t(Time)), rank)))
176 | colnames(Time_) = colnames(Time)
177 | Time = Time_
178 |
179 | for (method in methods){
180 | W_train_mean = mean(unlist(W_train[method]))
181 | W_train_sd = std.error(unlist(W_train[method]))
182 |
183 | W_test_mean = mean(unlist(W_test[method]))
184 | W_test_sd = std.error(unlist(W_test[method]))
185 |
186 | coverage_mean = mean(unlist(coverage[method]))
187 | coverage_sd = std.error(unlist(coverage[method]))
188 | coverage_test_mean = mean(unlist(coverage_test[method]))
189 | coverage_test_sd = std.error(unlist(coverage_test[method]))
190 |
191 | R2_fake_mean = mean(unlist(R2_fake[method]))
192 | R2_fake_sd = std.error(unlist(R2_fake[method]))
193 |
194 | F1_fake_mean = mean(unlist(F1_fake[method]))
195 | F1_fake_sd = std.error(unlist(F1_fake[method]))
196 |
197 | percent_bias_mean = mean(unlist(percent_bias[method]))
198 | percent_bias_sd = std.error(unlist(percent_bias[method]))
199 |
200 | coverage_rate_mean = mean(unlist(coverage_rate[method]))
201 | coverage_rate_sd = std.error(unlist(coverage_rate[method]))
202 |
203 | class_score_mean = mean(unlist(ClassScore[method]))
204 | class_score_sd = std.error(unlist(ClassScore[method]))
205 |
206 | time_mean = mean(unlist(Time[method]))
207 | time_sd = std.error(unlist(Time[method]))
208 |
209 | data_summary_rank[method] = c(
210 | paste0(as.character(round(W_train_mean, 1)), ' (', as.character(round(W_train_sd, 1)), ')'),
211 | paste0(as.character(round(W_test_mean, 1)), ' (', as.character(round(W_test_sd, 1)), ')'),
212 | paste0(as.character(round(coverage_mean, 1)), ' (', as.character(round(coverage_sd, 1)), ')'),
213 | paste0(as.character(round(coverage_test_mean, 1)), ' (', as.character(round(coverage_test_sd, 1)), ')'),
214 | paste0(as.character(round(R2_fake_mean, 1)), ' (', as.character(round(R2_fake_sd, 1)), ')'),
215 | paste0(as.character(round(F1_fake_mean, 1)), ' (', as.character(round(F1_fake_sd, 1)), ')'),
216 | paste0(as.character(round(class_score_mean, 1)), ' (', as.character(round(class_score_sd, 1)), ')'),
217 | paste0(as.character(round(percent_bias_mean, 1)), ' (', as.character(round(percent_bias_sd, 1)), ')'),
218 | paste0(as.character(round(coverage_rate_mean, 1)), ' (', as.character(round(coverage_rate_sd, 1)), ')'),
219 | paste0(as.character(round(time_mean, 1)), ' (', as.character(round(time_sd, 1)), ')')
220 | )
221 |
222 | }
223 | rownames(data_summary_rank) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time")
224 | data_summary_rank
225 | data_summary_rank_t <- t(data_summary_rank)
226 | data_summary_rank_t
227 |
228 |
229 | ########### Latex tables
230 |
231 | rownames(data_summary_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTABGAN", "Stasy", "TabDDPM", "Forest-VP", "Forest-Flow", "Utrees", "Oracle")
232 | rownames(data_summary_rank_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTABGAN", "Stasy", "TabDDPM", "Forest-VP", "Forest-Flow", "UTrees")
233 |
234 | xtable(data_summary_t[,-10], type = "latex")
235 | xtable(data_summary_rank_t[,-10], type = "latex")
236 |
237 |
238 | if (FALSE) {
239 | # For building the barplot, to csv
240 |
241 | W_train = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
242 | colnames(W_train) = methods_all
243 | rownames(W_train) = all_vars
244 | W_train[,] = 0
245 | W_test = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
246 | colnames(W_test) = methods_all
247 | rownames(W_test) = all_vars
248 | W_test[,] = 0
249 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
250 | colnames(coverage) = methods_all
251 | rownames(coverage) = all_vars
252 | coverage[,] = 0
253 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
254 | colnames(coverage_test) = methods_all
255 | rownames(coverage_test) = all_vars
256 | coverage_test[,] = 0
257 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
258 | colnames(R2_fake) = methods_all
259 | rownames(R2_fake) = all_vars
260 | R2_fake[,] = 0
261 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
262 | colnames(F1_fake) = methods_all
263 | rownames(F1_fake) = all_vars
264 | F1_fake[,] = 0
265 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
266 | colnames(ClassScore) = methods_all
267 | rownames(ClassScore) = all_vars
268 | ClassScore[,] = 0
269 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
270 | colnames(percent_bias) = methods_all
271 | rownames(percent_bias) = all_vars
272 | percent_bias[,] = 0
273 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
274 | colnames(coverage_rate) = methods_all
275 | rownames(coverage_rate) = all_vars
276 | coverage_rate[,] = 0
277 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
278 | colnames(Time) = methods_all
279 | rownames(Time) = all_vars
280 | Time[,] = 0
281 |
282 | for (method in methods_all){
283 | W_train[rownames(W_train) %in% W_vars, method] = data$W_train[data$method==method & data$dataset %in% W_vars]
284 | W_test[rownames(W_train) %in% W_vars, method] = data$W_test[data$method==method & data$dataset %in% W_vars]
285 |
286 | coverage[, method] = data$coverage_train[data$method==method & data$dataset %in% all_vars]
287 | coverage_test[, method] = data$coverage_test[data$method==method & data$dataset %in% all_vars]
288 |
289 | R2_fake[rownames(W_train) %in% R2_vars, method] = data$R2_fake[data$method==method & data$dataset %in% R2_vars]
290 | F1_fake[rownames(W_train) %in% F1_vars, method] = data$F1_fake[data$method==method & data$dataset %in% F1_vars]
291 | ClassScore[rownames(W_train) %in% all_vars, method] = data$class_score[data$method==method & data$dataset %in% all_vars]
292 | percent_bias[rownames(W_train) %in% R2_vars, method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars]
293 | coverage_rate[rownames(W_train) %in% R2_vars, method] = data$coverage_rate[data$method==method & data$dataset %in% R2_vars]
294 | }
295 |
296 | # For building the barplot, to csv
297 | write.csv(W_test, file="gen_W_nmiss.csv")
298 | write.csv(coverage_test, file="gen_cov_nmiss.csv")
299 | write.csv(R2_fake, file="gen_R2_nmiss.csv")
300 | write.csv(F1_fake, file="gen_F1_nmiss.csv")
301 | write.csv(ClassScore, file="gen_disc_nmiss.csv")
302 | write.csv(percent_bias, file="gen_pb_nmiss.csv")
303 | write.csv(coverage_rate, file="gen_cr_nmiss.csv")
304 | }
305 |
--------------------------------------------------------------------------------
/Results/generation_script.R:
--------------------------------------------------------------------------------
1 | options(show.error.locations = TRUE)
2 | options(repos=c(CRAN = "https://cloud.r-project.org"))
3 | install.packages('matrixStats')
4 | install.packages('xtable')
5 | install.packages('plotrix')
6 | library(matrixStats)
7 | library(xtable)
8 | library(plotrix)
9 |
10 | # Change the folder to your own
11 | data = read.csv("Results_tabular - generation.csv2")
12 |
13 | # Better scaling for clean tables
14 | data$percent_bias = data$percent_bias / 100
15 |
16 | all_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'ecoli', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'california', 'bean', 'car','congress','tictactoe')
17 | W_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'ecoli', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'car','congress','tictactoe')
18 | R2_vars = c('concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'wine_quality_red', 'wine_quality_white', 'california')
19 | F1_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'ecoli', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'bean', 'car','congress','tictactoe')
20 |
21 | methods = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'utrees')
22 | methods_all = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'utrees', 'oracle')
23 |
24 | data_summary = data.frame(matrix(ncol = 10, nrow = 10))
25 | colnames(data_summary) = methods_all
26 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake","class_score", "percent_bias", "coverage_rate", "time")
27 |
28 | for (method in methods_all){
29 | W_train_mean = mean(data$W_train[data$method==method & data$dataset %in% W_vars])
30 | W_train_sd = std.error(data$W_train[data$method==method & data$dataset %in% W_vars])
31 |
32 | W_test_mean = mean(data$W_test[data$method==method & data$dataset %in% W_vars])
33 | W_test_sd = std.error(data$W_test[data$method==method & data$dataset %in% W_vars])
34 |
35 | coverage_mean = mean(data$coverage_train[data$method==method & data$dataset %in% all_vars])
36 | coverage_sd = std.error(data$coverage_train[data$method==method & data$dataset %in% all_vars])
37 |
38 | coverage_test_mean = mean(data$coverage_test[data$method==method & data$dataset %in% all_vars])
39 | coverage_test_sd = std.error(data$coverage_test[data$method==method & data$dataset %in% all_vars])
40 |
41 | R2_fake_mean = mean(data$R2_fake[data$method==method & data$dataset %in% R2_vars])
42 | R2_fake_sd = std.error(data$R2_fake[data$method==method & data$dataset %in% R2_vars])
43 |
44 | F1_fake_mean = mean(data$F1_fake[data$method==method & data$dataset %in% F1_vars])
45 | F1_fake_sd = std.error(data$F1_fake[data$method==method & data$dataset %in% F1_vars])
46 |
47 | percent_bias_mean = mean(data$percent_bias[data$method==method & data$dataset %in% R2_vars])
48 | percent_bias_sd = std.error(data$percent_bias[data$method==method & data$dataset %in% R2_vars])
49 |
50 | coverage_rate_mean = mean(data$coverage_rate[data$method==method & data$dataset %in% R2_vars])
51 | coverage_rate_sd = std.error(data$coverage_rate[data$method==method & data$dataset %in% R2_vars])
52 |
53 | class_score_mean = mean(data$class_score[data$method==method & data$dataset %in% all_vars])
54 | class_score_sd = std.error(data$class_score[data$method==method & data$dataset %in% all_vars])
55 |
56 | time_mean = mean(data$time[data$method==method & data$dataset %in% all_vars])
57 | time_sd = std.error(data$time[data$method==method & data$dataset %in% all_vars])
58 |
59 | data_summary[method] = c(
60 | paste0(as.character(round(W_train_mean, 2)), ' (', as.character(round(W_train_sd, 2)), ')'),
61 | paste0(as.character(round(W_test_mean, 2)), ' (', as.character(round(W_test_sd, 2)), ')'),
62 | paste0(as.character(round(coverage_mean, 2)), ' (', as.character(round(coverage_sd, 2)), ')'),
63 | paste0(as.character(round(coverage_test_mean, 2)), ' (', as.character(round(coverage_test_sd, 2)), ')'),
64 | paste0(as.character(round(R2_fake_mean, 2)), ' (', as.character(round(R2_fake_sd, 2)), ')'),
65 | paste0(as.character(round(F1_fake_mean, 2)), ' (', as.character(round(F1_fake_sd, 2)), ')'),
66 | paste0(as.character(round(class_score_mean, 2)), ' (', as.character(round(class_score_sd, 2)), ')'),
67 | paste0(as.character(round(percent_bias_mean, 2)), ' (', as.character(round(percent_bias_sd, 2)), ')'),
68 | paste0(as.character(round(coverage_rate_mean, 2)), ' (', as.character(round(coverage_rate_sd, 2)), ')'),
69 | paste0(as.character(round(time_mean, 2)), ' (', as.character(round(time_sd, 2)), ')')
70 | )
71 |
72 | }
73 | data_summary
74 | data_summary_t <- t(data_summary)
75 | data_summary_t
76 |
77 |
78 | ###################### RANK
79 |
80 | data_summary_rank = data.frame(matrix(ncol = 9, nrow = 10))
81 | colnames(data_summary_rank) = methods
82 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time")
83 |
84 | W_train = data.frame(matrix(ncol = 9, nrow = length(W_vars)))
85 | colnames(W_train) = methods
86 | rownames(W_train) = W_vars
87 | W_test = data.frame(matrix(ncol = 9, nrow = length(W_vars)))
88 | colnames(W_test) = methods
89 | rownames(W_test) = W_vars
90 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
91 | colnames(coverage) = methods
92 | rownames(coverage) = all_vars
93 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
94 | colnames(coverage_test) = methods
95 | rownames(coverage_test) = all_vars
96 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(R2_vars)))
97 | colnames(R2_fake) = methods
98 | rownames(R2_fake) = R2_vars
99 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(F1_vars)))
100 | colnames(F1_fake) = methods
101 | rownames(F1_fake) = F1_vars
102 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(R2_vars)))
103 | colnames(percent_bias) = methods
104 | rownames(percent_bias) = R2_vars
105 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(R2_vars)))
106 | colnames(coverage_rate) = methods
107 | rownames(coverage_rate) = R2_vars
108 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
109 | colnames(ClassScore) = methods
110 | rownames(ClassScore) = all_vars
111 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
112 | colnames(Time) = methods
113 | rownames(Time) = all_vars
114 |
115 | for (method in methods){
116 | print(method)
117 | print('W')
118 | W_train[method] = data$W_train[data$method==method & data$dataset %in% W_vars]
119 | W_test[method] = data$W_test[data$method==method & data$dataset %in% W_vars]
120 | print('cov')
121 | coverage[method] = -data$coverage_train[data$method==method & data$dataset %in% all_vars]
122 | coverage_test[method] = -data$coverage_test[data$method==method & data$dataset %in% all_vars]
123 | print('R2')
124 | R2_fake[method] = -data$R2_fake[data$method==method & data$dataset %in% R2_vars]
125 | print('F1')
126 | F1_fake[method] = -data$F1_fake[data$method==method & data$dataset %in% F1_vars]
127 | print('stats')
128 | percent_bias[method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars]
129 | coverage_rate[method] = -data$coverage_rate[data$method==method & data$dataset %in% R2_vars]
130 | print('class')
131 | ClassScore[method] = data$class_score[data$method==method & data$dataset %in% all_vars]
132 | print('time')
133 | Time[method] = data$time[data$method==method & data$dataset %in% all_vars]
134 | }
135 |
136 | # Rank by dataset
137 |
138 | W_train_ = as.data.frame(t(sapply(as.data.frame(t(W_train)), rank)))
139 | colnames(W_train_) = colnames(W_train)
140 | W_train = W_train_
141 |
142 | W_test_ = as.data.frame(t(sapply(as.data.frame(t(W_test)), rank)))
143 | colnames(W_test_) = colnames(W_test)
144 | W_test = W_test_
145 |
146 | coverage_ = as.data.frame(t(sapply(as.data.frame(t(coverage)), rank)))
147 | colnames(coverage_) = colnames(coverage)
148 | coverage = coverage_
149 | coverage_test_ = as.data.frame(t(sapply(as.data.frame(t(coverage_test)), rank)))
150 | colnames(coverage_test_) = colnames(coverage_test)
151 | coverage_test = coverage_test_
152 |
153 | R2_fake_ = as.data.frame(t(sapply(as.data.frame(t(R2_fake)), rank)))
154 | colnames(R2_fake_) = colnames(R2_fake)
155 | R2_fake = R2_fake_
156 |
157 | F1_fake_ = as.data.frame(t(sapply(as.data.frame(t(F1_fake)), rank)))
158 | colnames(F1_fake_) = colnames(F1_fake)
159 | F1_fake = F1_fake_
160 |
161 | percent_bias_ = as.data.frame(t(sapply(as.data.frame(t(percent_bias)), rank)))
162 | colnames(percent_bias_) = colnames(percent_bias)
163 | percent_bias = percent_bias_
164 | coverage_rate_ = as.data.frame(t(sapply(as.data.frame(t(coverage_rate)), rank)))
165 | colnames(coverage_rate_) = colnames(coverage_rate)
166 | coverage_rate = coverage_rate_
167 |
168 | ClassScore_ = as.data.frame(t(sapply(as.data.frame(t(ClassScore)), rank)))
169 | colnames(ClassScore_) = colnames(ClassScore)
170 | ClassScore = ClassScore_
171 |
172 | Time_ = as.data.frame(t(sapply(as.data.frame(t(Time)), rank)))
173 | colnames(Time_) = colnames(Time)
174 | Time = Time_
175 |
176 | for (method in methods){
177 | W_train_mean = mean(unlist(W_train[method]))
178 | W_train_sd = std.error(unlist(W_train[method]))
179 |
180 | W_test_mean = mean(unlist(W_test[method]))
181 | W_test_sd = std.error(unlist(W_test[method]))
182 |
183 | coverage_mean = mean(unlist(coverage[method]))
184 | coverage_sd = std.error(unlist(coverage[method]))
185 | coverage_test_mean = mean(unlist(coverage_test[method]))
186 | coverage_test_sd = std.error(unlist(coverage_test[method]))
187 |
188 | R2_fake_mean = mean(unlist(R2_fake[method]))
189 | R2_fake_sd = std.error(unlist(R2_fake[method]))
190 |
191 | F1_fake_mean = mean(unlist(F1_fake[method]))
192 | F1_fake_sd = std.error(unlist(F1_fake[method]))
193 |
194 | percent_bias_mean = mean(unlist(percent_bias[method]))
195 | percent_bias_sd = std.error(unlist(percent_bias[method]))
196 |
197 | coverage_rate_mean = mean(unlist(coverage_rate[method]))
198 | coverage_rate_sd = std.error(unlist(coverage_rate[method]))
199 |
200 | class_score_mean = mean(unlist(ClassScore[method]))
201 | class_score_sd = std.error(unlist(ClassScore[method]))
202 |
203 | time_mean = mean(unlist(Time[method]))
204 | time_sd = std.error(unlist(Time[method]))
205 |
206 | data_summary_rank[method] = c(
207 | paste0(as.character(round(W_train_mean, 1)), ' (', as.character(round(W_train_sd, 1)), ')'),
208 | paste0(as.character(round(W_test_mean, 1)), ' (', as.character(round(W_test_sd, 1)), ')'),
209 | paste0(as.character(round(coverage_mean, 1)), ' (', as.character(round(coverage_sd, 1)), ')'),
210 | paste0(as.character(round(coverage_test_mean, 1)), ' (', as.character(round(coverage_test_sd, 1)), ')'),
211 | paste0(as.character(round(R2_fake_mean, 1)), ' (', as.character(round(R2_fake_sd, 1)), ')'),
212 | paste0(as.character(round(F1_fake_mean, 1)), ' (', as.character(round(F1_fake_sd, 1)), ')'),
213 | paste0(as.character(round(class_score_mean, 1)), ' (', as.character(round(class_score_sd, 1)), ')'),
214 | paste0(as.character(round(percent_bias_mean, 1)), ' (', as.character(round(percent_bias_sd, 1)), ')'),
215 | paste0(as.character(round(coverage_rate_mean, 1)), ' (', as.character(round(coverage_rate_sd, 1)), ')'),
216 | paste0(as.character(round(time_mean, 1)), ' (', as.character(round(time_sd, 1)), ')')
217 |
218 | )
219 | }
220 |
221 | rownames(data_summary_rank) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time")
222 | data_summary_rank
223 | data_summary_rank_t <- t(data_summary_rank)
224 | data_summary_rank_t
225 |
226 |
227 |
228 | ############################## Latex tables
229 |
230 |
231 | rownames(data_summary_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTAB-GAN+", "STaSy", "TabDDPM", "Forest-VP", "Forest-Flow", "UTrees", "Oracle")
232 | rownames(data_summary_rank_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTAB-GAN+", "STaSy", "TabDDPM", "Forest-VP", "Forest-Flow", "UTrees")
233 |
234 | xtable(data_summary_t[,-10], type = "latex")
235 | xtable(data_summary_rank_t[,-10], type = "latex")
236 |
237 |
238 | if (FALSE) {
239 | ### ablation
240 |
241 | data_x = read.csv('C:/Users/Alexia-Mini/Downloads/Results_tabular - ablation_iris.csv')
242 | xtable(data_x, type = "latex")
243 |
244 | # For building the barplot, to csv
245 |
246 | W_train = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
247 | colnames(W_train) = methods_all
248 | rownames(W_train) = all_vars
249 | W_train[,] = 0
250 | W_test = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
251 | colnames(W_test) = methods_all
252 | rownames(W_test) = all_vars
253 | W_test[,] = 0
254 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
255 | colnames(coverage) = methods_all
256 | rownames(coverage) = all_vars
257 | coverage[,] = 0
258 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
259 | colnames(coverage_test) = methods_all
260 | rownames(coverage_test) = all_vars
261 | coverage_test[,] = 0
262 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
263 | colnames(R2_fake) = methods_all
264 | rownames(R2_fake) = all_vars
265 | R2_fake[,] = 0
266 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
267 | colnames(F1_fake) = methods_all
268 | rownames(F1_fake) = all_vars
269 | F1_fake[,] = 0
270 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
271 | colnames(ClassScore) = methods_all
272 | rownames(ClassScore) = all_vars
273 | ClassScore[,] = 0
274 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
275 | colnames(percent_bias) = methods_all
276 | rownames(percent_bias) = all_vars
277 | percent_bias[,] = 0
278 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
279 | colnames(coverage_rate) = methods_all
280 | rownames(coverage_rate) = all_vars
281 | coverage_rate[,] = 0
282 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars)))
283 | colnames(Time) = methods_all
284 | rownames(Time) = all_vars
285 | Time[,] = 0
286 |
287 | for (method in methods_all){
288 | W_train[rownames(W_train) %in% W_vars, method] = data$W_train[data$method==method & data$dataset %in% W_vars]
289 | W_test[rownames(W_train) %in% W_vars, method] = data$W_test[data$method==method & data$dataset %in% W_vars]
290 |
291 | coverage[, method] = data$coverage_train[data$method==method & data$dataset %in% all_vars]
292 | coverage_test[, method] = data$coverage_test[data$method==method & data$dataset %in% all_vars]
293 |
294 | R2_fake[rownames(W_train) %in% R2_vars, method] = data$R2_fake[data$method==method & data$dataset %in% R2_vars]
295 | F1_fake[rownames(W_train) %in% F1_vars, method] = data$F1_fake[data$method==method & data$dataset %in% F1_vars]
296 | ClassScore[rownames(W_train) %in% all_vars, method] = data$class_score[data$method==method & data$dataset %in% all_vars]
297 | percent_bias[rownames(W_train) %in% R2_vars, method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars]
298 | coverage_rate[rownames(W_train) %in% R2_vars, method] = data$coverage_rate[data$method==method & data$dataset %in% R2_vars]
299 | }
300 |
301 | # For building the barplot, to csv
302 | write.csv(W_test, file="gen_W.csv")
303 | write.csv(coverage_test, file="gen_cov.csv")
304 | write.csv(R2_fake, file="gen_R2.csv")
305 | write.csv(F1_fake, file="gen_F1.csv")
306 | write.csv(ClassScore, file="gen_disc.csv")
307 | write.csv(percent_bias, file="gen_pb.csv")
308 | write.csv(coverage_rate, file="gen_cr.csv")
309 | }
310 |
--------------------------------------------------------------------------------