├── tests ├── __init__.py ├── utrees │ ├── __init__.py │ ├── test_baltobot.py │ └── test_utrees.py └── test_imports.py ├── .gitattributes ├── paper ├── wave-demo.png ├── iris-vertical.png ├── poisson-demo.png ├── moons-generation.png ├── moons-imputation.png ├── wave-time-pdfat2.png ├── wave-baltobot-readme.jpg ├── baltobotmodel.py ├── moons.ipynb ├── baltobot-simple-demo.ipynb └── iris.ipynb ├── .pre-commit-config.yaml ├── utrees ├── __init__.py ├── kdi_quantizer.py ├── baltobot.py └── unmasking_trees.py ├── requirements.txt ├── LICENSE ├── Results ├── latex_generation_wmissing_baltobot4d8f71b.txt ├── latex_imputation_baltobot4d8f71b.txt ├── latex_generation_baltobot4d8f71b.txt ├── latex_imputation_ablation_baltobot4d8f71b.txt ├── tabular_imputation_results_kmeans.txt ├── tabular_imputation_results_tobt.txt ├── tabular_imputation_results.txt ├── tabular_imputation_results_baltobot4d8f71b.txt ├── latex_tables.txt ├── imputation_script.R ├── generation_script_nmiss50.R └── generation_script.R ├── pyproject.toml ├── setup.py ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utrees/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pdf binary 2 | *.eps binary 3 | *.png binary 4 | -------------------------------------------------------------------------------- /paper/wave-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/wave-demo.png -------------------------------------------------------------------------------- /paper/iris-vertical.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/iris-vertical.png -------------------------------------------------------------------------------- /paper/poisson-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/poisson-demo.png -------------------------------------------------------------------------------- /paper/moons-generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/moons-generation.png -------------------------------------------------------------------------------- /paper/moons-imputation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/moons-imputation.png -------------------------------------------------------------------------------- /paper/wave-time-pdfat2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/wave-time-pdfat2.png -------------------------------------------------------------------------------- /paper/wave-baltobot-readme.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calvinmccarter/unmasking-trees/HEAD/paper/wave-baltobot-readme.jpg -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | files: ^utrees/ 2 | repos: 3 | - repo: https://github.com/psf/black 4 | rev: 22.10.0 5 | hooks: 6 | - id: black 7 | -------------------------------------------------------------------------------- /tests/test_imports.py: -------------------------------------------------------------------------------- 1 | def test_import_packages(): 2 | """Test that importing our package works.""" 3 | import utrees 4 | from utrees import UnmaskingTrees 5 | from utrees import KDIQuantizer 6 | -------------------------------------------------------------------------------- /utrees/__init__.py: -------------------------------------------------------------------------------- 1 | from utrees.baltobot import Baltobot 2 | from utrees.baltobot import NanTabPFNClassifier 3 | from utrees.kdi_quantizer import KDIQuantizer 4 | from utrees.unmasking_trees import UnmaskingTrees 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | numpy 3 | pre-commit>=2.2.0 4 | pytest>=5.4.1 5 | scikit-learn>=0.23 6 | scipy>=1.0 7 | tqdm 8 | torch 9 | lightgbm 10 | xgboost 11 | pandas 12 | matplotlib 13 | jupyterlab 14 | seaborn 15 | -------------------------------------------------------------------------------- /tests/utrees/test_baltobot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from sklearn.datasets import make_moons 5 | 6 | from utrees import Baltobot 7 | 8 | 9 | def test_moons(): 10 | n_upper = 100 11 | n_lower = 100 12 | n_generate = 123 13 | n = n_upper + n_lower 14 | data, labels = make_moons( 15 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345) 16 | X = data[:, 0:1] 17 | y = data[:, 1] 18 | balto = Baltobot(random_state=12345) 19 | balto.fit(X, y) 20 | sampley = balto.sample(X) 21 | assert sampley.shape == (n, ) 22 | evaly = np.linspace(-3, 3, 1000) 23 | evalX = np.full((1000, 1), fill_value=0.) 24 | scores = balto.score_samples(evalX, evaly) 25 | assert scores.shape == (1000,) 26 | probs = np.exp(scores) 27 | assert 0 <= np.min(probs) 28 | assert 0.99 < np.sum(probs) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Calvin McCarter 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Results/latex_generation_wmissing_baltobot4d8f71b.txt: -------------------------------------------------------------------------------- 1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 2 | % Tue Sep 3 11:00:16 2024 3 | \begin{table}[ht] 4 | \centering 5 | \begin{tabular}{rlllllllll} 6 | \hline 7 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\ 8 | \hline 9 | GaussianCopula & 7 (0.3) & 7.1 (0.2) & 7.2 (0.3) & 7.1 (0.3) & 6.3 (0.4) & 6.6 (0.3) & 6.7 (0.4) & 5.5 (1) & 7.7 (0.6) \\ 10 | TVAE & 5.2 (0.3) & 4.9 (0.3) & 5.7 (0.3) & 5.8 (0.2) & 6 (1) & 5.8 (0.5) & 5.8 (0.4) & 8 (0.4) & 6.2 (1) \\ 11 | CTGAN & 8.3 (0.2) & 8.4 (0.2) & 8.4 (0.2) & 8.3 (0.2) & 8.3 (0.3) & 8.4 (0.2) & 6.5 (0.2) & 4.8 (1.2) & 7.1 (0.7) \\ 12 | CTABGAN & 6.7 (0.4) & 6.5 (0.4) & 7.1 (0.3) & 6.8 (0.3) & 7.3 (0.6) & 7.1 (0.4) & 6.6 (0.3) & 7.5 (1) & 6.1 (0.6) \\ 13 | Stasy & 5.9 (0.2) & 6.1 (0.3) & 5.3 (0.2) & 5.1 (0.3) & 5.8 (0.9) & 4.4 (0.4) & 5.3 (0.4) & 3.7 (0.4) & 4.6 (1.1) \\ 14 | TabDDPM & 3 (0.7) & 3.4 (0.7) & 2.3 (0.5) & 2.9 (0.6) & 1.7 (0.3) & 3.3 (0.6) & 3.9 (0.6) & 3.8 (1.2) & 2 (0.5) \\ 15 | Forest-VP & 3.7 (0.2) & 3.2 (0.3) & 3.9 (0.2) & 3.8 (0.3) & 3.2 (0.3) & 2.3 (0.3) & 4.2 (0.4) & 4.2 (0.8) & 4.5 (1.1) \\ 16 | Forest-Flow & 3 (0.3) & 2.6 (0.3) & 2.6 (0.3) & 2.7 (0.2) & 3 (0.7) & 3.7 (0.3) & 5 (0.5) & 3.8 (0.9) & 3.2 (0.8) \\ 17 | UTrees & 2.1 (0.2) & 2.8 (0.3) & 2.5 (0.2) & 2.5 (0.2) & 3.3 (0.8) & 3.5 (0.5) & 1 (0) & 3.7 (0.9) & 3.7 (1) \\ 18 | \hline 19 | \end{tabular} 20 | \end{table} -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "utrees" 7 | version = "0.3.0" 8 | description = "Unmasking trees for tabular data generation and imputation" 9 | readme = "README.md" 10 | authors = [ 11 | { name = "Calvin McCarter", email = "mccarter.calvin@gmail.com" }, 12 | ] 13 | maintainers = [ 14 | { name = "Calvin McCarter", email = "mccarter.calvin@gmail.com" }, 15 | ] 16 | keywords = [ 17 | "generation", 18 | "imputation", 19 | "tabular", 20 | ] 21 | classifiers = [ 22 | "Development Status :: 3 - Alpha", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "License :: OSI Approved", 26 | "Operating System :: MacOS", 27 | "Operating System :: Microsoft :: Windows", 28 | "Operating System :: POSIX", 29 | "Operating System :: Unix", 30 | "Programming Language :: Python", 31 | "Programming Language :: Python :: 3.6", 32 | "Programming Language :: Python :: 3.7", 33 | "Programming Language :: Python :: 3.8", 34 | "Programming Language :: Python :: 3.9", 35 | "Topic :: Scientific/Engineering", 36 | "Topic :: Software Development", 37 | ] 38 | dependencies = [ 39 | "kditransform", 40 | "numba >= 0.48", 41 | "numpy", 42 | "pandas", 43 | "scikit-learn >= 0.23", 44 | "scipy >= 1.0", 45 | "tqdm", 46 | "xgboost", 47 | ] 48 | 49 | [project.urls] 50 | Homepage = "http://github.com/calvinmccarter/unmasking-trees" 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | def readme(): 5 | with open("README.md") as readme_file: 6 | return readme_file.read() 7 | 8 | 9 | configuration = { 10 | "name": "utrees", 11 | "version": "0.3.0", 12 | "description": "Unmasking trees for tabular data generation and imputation", 13 | "long_description": readme(), 14 | "long_description_content_type": "text/markdown", 15 | "classifiers": [ 16 | "Development Status :: 3 - Alpha", 17 | "Intended Audience :: Science/Research", 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved", 20 | "Programming Language :: Python", 21 | "Topic :: Software Development", 22 | "Topic :: Scientific/Engineering", 23 | "Operating System :: Microsoft :: Windows", 24 | "Operating System :: POSIX", 25 | "Operating System :: Unix", 26 | "Operating System :: MacOS", 27 | "Programming Language :: Python :: 3.6", 28 | "Programming Language :: Python :: 3.7", 29 | "Programming Language :: Python :: 3.8", 30 | "Programming Language :: Python :: 3.9", 31 | ], 32 | "keywords": "tabular,imputation,generation", 33 | "url": "http://github.com/calvinmccarter/unmasking-trees", 34 | "author": "Calvin McCarter", 35 | "author_email": "mccarter.calvin@gmail.com", 36 | "maintainer": "Calvin McCarter", 37 | "maintainer_email": "mccarter.calvin@gmail.com", 38 | "packages": ["utrees"], 39 | "install_requires": [ 40 | "kditransform", 41 | "numba >= 0.48", 42 | "numpy", 43 | "pandas", 44 | "scikit-learn >= 0.23", 45 | "scipy >= 1.0", 46 | "tqdm", 47 | "xgboost", 48 | ], 49 | "ext_modules": [], 50 | "cmdclass": {}, 51 | "test_suite": "nose.collector", 52 | "tests_require": ["nose"], 53 | "data_files": (), 54 | "zip_safe": True, 55 | } 56 | 57 | setup(**configuration) 58 | -------------------------------------------------------------------------------- /paper/baltobotmodel.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import Optional 3 | 4 | from jaxtyping import Float 5 | from numpy import ndarray 6 | from sklearn.base import MultiOutputMixin 7 | from skopt.space import Categorical 8 | from skopt.space import Integer 9 | from skopt.space import Real 10 | 11 | import numpy as np 12 | import utrees 13 | 14 | from .base_model import ProbabilisticModel 15 | 16 | 17 | class BaltoBotModel(ProbabilisticModel, MultiOutputMixin): 18 | """ 19 | Wrapping the BaltoBot model as a ProbabilisticModel. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | learning_rate: float = 0.3, 25 | max_leaves: int = 0, 26 | subsample: float = 1, 27 | seed: int = 0, 28 | ): 29 | super().__init__(seed) 30 | self.subsample = subsample 31 | self.max_leaves = max_leaves 32 | self.learning_rate = learning_rate 33 | self.model = None 34 | 35 | def fit( 36 | self, 37 | X: Float[ndarray, "batch x_dim"], 38 | y: Float[ndarray, "batch y_dim"], 39 | cat_idx: Optional[List[int]] = None, 40 | ) -> "ProbabilisticModel": 41 | self.model = utrees.Baltobot( 42 | tabpfn=True, 43 | random_state=self.seed, 44 | xgboost_kwargs={ 45 | "subsample": self.subsample, 46 | "max_leaves": self.max_leaves, 47 | "learning_rate": self.learning_rate, 48 | }, 49 | ) 50 | assert y.shape[1] == 1 51 | self.model.fit(X, y.ravel()) 52 | return self 53 | 54 | def predict(self, X: Float[ndarray, "batch x_dim"]) -> Float[ndarray, "batch y_dim"]: 55 | y_samples = self.sample(X, n_samples=50) 56 | y_samples = y_samples.mean(axis=0) 57 | return y_samples 58 | 59 | def sample( 60 | self, X: Float[ndarray, "batch x_dim"], n_samples=100, seed=None 61 | ) -> Float[ndarray, "n_samples batch y_dim"]: 62 | batch = X.shape[0] 63 | y_samples = np.zeros((n_samples, batch, 1)) 64 | for n in range(n_samples): 65 | y_samples[n, :, 0] = self.model.sample(X) 66 | return y_samples 67 | 68 | @staticmethod 69 | def search_space() -> dict: 70 | return { 71 | "learning_rate": Real(0.05, 0.5, "log-uniform"), 72 | "max_leaves": Categorical([0, 25, 50]), 73 | "subsample": Real(0.3, 1., "log-uniform"), 74 | } 75 | 76 | def get_extra_stats(self) -> dict: 77 | return {} 78 | -------------------------------------------------------------------------------- /Results/latex_imputation_baltobot4d8f71b.txt: -------------------------------------------------------------------------------- 1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 2 | % Mon Sep 2 13:56:23 2024 3 | \begin{table}[ht] 4 | \centering 5 | \begin{tabular}{rlllllllll} 6 | \hline 7 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 8 | \hline 9 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\ 10 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\ 11 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\ 12 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\ 13 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\ 14 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\ 15 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\ 16 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\ 17 | utrees & 0.08 (0.02) & 0.14 (0.03) & 0.37 (0.08) & 1.87 (0.48) & 0.27 (0.07) & 0.61 (0.1) & 0.76 (0.04) & 0.55 (0.19) & 0.71 (0.13) \\ 18 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\ 19 | \hline 20 | \end{tabular} 21 | \end{table} 22 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 23 | % Mon Sep 2 13:56:23 2024 24 | \begin{table}[ht] 25 | \centering 26 | \begin{tabular}{rlllllllll} 27 | \hline 28 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 29 | \hline 30 | KNN & 5.5 (0.5) & 6.3 (0.4) & 4.9 (0.4) & 5 (0.4) & 8.4 (0) & 6.5 (1) & 5.7 (1.1) & 6.2 (1) & 5.4 (0.6) \\ 31 | ICE & 6.8 (0.4) & 4.7 (0.4) & 7 (0.5) & 7.2 (0.4) & 1.6 (0.2) & 6.2 (1) & 7 (0.6) & 5.7 (0.9) & 5.3 (0.6) \\ 32 | MICE-Forest & 3.9 (0.4) & 2.5 (0.4) & 2.9 (0.2) & 3 (0.2) & 3.6 (0.2) & 3.7 (1.4) & 3.2 (1) & 5.5 (1.2) & 4.3 (0.6) \\ 33 | MissForest & 2.7 (0.5) & 4 (0.4) & 1.8 (0.3) & 2 (0.3) & 5.5 (0.2) & 3.8 (1.4) & 2.5 (0.5) & 5.5 (1.5) & 3.3 (0.5) \\ 34 | Softimpute & 6.7 (0.4) & 7.6 (0.4) & 7.1 (0.5) & 7.3 (0.5) & 8.4 (0) & 6 (0.9) & 7.8 (0.4) & 6.3 (0.9) & 6.7 (0.4) \\ 35 | OT & 5.9 (0.4) & 6.1 (0.3) & 6 (0.5) & 6 (0.5) & 3.7 (0.3) & 6.2 (0.5) & 6.8 (0.6) & 5.5 (0.8) & 4.8 (0.5) \\ 36 | GAIN & 4.7 (0.4) & 6.5 (0.3) & 6 (0.3) & 6 (0.2) & 6.9 (0.1) & 5.7 (0.8) & 5.4 (0.8) & 4.7 (1) & 5 (0.6) \\ 37 | Forest-VP & 5.3 (0.4) & 4 (0.5) & 5.8 (0.3) & 5.1 (0.4) & 3.2 (0.4) & 4.5 (0.9) & 4.6 (0.8) & 3.3 (0.6) & 5.5 (0.7) \\ 38 | utrees & 3.5 (0.5) & 3.2 (0.5) & 3.5 (0.4) & 3.5 (0.5) & 3.8 (0.2) & 2.5 (0.6) & 2.2 (0.6) & 2.3 (0.9) & 4.7 (0.6) \\ 39 | \hline 40 | \end{tabular} 41 | \end{table} -------------------------------------------------------------------------------- /Results/latex_generation_baltobot4d8f71b.txt: -------------------------------------------------------------------------------- 1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 2 | % Sat Sep 7 22:19:01 2024 3 | \begin{table}[ht] 4 | \centering 5 | \begin{tabular}{rlllllllll} 6 | \hline 7 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\ 8 | \hline 9 | GaussianCopula & 2.74 (0.56) & 2.99 (0.61) & 0.18 (0.04) & 0.37 (0.06) & 0.2 (0.14) & 0.46 (0.06) & NA (0.05) & 2.27 (0.77) & 0.23 (0.12) \\ 10 | TVAE & 2.12 (0.58) & 2.35 (0.63) & 0.33 (0.04) & 0.63 (0.04) & -0.47 (0.61) & 0.52 (0.08) & NA (0.01) & 4.15 (1.97) & 0.26 (0.09) \\ 11 | CTGAN & 3.58 (0.99) & 3.74 (1.01) & 0.12 (0.03) & 0.28 (0.04) & -0.43 (0.08) & 0.35 (0.04) & NA (0.01) & 2.48 (1.3) & 0.2 (0.08) \\ 12 | CTAB-GAN+ & 2.71 (0.81) & 2.89 (0.83) & 0.22 (0.04) & 0.44 (0.05) & 0.05 (0.12) & 0.44 (0.05) & NA (0.02) & 2.95 (1.04) & 0.26 (0.07) \\ 13 | STaSy & 3.41 (1.39) & 3.66 (1.42) & 0.38 (0.05) & 0.63 (0.05) & -4.21 (4.44) & 0.61 (0.06) & NA (0.02) & 1.23 (0.44) & 0.45 (0.12) \\ 14 | TabDDPM & 4.27 (1.89) & 4.79 (1.89) & 0.76 (0.06) & 0.8 (0.06) & 0.6 (0.11) & 0.66 (0.06) & NA (0.03) & 0.76 (0.28) & 0.72 (0.11) \\ 15 | Forest-VP & 1.46 (0.4) & 1.94 (0.5) & 0.67 (0.05) & 0.84 (0.03) & 0.55 (0.1) & 0.73 (0.04) & NA (0.01) & 0.94 (0.3) & 0.52 (0.15) \\ 16 | Forest-Flow & 1.36 (0.39) & 1.9 (0.5) & 0.83 (0.03) & 0.9 (0.03) & 0.57 (0.11) & 0.73 (0.04) & NA (0.01) & 0.83 (0.23) & 0.63 (0.11) \\ 17 | UTrees & 1.54 (0.47) & 2 (0.55) & 0.73 (0.04) & 0.87 (0.02) & 0.49 (0.1) & 0.7 (0.04) & NA (NA) & 1.2 (0.28) & 0.4 (0.09) \\ 18 | Oracle & 0 (0) & 1.81 (0.47) & 0.99 (0.01) & 0.91 (0.04) & 0.64 (0.09) & 0.77 (0.04) & NA (0) & 0 (0) & 1 (0) \\ 19 | \hline 20 | \end{tabular} 21 | \end{table} 22 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 23 | % Sat Sep 7 22:19:01 2024 24 | \begin{table}[ht] 25 | \centering 26 | \begin{tabular}{rlllllllll} 27 | \hline 28 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\ 29 | \hline 30 | GaussianCopula & 7.1 (0.3) & 7.2 (0.3) & 7.3 (0.3) & 7.4 (0.3) & 6.2 (0.2) & 6.4 (0.3) & 7 (0.4) & 6.5 (1.1) & 7.5 (0.7) \\ 31 | TVAE & 5.3 (0.2) & 5.1 (0.2) & 5.7 (0.2) & 5.7 (0.2) & 6.5 (0.7) & 6 (0.5) & 5.5 (0.3) & 7.3 (0.6) & 6.7 (0.6) \\ 32 | CTGAN & 8.4 (0.1) & 8.4 (0.2) & 8.3 (0.2) & 8.1 (0.2) & 8.5 (0.2) & 8.3 (0.2) & 6.7 (0.3) & 5.3 (1.1) & 7.2 (0.5) \\ 33 | CTAB-GAN+ & 6.8 (0.3) & 6.7 (0.3) & 7.2 (0.3) & 7.1 (0.3) & 6.8 (0.4) & 6.9 (0.4) & 6.9 (0.3) & 7.7 (0.8) & 6.7 (0.8) \\ 34 | STaSy & 6.1 (0.2) & 6.3 (0.2) & 5.3 (0.2) & 5.4 (0.2) & 6 (1.2) & 5.1 (0.3) & 6.1 (0.3) & 4.5 (0.8) & 4.2 (1.1) \\ 35 | TabDDPM & 3 (0.7) & 3.9 (0.6) & 2.8 (0.5) & 3.4 (0.5) & 1.2 (0.2) & 3.8 (0.6) & 3.2 (0.4) & 3 (0.9) & 1.4 (0.2) \\ 36 | Forest-VP & 3.2 (0.2) & 2.8 (0.2) & 3.6 (0.3) & 3.3 (0.3) & 2.8 (0.3) & 2.2 (0.3) & 4.3 (0.4) & 3.2 (0.9) & 3.5 (0.8) \\ 37 | Forest-Flow & 1.9 (0.2) & 1.5 (0.2) & 1.7 (0.2) & 1.8 (0.2) & 2.3 (0.4) & 2.4 (0.3) & 4.3 (0.4) & 2.8 (0.5) & 2.7 (0.4) \\ 38 | UTrees & 3.1 (0.1) & 3.1 (0.2) & 3.1 (0.2) & 2.8 (0.2) & 4.7 (0.3) & 3.9 (0.3) & 1 (0) & 4.7 (0.7) & 5.2 (0.9) \\ 39 | \hline 40 | \end{tabular} 41 | \end{table} -------------------------------------------------------------------------------- /Results/latex_imputation_ablation_baltobot4d8f71b.txt: -------------------------------------------------------------------------------- 1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 2 | % Tue Sep 3 11:40:24 2024 3 | \begin{table}[ht] 4 | \centering 5 | \begin{tabular}{rlllllllll} 6 | \hline 7 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 8 | \hline 9 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\ 10 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\ 11 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\ 12 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\ 13 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\ 14 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\ 15 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\ 16 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\ 17 | utrees-kmeans & 0.1 (0.02) & 0.15 (0.03) & 0.43 (0.09) & 1.9 (0.5) & 0.28 (0.06) & 0.61 (0.1) & 0.76 (0.04) & 0.63 (0.21) & 0.72 (0.13) \\ 18 | utrees-kdi & 0.1 (0.02) & 0.14 (0.03) & 0.42 (0.09) & 1.89 (0.49) & 0.27 (0.06) & 0.61 (0.1) & 0.76 (0.04) & 0.68 (0.24) & 0.68 (0.14) \\ 19 | utrees & 0.08 (0.02) & 0.14 (0.03) & 0.37 (0.08) & 1.87 (0.48) & 0.27 (0.07) & 0.61 (0.1) & 0.76 (0.04) & 0.55 (0.19) & 0.71 (0.13) \\ 20 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\ 21 | \hline 22 | \end{tabular} 23 | \end{table} 24 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 25 | % Tue Sep 3 11:40:24 2024 26 | \begin{table}[ht] 27 | \centering 28 | \begin{tabular}{rlllllllll} 29 | \hline 30 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 31 | \hline 32 | KNN & 6.8 (0.6) & 7.8 (0.6) & 6 (0.4) & 6.1 (0.5) & 10.4 (0) & 8.2 (1.3) & 7 (1.5) & 7.5 (1.5) & 6.5 (0.8) \\ 33 | ICE & 8.3 (0.5) & 5.8 (0.5) & 8.5 (0.6) & 8.8 (0.5) & 1.9 (0.4) & 8 (1.1) & 9 (0.6) & 7.2 (1.1) & 6.4 (0.8) \\ 34 | MICE-Forest & 4.8 (0.6) & 3.3 (0.6) & 3.5 (0.3) & 3.4 (0.3) & 4.6 (0.4) & 4.3 (1.8) & 4.3 (1.3) & 6.8 (1.6) & 4.8 (0.7) \\ 35 | MissForest & 3.3 (0.7) & 5 (0.6) & 2.2 (0.4) & 2.3 (0.4) & 7.2 (0.3) & 4.7 (1.8) & 3.3 (0.9) & 6.8 (1.9) & 3.8 (0.6) \\ 36 | Softimpute & 8.3 (0.5) & 9.3 (0.5) & 8.8 (0.6) & 8.9 (0.6) & 10.4 (0) & 7.5 (1.2) & 9.8 (0.4) & 8.3 (0.9) & 7.9 (0.6) \\ 37 | OT & 7.2 (0.5) & 7.6 (0.4) & 7.4 (0.6) & 7.4 (0.6) & 4.8 (0.4) & 8.2 (0.5) & 8.8 (0.6) & 7.3 (0.7) & 5.8 (0.7) \\ 38 | GAIN & 5.8 (0.5) & 8.3 (0.4) & 7.2 (0.5) & 7.5 (0.4) & 8.9 (0.1) & 7.5 (0.8) & 7.4 (0.8) & 6.7 (1) & 6.1 (0.8) \\ 39 | Forest-VP & 6.4 (0.5) & 4.8 (0.6) & 7 (0.4) & 6.1 (0.5) & 3.8 (0.5) & 6.5 (0.9) & 6.6 (0.8) & 4.5 (0.8) & 6.5 (0.8) \\ 40 | utrees-kmeans & 6 (0.6) & 5.8 (0.5) & 6.3 (0.6) & 6.1 (0.6) & 4.1 (0.3) & 4 (0.7) & 2.9 (0.6) & 3.8 (1) & 6 (0.7) \\ 41 | utrees-kdi & 5.1 (0.5) & 5.1 (0.5) & 5.4 (0.6) & 5.6 (0.5) & 4.8 (0.3) & 4.5 (0.9) & 4 (0.5) & 3.5 (1.2) & 6.4 (0.7) \\ 42 | utrees & 3.8 (0.5) & 3.2 (0.5) & 3.8 (0.4) & 3.8 (0.5) & 5 (0.3) & 2.7 (0.6) & 2.9 (0.8) & 3.5 (0.8) & 5.8 (0.7) \\ 43 | \hline 44 | \end{tabular} 45 | \end{table} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # unmasking-trees 😷➡️🥳 🌲🌲🌲 2 | 3 | [![PyPI version](https://badge.fury.io/py/utrees.svg)](https://badge.fury.io/py/utrees) 4 | [![Downloads](https://static.pepy.tech/badge/utrees)](https://pepy.tech/project/utrees) 5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | 7 | UnmaskingTrees is a method for tabular data generation and imputation. It's an order-agnostic autoregressive diffusion model, wherein a training dataset is constructed by incrementally masking features in random order. Per-feature gradient-boosted trees are then trained to unmask each feature. Read more about it in my [paper](https://arxiv.org/abs/2407.05593)! 8 | 9 | To better model conditional distributions which are multi-modal ("modal" as in "modes", not as in "modalities"), we hierarchically partition each feature, recursively training XGBoost classifiers at each node of the binary "meta-tree". This approach for conditional modeling of individual features, dubbed BaltoBot, outperforms quantile regression and diffusion-based probabilistic prediction. You can also customize quantization via the `quantize_cols` parameter in the `fit` method. Provide a list of length `n_dims`, with values in `('continuous', 'categorical', 'integer')`. Given `categorical` it currently skips quantization of that feature. 10 | 11 |
12 |
Here's how well UnmaskingTrees works on imputation with the Two Moons synthetic dataset:
13 | drawing 14 |
15 | 16 |
17 |
Here's how well BaltoBot works on probabilistic prediction:
18 | drawing 19 |
20 | 21 | ## Installation 22 | 23 | ### Installation from PyPI 24 | ``` 25 | pip install utrees 26 | ``` 27 | 28 | ### Installation from source 29 | After cloning this repo, install the dependencies on the command-line, then install utrees: 30 | ``` 31 | pip install -r requirements.txt 32 | pip install -e . 33 | pytest 34 | ``` 35 | 36 | ## Usage 37 | 38 | Check out [this notebook](https://github.com/calvinmccarter/unmasking-trees/blob/master/paper/moons.ipynb) with the Two Moons example, or [this one](https://github.com/calvinmccarter/unmasking-trees/blob/master/paper/iris.ipynb) with the Iris dataset. 39 | 40 | ### Synthetic data generation 41 | 42 | You can fit `utrees.UnmaskingTrees` the way you would an sklearn model, with the added option that you can call `fit` with `quantize_cols`, a list to specify which columns are continuous (and therefore need to be discretized). By default, all columns are assumed to contain continuous features. 43 | 44 | ``` 45 | import numpy as np 46 | from sklearn.datasets import make_moons 47 | from utrees import UnmaskingTrees 48 | data, labels = make_moons((100, 100), shuffle=False, noise=0.1, random_state=123) # size (200, 2) 49 | utree = UnmaskingTrees().fit(data) 50 | ``` 51 | 52 | Then, you can generate new data: 53 | 54 | ``` 55 | newdata = utree.generate(n_generate=123) # size (123, 2) 56 | ``` 57 | 58 | ### Missing data imputation 59 | 60 | You can fit your `UnmaskingTrees` model on data with missing elements, provided as `np.nan`. You can then impute the missing values, potentially with multiple imputations per missing element. Given an array of `(n_samples, n_dims)`, you will get back an array of size `(n_impute, n_samples, n_dims)`, where the NaNs have been replaced while the others are unchanged. 61 | 62 | ``` 63 | data4impute = data.copy() 64 | data4impute[:, 1] = np.nan 65 | X=np.concatenate([data, data4impute], axis=0) # size (400, 2) 66 | utree = UnmaskingTrees().fit(X) 67 | imputeddata = utree.impute(n_impute=5) # size (5, 400, 2) 68 | ``` 69 | 70 | You can also provide a totally new dataset to be imputed, so the model performs imputation without retraining: 71 | 72 | ``` 73 | utree = UnmaskingTrees().fit(data) 74 | imputeddata = utree.impute(n_impute=5, X=data4impute) # size (5, 200, 2) 75 | ``` 76 | 77 | ### Hyperparameters 78 | 79 | - depth: Depth of balanced binary tree for recursively quantizing each feature. 80 | - duplicate_K: Number of random masking orders per actual sample. The training dataset will be of size `(n_samples * n_dims * duplicate_K, n_dims)`. 81 | - xgboost_kwargs: dict to pass to XGBClassifier. 82 | - strategy: how to quantize continuous features ('kdiquantile', 'quantile', 'uniform', or 'kmeans'). 83 | - random_state: controls randomness. 84 | 85 | 86 | ## Citing this method 87 | 88 | Please consider citing the UnmaskingTrees [arXiv preprint](https://arxiv.org/pdf/2407.05593). The bibtex is: 89 | 90 | ``` 91 | @article{mccarter2024unmasking, 92 | title={Unmasking Trees for Tabular Data}, 93 | author={McCarter, Calvin}, 94 | journal={arXiv preprint arXiv:2407.05593}, 95 | year={2024} 96 | } 97 | ```` 98 | 99 | Also, please consider citing ForestDiffusion ([code](https://github.com/SamsungSAILMontreal/ForestDiffusion) and [paper](https://arxiv.org/abs/2309.09968)), which this work builds on. 100 | -------------------------------------------------------------------------------- /tests/utrees/test_utrees.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import time 4 | 5 | from sklearn.datasets import ( 6 | load_iris, 7 | make_moons, 8 | ) 9 | from sklearn.neighbors import KernelDensity 10 | 11 | from utrees import UnmaskingTrees 12 | 13 | CREATE_PLOTS = False 14 | 15 | @pytest.mark.parametrize("strategy, depth, duplicate_K, min_score", [ 16 | ("kdiquantile", 3, 10, -2.), 17 | ("kdiquantile", 3, 10, -2.), 18 | ("kdiquantile", 4, 10, -2.), 19 | ("kdiquantile", 4, 50, -2.), 20 | ("kdiquantile", 5, 50, -2.), 21 | ]) 22 | def test_moons_generate(strategy, depth, duplicate_K, min_score): 23 | n_upper = 100 24 | n_lower = 100 25 | n_generate = 123 26 | n = n_upper + n_lower 27 | data, labels = make_moons( 28 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345) 29 | 30 | utree = UnmaskingTrees( 31 | depth=depth, 32 | strategy=strategy, 33 | duplicate_K=duplicate_K, 34 | random_state=12345, 35 | ) 36 | utree.fit(data) 37 | newdata = utree.generate(n_generate=n_generate) 38 | if CREATE_PLOTS: 39 | import matplotlib.pyplot as plt 40 | plt.figure() 41 | plt.scatter(newdata[:, 0], newdata[:, 1]); 42 | plt.savefig(f'test_moons_generate-{strategy}-{depth}-{duplicate_K}.pdf') 43 | assert newdata.shape == (n_generate, 2) 44 | 45 | kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(data) 46 | scores = kde.score_samples(newdata) 47 | assert scores.mean() >= min_score 48 | 49 | 50 | @pytest.mark.parametrize("strategy, depth, duplicate_K, cast_float32, min_score, k", [ 51 | ("kdiquantile", 3, 10, True, -1.5, 1), 52 | ("kdiquantile", 3, 10, True, -1.5, 1), 53 | ("kdiquantile", 4, 10, True, -1.5, 1), 54 | ("kdiquantile", 4, 50, True, -1.5, 1), 55 | ("kdiquantile", 5, 50, True, -1.5, 1), 56 | ("kdiquantile", 3, 10, False, -1.5, 1), 57 | ("kdiquantile", 3, 10, False, -1.5, 1), 58 | ("kdiquantile", 4, 10, False, -1.5, 1), 59 | ("kdiquantile", 4, 50, False, -1.5, 1), 60 | ("kdiquantile", 5, 50, False, -1.5, 1), 61 | ("kdiquantile", 3, 10, False, -1.5, 3), 62 | ("kdiquantile", 3, 10, False, -1.5, 3), 63 | ("kdiquantile", 4, 10, False, -1.5, 3), 64 | ("kdiquantile", 4, 50, False, -1.5, 3), 65 | ("kdiquantile", 5, 50, False, -1.5, 3), 66 | ]) 67 | def test_moons_impute(strategy, depth, cast_float32, duplicate_K, min_score, k): 68 | n_upper = 100 69 | n_lower = 100 70 | n = n_upper + n_lower 71 | data, labels = make_moons( 72 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345) 73 | data4impute = data.copy() 74 | data4impute[:, 1] = np.nan 75 | X=np.concatenate([data, data4impute], axis=0) 76 | 77 | utree = UnmaskingTrees( 78 | depth=depth, 79 | strategy=strategy, 80 | duplicate_K=duplicate_K, 81 | cast_float32=cast_float32, 82 | random_state=12345, 83 | ) 84 | utree.fit(X) 85 | kde = KernelDensity(kernel='gaussian', bandwidth=0.2).fit(data) 86 | 87 | # Tests using fitted data 88 | imputedX = utree.impute(n_impute=k) 89 | assert imputedX.shape == (k, X.shape[0], X.shape[1]) 90 | nannystate = ~np.isnan(X) 91 | for kk in range(k): 92 | imputedXcur = imputedX[kk, :, :] 93 | # Assert observed data is unchanged 94 | np.testing.assert_equal(imputedXcur[nannystate], X[nannystate]) 95 | # Assert all NaNs removed 96 | assert np.isnan(imputedXcur).sum() == 0 97 | scores = kde.score_samples(imputedXcur) 98 | # Assert imputed data is close to observed data distribution 99 | assert scores.mean() >= min_score 100 | if CREATE_PLOTS: 101 | import matplotlib.pyplot as plt 102 | plt.figure() 103 | plt.scatter(imputedX[0, :, 0], imputedX[0, :, 1]); 104 | plt.savefig(f'test_moons_impute-{strategy}-{depth}-{duplicate_K}.pdf') 105 | 106 | # Tests providing data to impute 107 | imputedX = utree.impute(n_impute=k, X=data4impute) 108 | assert imputedX.shape == (k, data4impute.shape[0], data4impute.shape[1]) 109 | nannystate = ~np.isnan(data4impute) 110 | for kk in range(k): 111 | imputedXcur = imputedX[kk, :, :] 112 | 113 | np.testing.assert_equal(imputedXcur[nannystate], data4impute[nannystate]) 114 | scores = kde.score_samples(imputedXcur) 115 | assert scores.mean() >= min_score 116 | 117 | def test_moons_score_samples(): 118 | n_upper = 100 119 | n_lower = 100 120 | n_generate = 123 121 | n = n_upper + n_lower 122 | data, labels = make_moons( 123 | (n_upper, n_lower), shuffle=False, noise=0.1, random_state=12345) 124 | 125 | utree = UnmaskingTrees(random_state=12345) 126 | utree.fit(data) 127 | scores = utree.score_samples(data) 128 | assert scores.shape == (200,) 129 | assert np.min(scores) > -0.5 130 | 131 | @pytest.mark.parametrize('target_type', [ 132 | 'continuous', 'categorical', 'integer', 133 | ]) 134 | def test_iris(target_type): 135 | my_data = load_iris() 136 | X, y = my_data['data'], my_data['target'] 137 | n = X.shape[0] 138 | Xy = np.concatenate((X, np.expand_dims(y, axis=1)), axis=1) 139 | ut_model = UnmaskingTrees(random_state=12345) 140 | ut_model.fit(Xy, quantize_cols=['continuous']*4 + [target_type]) 141 | Xy_gen_utrees = ut_model.generate(n_generate=n) 142 | col_names = my_data['feature_names'] + ['target_names'] 143 | petalw = col_names.index('petal width (cm)') 144 | petall = col_names.index('petal length (cm)') 145 | assert np.unique(Xy_gen_utrees[:, -1]).size == 3 146 | assert np.unique(Xy_gen_utrees[:, petall]).size > n - 10 147 | assert 50 <= np.unique(Xy_gen_utrees[:, petalw]).size < 100 148 | -------------------------------------------------------------------------------- /Results/tabular_imputation_results_kmeans.txt: -------------------------------------------------------------------------------- 1 | iris,MCAR(0.2),utrees,0.06436957134970737,0.09820029507565886,0.0,0.0,0.07406719488532511,0.2432489111314778,0.0035090903348105984,0.16882548182880674,0.14711389375938483,0.0,0.9533635056734473,5.240656932195027,0.0,0.9888721804511277,0.0,0.901246399258095,0.0,0.9657031467557784,0.0,0.9576322962287873 2 | wine,MCAR(0.2),utrees,0.10703117843764451,0.1500977475902249,0.0,0.0,0.40604224206690687,1.473187616004656,0.01009838423341393,0.2996230669137918,0.25479589314502804,0.0,0.9506248663768152,43.43113430341085,0.0,0.9471527074612197,0.0,0.8861589265469987,0.0,0.9910199172018928,0.0,0.9781679142971489 3 | parkinsons,MCAR(0.2),utrees,0.05551321452108013,0.08524537540547424,0.0,0.0,0.38426840866581696,1.7739918499212077,0.0065856889843212835,0.19988221995294644,0.16980011273351714,0.0,0.824004272401742,132.9267556667328,0.0,0.745174148812501,0.0,0.849369991297509,0.0,0.8452874989174725,0.0,0.856185450579485 4 | climate_model_crashes,MCAR(0.2),utrees,0.2445575415292093,0.3445494156494738,0.0,0.0,1.2426306605524053,3.8839612332455897,0.04650630572459653,0.814139369339031,0.7026100326117098,0.0,0.6776350950992771,275.9443356990814,0.0,0.8300706838056722,0.0,0.7040739514597695,0.0,0.4823199662304771,0.0,0.6940757789011901 5 | concrete_compression,MCAR(0.2),utrees,0.03210828842755,0.056903224276065374,139.9718029014977,0.22962962962962963,0.08964405779202678,0.5147330608069801,0.0040364825986187736,0.1363458859714353,0.11504088872391835,0.7501702370745574,0.0,115.43401980400085,0.5628067782571478,0.0,0.7370457763261031,0.0,0.8608975679029993,0.0,0.8399308258119796,0.0 6 | yacht_hydrodynamics,MCAR(0.2),utrees,0.047272415294858,0.09254175399891258,90.36945434969547,0.9809523809523808,0.09956587072417561,0.5353933813685045,0.007927911237284414,0.21188108168258518,0.18117387052298115,0.8961366668730334,0.0,14.487419525782265,0.6082034338573993,0.0,0.9851336619886606,0.0,0.9958894477927889,0.0,0.9953201238532842,0.0 7 | airfoil_self_noise,MCAR(0.2),utrees,0.037140222355384125,0.083416395329536,4.063234761149441,1.0,0.06861066115715818,0.2632849337812995,0.00906505128438473,0.2407473071657216,0.21101632145252772,0.726789378235716,0.0,61.72480670611064,0.5073497692761099,0.0,0.6260807273477592,0.0,0.8835357178295489,0.0,0.8901912984894461,0.0 8 | connectionist_bench_sonar,MCAR(0.2),utrees,0.12594510082585392,0.1575817200024125,0.0,0.0,1.9018403823025247,8.832364115331467,0.01524932730916776,0.36679856139287836,0.31067455387346066,0.0,0.7934502116490639,1217.6911363601685,0.0,0.7189555311397549,0.0,0.8081375928343585,0.0,0.8141292146446422,0.0,0.8325785079775009 9 | ionosphere,MCAR(0.2),utrees,0.09313030024983923,0.12300146406326612,0.0,0.0,0.8143947765317358,4.441172261619947,0.011246970278524575,0.2295432671127555,0.18978076793547133,0.0,0.8963486608285675,542.9356789588928,0.0,0.8371620356469938,0.0,0.9124253626300302,0.0,0.9388813305587143,0.0,0.8969259144785321 10 | qsar_biodegradation,MCAR(0.2),utrees,0.019527684473879602,0.027992824231103,0.0,0.0,0.23001487701993542,1.3982340329151852,0.001090018369507807,0.05588821862156047,0.04755009522818494,0.0,0.8446301860513109,1721.6118144194284,0.0,0.8467325146932341,0.0,0.8306542562137806,0.0,0.8528585746608389,0.0,0.8482753986373902 11 | seeds,MCAR(0.2),utrees,0.0615945639620859,0.09729499079999612,0.0,0.0,0.14014201616568808,0.4889822967496476,0.004909057215034096,0.20851218941450503,0.18024954813653288,0.0,0.8759684438048736,16.49672166506449,0.0,0.9312245285578619,0.0,0.7145206838096979,0.0,0.9206997637526373,0.0,0.9374287990992972 12 | glass,MCAR(0.2),utrees,0.056683279899008815,0.08418955377322465,0.0,0.0,0.15643328511219365,0.6629243709553694,0.005207425246596656,0.14701144249299913,0.12408580449687279,0.0,0.5400166896494403,27.309990644454956,0.0,0.5740018743324092,0.0,0.31148857904273825,0.0,0.6503503977514334,0.0,0.62422590747118 13 | ecoli,MCAR(0.2),utrees,0.07824011385064211,0.13473034411769022,0.0,0.0,0.19019169962553592,0.479811891076667,0.007038748092204368,0.26565302594552415,0.2363700591455002,0.0,0.6791062196038589,21.325106700261433,0.0,0.8008041093956577,0.0,0.39762456661176226,0.0,0.8004187478580278,0.0,0.7175774545499879 14 | yeast,MCAR(0.2),utrees,0.06941257417556176,0.12116565147613076,0.0,0.0,0.18620507065740766,0.3962720433056949,0.005375955083372073,0.2540420088841035,0.2300358190571259,0.0,0.4441336024074725,123.21836757659912,0.0,0.5279781488912345,0.0,0.24764316643869705,0.0,0.5049291325437661,0.0,0.4959839617561922 15 | libras,MCAR(0.2),utrees,0.04337659174664109,0.054188700617526744,0.0,0.0,0.9781172408544512,9.087276475583508,0.002856839376379127,0.14195162056368987,0.12417461862538949,0.0,0.5668127539679112,5551.839271465938,0.0,0.707391694725028,0.0,0.046151433972063276,0.0,0.8003763298429966,0.0,0.7133315573315573 16 | planning_relax,MCAR(0.2),utrees,0.10076780027173504,0.14972397424144582,0.0,0.0,0.37768020618544945,1.497770661672028,0.011499330547311781,0.3311062364520646,0.28206416197506373,0.0,0.45340396658398313,39.88673202196757,0.0,0.4101516792156664,0.0,0.498282297881215,0.0,0.4314105982187582,0.0,0.4737712910202928 17 | blood_transfusion,MCAR(0.2),utrees,0.041859507915742486,0.07342162483104794,0.0,0.0,0.03713507439059031,0.1120434048586678,0.004067193066445469,0.15581042388651056,0.13510836072561155,0.0,0.5932689994790394,22.858375310897827,0.0,0.5587901528296861,0.0,0.6208276889233224,0.0,0.5952800168363233,0.0,0.5981781393268258 18 | breast_cancer_diagnostic,MCAR(0.2),utrees,0.04528500136927905,0.059025866114478055,0.0,0.0,0.3547227648437304,1.8643734784330732,0.0016410376729804662,0.11556167150646471,0.10212184999736261,0.0,0.9585506174839166,749.6472186247507,0.0,0.9559701297680048,0.0,0.9570539836293582,0.0,0.9570414965792896,0.0,0.9641368599590135 19 | connectionist_bench_vowel,MCAR(0.2),utrees,0.061398164895609314,0.1039307212244072,0.0,0.0,0.20811034421411279,0.7362538908543316,0.005596639823743729,0.2511663316487918,0.2201338332636656,0.0,0.6648635771081516,159.47676269213358,0.0,0.6416428022305954,0.0,0.2316282014085894,0.0,0.9161645012332263,0.0,0.8700188035601953 20 | concrete_slump,MCAR(0.2),utrees,0.1333530760953803,0.2087869919401345,47.93493627427608,0.7416666666666666,0.2918570338779636,1.1616849899114061,0.020308189548916368,0.41754187877797,0.3449666588250957,0.6617184542749365,0.0,9.68403736750285,0.7671580550398038,0.0,0.6267470830753135,0.0,0.6622900928791725,0.0,0.590678586105456,0.0 21 | wine_quality_red,MCAR(0.2),utrees,0.043987667301618495,0.07125263601096245,34.704274950328255,1.0,0.1413180725084123,0.517063871726785,0.00269422418878734,0.1602979380994121,0.14191509671282354,0.30131984599606504,0.0,237.22118894259137,0.31614333412554846,0.0,0.28376551757830254,0.0,0.36190006000046004,0.0,0.2434704722799491,0.0 22 | wine_quality_white,MCAR(0.2),utrees,0.036103445695891,0.06268985046227243,84.75792080373398,0.5555555555555555,0.13722323358222618,0.4512826541091945,0.002325762946555635,0.1625876365418874,0.1462748590117181,0.33976933748992155,0.0,922.8720918496449,0.26866062995313283,0.0,0.2833327598765625,0.0,0.4323614905194416,0.0,0.3747224696105493,0.0 23 | california,MCAR(0.2),utrees,0.02069658137403717,0.05039244339583727,26.527604512665064,0.5259259259259259,0.0,0.0,0.0044263438785602975,0.152781496744373,0.13628165960935112,0.6474823603570286,0.0,2793.5273619492846,0.5902896683351077,0.0,0.4149105745891897,0.0,0.7815453463555943,0.0,0.8031838521482231,0.0 24 | bean,MCAR(0.2),utrees,0.014279476432662424,0.026375358712360665,0.0,0.0,0.0,0.0,0.0010325532628770013,0.08178814262320773,0.07467284322933423,0.0,0.785844029183105,6019.071438709895,0.0,0.7804254388322259,0.0,0.49402300709239133,0.0,0.9333896534836575,0.0,0.935538017324145 25 | tictactoe,MCAR(0.2),utrees,0.3302788913404004,0.5047493883573801,0.0,0.0,0.7706701479547405,1.923353640847097,0.019924972355421657,1.1954983413252993,0.9192553211800524,0.0,0.8260940741613819,24.388316075007122,0.0,0.7628063445625028,0.0,0.7088680436660513,0.0,0.900034114313759,0.0,0.932667794103214 26 | congress,MCAR(0.2),utrees,0.1890438211894098,0.2666739760698372,0.0,0.0,0.7869731800766308,2.359770114942525,0.007118368287168615,0.4271020972301169,0.3176210020716755,0.0,0.9343496695461981,28.509195009867355,0.0,0.9181753899593685,0.0,0.9395054129002633,0.0,0.9437576372605899,0.0,0.9359602380645711 27 | car,MCAR(0.2),utrees,0.45735025839689725,0.6895113519508365,0.0,0.0,0.5051616015436621,1.0813146733811403,0.02605064368622248,1.8221719722572072,1.4401332127730782,0.0,0.819537784685739,29.95012299219767,0.0,0.8518452545343415,0.0,0.7390913722188874,0.0,0.8187829437640975,0.0,0.8684315682256301 28 | -------------------------------------------------------------------------------- /Results/tabular_imputation_results_tobt.txt: -------------------------------------------------------------------------------- 1 | 271,iris,MCAR(0.2),utrees,0.0738508822640819,0.08621383344459618,0.0,0.0,0.0641272750008991,0.23875198098882847,0.0010849252155242946,0.060747518260253,0.050638420184983594,0.0,0.9592957351290684,6.241593678792317,0.0,0.9844054580896685,0.0,0.9208420883859479,0.0,0.9719632414369256,0.0,0.9599721526037313 2 | 272,wine,MCAR(0.2),utrees,0.10847833071435786,0.1273898930976937,0.0,0.0,0.34477782264980283,1.436851201603759,0.0022681103341390227,0.1234763420665328,0.10114861403442213,0.0,0.9500534430124237,31.166414578755692,0.0,0.9395602586713697,0.0,0.8874777130686393,0.0,0.9913154133784308,0.0,0.9818603869312548 3 | 273,parkinsons,MCAR(0.2),utrees,0.05216119916570852,0.06282486691810032,0.0,0.0,0.2829609680941807,1.7057565523311848,0.0011464499359240465,0.06601010209647172,0.05496063290496879,0.0,0.8197209207924157,67.40996813774109,0.0,0.7448269072722554,0.0,0.8205556286221438,0.0,0.8631517929896829,0.0,0.8503493542855808 4 | 274,climate_model_crashes,MCAR(0.2),utrees,0.2808700271957786,0.33931437568865086,0.0,0.0,1.2236978336740576,3.8843171348144097,0.01981178122553466,0.41717855463271586,0.3435656973505204,0.0,0.732257043439506,122.85987591743469,0.0,0.825510951198097,0.0,0.7657393075515603,0.0,0.5183956256861775,0.0,0.8193822893221895 5 | 275,concrete_compression,MCAR(0.2),utrees,0.03208937400224214,0.0449843462357446,95.02280181221532,0.3555555555555555,0.06950370247361591,0.5018743805043794,0.001868359858469494,0.06612196061508667,0.05437818351022399,0.7585029515684282,0.0,57.00657558441162,0.56886904606764,0.0,0.7410657368387151,0.0,0.8713681358466046,0.0,0.8527088875207527,0.0 6 | 276,yacht_hydrodynamics,MCAR(0.2),utrees,0.043468972067696544,0.055231025220134546,24.99998836102717,1.0,0.05533555453165772,0.5070570594782794,0.0038244437942096835,0.06134887096816548,0.04755683277595425,0.8965889191222446,0.0,10.832573652267456,0.6041882722010656,0.0,0.9897168134399937,0.0,0.9966183190506115,0.0,0.9958322717973075,0.0 7 | 277,airfoil_self_noise,MCAR(0.2),utrees,0.03601927283689502,0.056442182219915434,3.602237261278865,1.0,0.04072949157540118,0.248333447015149,0.004789185084037393,0.09646590551563475,0.07741676024143188,0.7259813075249679,0.0,37.661417404810585,0.5089654463051735,0.0,0.6209385469506234,0.0,0.8853326392820764,0.0,0.888688597561998,0.0 8 | 278,connectionist_bench_sonar,MCAR(0.2),utrees,0.10325464112454648,0.11572714409057966,0.0,0.0,1.396946930870398,8.490695914502343,0.0024130064258775265,0.12086357229675992,0.09973393051247184,0.0,0.7961100547917862,484.8436369101206,0.0,0.7305667860775371,0.0,0.7917005749057889,0.0,0.826771736701832,0.0,0.8354011214819863 9 | 279,ionosphere,MCAR(0.2),utrees,0.09274800547354099,0.11362675052595668,0.0,0.0,0.7530395291283524,4.428334329991218,0.007861143042296006,0.14133471175286777,0.11219339961349425,0.0,0.9079339126392311,226.5654911994934,0.0,0.8704189158504423,0.0,0.913392839991736,0.0,0.9347812384932109,0.0,0.9131426562215345 10 | 280,qsar_biodegradation,MCAR(0.2),utrees,0.016950361470988242,0.021991270274625278,0.0,0.0,0.18070410987453125,1.3799801507746234,0.0007616428769309639,0.029576808224069363,0.02395775645070889,0.0,0.8464247159723208,607.7569101651509,0.0,0.8493412113767648,0.0,0.8298656606493132,0.0,0.8522279272923382,0.0,0.8542640645708677 11 | 281,seeds,MCAR(0.2),utrees,0.06558186398759239,0.0818383717537164,0.0,0.0,0.11806093655638271,0.4758862336495159,0.0012582765118176159,0.08697161235318307,0.07302355757537271,0.0,0.8815366461629893,16.37628761927287,0.0,0.9413927210862075,0.0,0.7207285702837521,0.0,0.9292578215030322,0.0,0.9347674717789659 12 | 282,glass,MCAR(0.2),utrees,0.062050774465719924,0.07363746271541312,0.0,0.0,0.13664248771775228,0.6396874281780643,0.0018520039435309709,0.06659311935922695,0.054469169747801305,0.0,0.5200758571660892,19.19480045636495,0.0,0.5561984171683729,0.0,0.3218814414073257,0.0,0.6756645735935488,0.0,0.5265589964951093 13 | 283,ecoli,MCAR(0.2),utrees,0.06453228583324466,0.07761285744910246,0.0,0.0,0.1068919956390921,0.401773448510927,0.0011357623617926616,0.06588892321651871,0.054906589233967906,0.0,0.6939790972544977,17.368636290232338,0.0,0.7412225128742794,0.0,0.43088376118385335,0.0,0.8040077940861244,0.0,0.7998023208737332 14 | 284,yeast,MCAR(0.2),utrees,0.05895051406248822,0.06929448120458553,0.0,0.0,0.10311086536653524,0.3164615766358421,0.0009136845133219293,0.059834683194831914,0.049196346463218736,0.0,0.4443877944655975,72.2295389175415,0.0,0.5457200855194014,0.0,0.23381158819953407,0.0,0.5280076182562742,0.0,0.470011885887181 15 | 285,libras,MCAR(0.2),utrees,0.030900635362958416,0.03501986248736869,0.0,0.0,0.6320113064928582,8.927761073481095,0.0004123027700454415,0.05310431815487,0.046219991058215906,0.0,0.5763398196282307,2609.6879085699716,0.0,0.7266957363624029,0.0,0.0656637197506967,0.0,0.8029773387106721,0.0,0.7100224836891503 16 | 286,planning_relax,MCAR(0.2),utrees,0.09702467095872941,0.11677807762938469,0.0,0.0,0.294182350854369,1.4490622094955194,0.0024094633661303623,0.12394061275131349,0.10278404304613803,0.0,0.4656162078089963,29.979744831720986,0.0,0.41686890777319113,0.0,0.5128104109771937,0.0,0.4468584026329554,0.0,0.4859271098526448 17 | 287,blood_transfusion,MCAR(0.2),utrees,0.05250419832506664,0.06273627296215434,0.0,0.0,0.03417193597069273,0.11108166737525997,0.0009944133374911643,0.05049024576996288,0.04146980880327056,0.0,0.5909121710221329,15.97951825459798,0.0,0.5517558674610492,0.0,0.6136836696925864,0.0,0.5923459245005579,0.0,0.6058632224343383 18 | 288,breast_cancer_diagnostic,MCAR(0.2),utrees,0.04248342287389557,0.049991474574389065,0.0,0.0,0.30053080575196994,1.8403605511797796,0.0005091639620044048,0.05575907086130531,0.047511666583858865,0.0,0.9559513116989217,350.96746794382733,0.0,0.9565720466705175,0.0,0.950461040184128,0.0,0.9552653152164161,0.0,0.9615068447246249 19 | 289,connectionist_bench_vowel,MCAR(0.2),utrees,0.06443732869709562,0.08390184593288284,0.0,0.0,0.16803464619838576,0.7150281660352723,0.0019676379777712324,0.11290896288770606,0.09527958048799712,0.0,0.6612800914403835,105.91338030497232,0.0,0.6372764684188381,0.0,0.2091374133599314,0.0,0.9306224975774907,0.0,0.868083986405273 20 | 290,concrete_slump,MCAR(0.2),utrees,0.14946281300454173,0.17743712288363436,50.530740741392904,0.6083333333333333,0.24881000910566675,1.1559990504226032,0.005047817126113788,0.14313051571169366,0.11337670728696354,0.6813521957189256,0.0,12.78035036722819,0.739561590918806,0.0,0.630150174247179,0.0,0.692742719426452,0.0,0.6629542982832657,0.0 21 | 291,wine_quality_red,MCAR(0.2),utrees,0.05145864317084543,0.06453161345110682,48.04485802828075,1.0,0.1277755226999056,0.5127004214844084,0.0010588208629251602,0.0743885786222127,0.062241781160318296,0.3121816124810164,0.0,125.84174267450967,0.3161742440812635,0.0,0.29229668834595846,0.0,0.3695386514219234,0.0,0.27071686607492024,0.0 22 | 292,wine_quality_white,MCAR(0.2),utrees,0.042593903786731196,0.05505373728025714,74.43341369338839,0.5111111111111111,0.12074880247067918,0.44595229298284883,0.0008893732992206452,0.0717649368850878,0.060195874296490495,0.3360038411750484,0.0,450.17718839645386,0.26787301686659737,0.0,0.279956460830806,0.0,0.4362087549627255,0.0,0.35997713204006476,0.0 23 | 293,california,MCAR(0.2),utrees,0.030580305310313076,0.042526266610077086,20.09642335364787,0.49629629629629635,0.0,0.0,0.0014429541654798791,0.05886321962908722,0.04926780789859086,0.654840243501494,0.0,1278.6375177701314,0.5915782908304531,0.0,0.43503356451736247,0.0,0.7823153738824907,0.0,0.8104337447756698,0.0 24 | 294,bean,MCAR(0.2),utrees,0.013745072140138138,0.01936528961128679,0.0,0.0,0.0,0.0,0.00032362892341678007,0.036293178330078155,0.03214980747629488,0.0,0.7827055281793264,2614.430075327555,0.0,0.7895514158594643,0.0,0.47266482176976476,0.0,0.9335693470438311,0.0,0.9350365280442451 25 | 295,tictactoe,MCAR(0.2),utrees,0.2934076415933544,0.5098007688248267,0.0,0.0,0.7750217580504756,1.923804213809092,0.02449292020035092,1.4695752120210548,1.1344409107142162,0.0,0.8228683601865097,30.083193381627403,0.0,0.7564519924401047,0.0,0.7051401013935286,0.0,0.8980469741728717,0.0,0.9318343727395332 26 | 296,congress,MCAR(0.2),utrees,0.17090379339869008,0.2739610466703952,0.0,0.0,0.7969348659003855,2.373180076628348,0.009687297604583728,0.5812378562750236,0.43111389253746524,0.0,0.9348000068108923,34.59704256057739,0.0,0.918716692128598,0.0,0.9357333401867587,0.0,0.9488637868680861,0.0,0.935886208060126 27 | 297,car,MCAR(0.2),utrees,0.4012625479200538,0.6826690498634778,0.0,0.0,0.4819585142305888,1.0697378070373702,0.029454461974252055,2.0602595278022156,1.6405766453740909,0.0,0.7836670401499316,37.12033700942993,0.0,0.8325497237184465,0.0,0.6373154415056159,0.0,0.8105739604491329,0.0,0.8542290349265308 28 | -------------------------------------------------------------------------------- /Results/tabular_imputation_results.txt: -------------------------------------------------------------------------------- 1 | iris , MCAR(0.2) , utrees , 0.06171134900290921 , 0.09286773754928909 , 0.0 , 0.0 , 0.06980779208322663 , 0.2397118967709102 , 0.0030206573130502293 , 0.15978792534668954 , 0.13658212292576377 , 0.0 , 0.9422763803480664 , 5.160777409871419 , 0.0 , 0.9866332497911444 , 0.0 , 0.8702625553914046 , 0.0 , 0.9634663088698177 , 0.0 , 0.9487434073398984 2 | wine , MCAR(0.2) , utrees , 0.10133481778238645 , 0.1439122311329542 , 0.0 , 0.0 , 0.3895478278301316 , 1.4645430063143305 , 0.008630066798274041 , 0.2829484754161229 , 0.23908105735580445 , 0.0 , 0.9387301795008196 , 46.97041702270508 , 0.0 , 0.9424267274641209 , 0.0 , 0.8394181522214272 , 0.0 , 0.9920429711070391 , 0.0 , 0.9810328672106903 3 | parkinsons , MCAR(0.2) , utrees , 0.052245517775528584 , 0.07630051184397803 , 0.0 , 0.0 , 0.3425923785306587 , 1.7516166306883203 , 0.004681727732015078 , 0.17334971838062277 , 0.14844598117147428 , 0.0 , 0.8151910426840616 , 140.84079313278198 , 0.0 , 0.7497556842583829 , 0.0 , 0.8285637760247496 , 0.0 , 0.8434252258261432 , 0.0 , 0.8390194846269702 4 | climate_model_crashes , MCAR(0.2) , utrees , 0.25313482958818423 , 0.36388498949116616 , 0.0 , 0.0 , 1.3125855770026345 , 3.9308358361374527 , 0.058244783820290275 , 0.9247016041837531 , 0.7911253027692842 , 0.0 , 0.6687687876893608 , 303.72695573170984 , 0.0 , 0.7815931202543376 , 0.0 , 0.715044914544767 , 0.0 , 0.485082313212326 , 0.0 , 0.6933548027460131 5 | concrete_compression , MCAR(0.2) , utrees , 0.032205947195894674 , 0.05695862771138322 , 144.7204580042744 , 0.2148148148148148 , 0.08956094803214486 , 0.51587579790969 , 0.0039197280163505745 , 0.13092247924677985 , 0.10971847972639154 , 0.7506571899571826 , 0.0 , 110.7270872592926 , 0.5635045538284041 , 0.0 , 0.7426887614177372 , 0.0 , 0.8593832855583278 , 0.0 , 0.8370521590242612 , 0.0 6 | yacht_hydrodynamics , MCAR(0.2) , utrees , 0.04600778874007024 , 0.08016372314818775 , 130.04117264988247 , 0.8952380952380953 , 0.08194873657018106 , 0.5212796431105499 , 0.006810609756057632 , 0.16949249963483226 , 0.13837548506080527 , 0.8959975452512271 , 0.0 , 12.314840714136759 , 0.6079608300788333 , 0.0 , 0.9850257618973024 , 0.0 , 0.995870603583432 , 0.0 , 0.9951329854453406 , 0.0 7 | airfoil_self_noise , MCAR(0.2) , utrees , 0.030701968468046875 , 0.07149217526283297 , 4.426930062522432 , 1.0 , 0.05527627319305399 , 0.255348994194407 , 0.0087494171344429 , 0.20225251607348688 , 0.1726012568420425 , 0.7288531573629374 , 0.0 , 61.05456566810608 , 0.5081040017180197 , 0.0 , 0.6304988371229822 , 0.0 , 0.883136433947107 , 0.0 , 0.8936733566636403 , 0.0 8 | connectionist_bench_sonar , MCAR(0.2) , utrees , 0.11712002366199573 , 0.14441412477065826 , 0.0 , 0.0 , 1.7433453859125134 , 8.675589836023347 , 0.010292110372589414 , 0.31071706160493046 , 0.2646020070218156 , 0.0 , 0.789587435581842 , 1221.7377189000447 , 0.0 , 0.7365130837758166 , 0.0 , 0.7853301364692437 , 0.0 , 0.8107724273701802 , 0.0 , 0.8257340947121282 9 | ionosphere , MCAR(0.2) , utrees , 0.09598313460475286 , 0.12398897668765305 , 0.0 , 0.0 , 0.8212715075994221 , 4.452108646183872 , 0.010532593079086822 , 0.22607161765535344 , 0.1872299367967365 , 0.0 , 0.8971300518445046 , 544.0828011035919 , 0.0 , 0.8335572700966875 , 0.0 , 0.912945819294007 , 0.0 , 0.937803996414423 , 0.0 , 0.9042131215729008 10 | qsar_biodegradation , MCAR(0.2) , utrees , 0.018528342646379738 , 0.026946503392505854 , 0.0 , 0.0 , 0.22144208357565423 , 1.3998370122598578 , 0.0012175729510766418 , 0.055858663558376945 , 0.047125598729442726 , 0.0 , 0.8440079997743205 , 1700.7666370868683 , 0.0 , 0.8492198286276188 , 0.0 , 0.8287018939881141 , 0.0 , 0.8524921264445106 , 0.0 , 0.845618150037039 11 | seeds , MCAR(0.2) , utrees , 0.06156737632234309 , 0.09532838023661071 , 0.0 , 0.0 , 0.13727812015204705 , 0.4818844341111245 , 0.0041747195976600746 , 0.19172721995362418 , 0.16607901238564932 , 0.0 , 0.8876327743360164 , 16.02963876724243 , 0.0 , 0.9392051110518542 , 0.0 , 0.7593816864030843 , 0.0 , 0.917707799905501 , 0.0 , 0.9342364999836262 12 | glass , MCAR(0.2) , utrees , 0.05690890150540151 , 0.08010917130462747 , 0.0 , 0.0 , 0.14928190914582376 , 0.6556699224704824 , 0.004381661572533172 , 0.1302607772932201 , 0.11177271760603738 , 0.0 , 0.5480099780432925 , 26.026296854019165 , 0.0 , 0.584371607108248 , 0.0 , 0.3177680015479405 , 0.0 , 0.6578281866630086 , 0.0 , 0.6320721168539731 13 | ecoli , MCAR(0.2) , utrees , 0.05629358048089226 , 0.08645477169671197 , 0.0 , 0.0 , 0.11921348698293545 , 0.4096222068800511 , 0.004157091351302514 , 0.1608139108009185 , 0.13803921739849476 , 0.0 , 0.7018928975330095 , 23.58275763193766 , 0.0 , 0.7876587322107141 , 0.0 , 0.438941456261841 , 0.0 , 0.799616636930169 , 0.0 , 0.7813547647293144 14 | yeast , MCAR(0.2) , utrees , 0.04535097744322566 , 0.07376880404694568 , 0.0 , 0.0 , 0.10812830155357582 , 0.31949312377513595 , 0.0035031547260809034 , 0.1568094959224052 , 0.13619697159380023 , 0.0 , 0.4573822425410245 , 135.45259356498718 , 0.0 , 0.5372650246353922 , 0.0 , 0.24436352760113153 , 0.0 , 0.5434300184509417 , 0.0 , 0.5044703994766325 15 | libras , MCAR(0.2) , utrees , 0.04227497283897813 , 0.052300055087484684 , 0.0 , 0.0 , 0.9441110249540525 , 9.064625281197733 , 0.002386057977089709 , 0.1351752676790998 , 0.11785024852303577 , 0.0 , 0.57221071121062 , 5732.288067579269 , 0.0 , 0.7167657034323701 , 0.0 , 0.054326598126233104 , 0.0 , 0.8033840825840824 , 0.0 , 0.714366460699794 16 | planning_relax , MCAR(0.2) , utrees , 0.09274661933098485 , 0.14182509487282988 , 0.0 , 0.0 , 0.3577914833040218 , 1.4829015968756003 , 0.010863641129838425 , 0.33208213992992913 , 0.2843774348922432 , 0.0 , 0.45357795616942376 , 39.74175707499186 , 0.0 , 0.40884725474889433 , 0.0 , 0.4809160683691389 , 0.0 , 0.4311217822489055 , 0.0 , 0.49342671931075655 17 | blood_transfusion , MCAR(0.2) , utrees , 0.040394437501713114 , 0.07071949375830425 , 0.0 , 0.0 , 0.0353891109851724 , 0.11301666799401906 , 0.00436917758066595 , 0.15255551896399422 , 0.13027972107616098 , 0.0 , 0.5891788320696434 , 20.771088282267254 , 0.0 , 0.551163926355941 , 0.0 , 0.6263315792365312 , 0.0 , 0.5861426479620243 , 0.0 , 0.5930771747240771 18 | breast_cancer_diagnostic , MCAR(0.2) , utrees , 0.04347497878768007 , 0.05818701882778201 , 0.0 , 0.0 , 0.34965808120004954 , 1.8694501937914967 , 0.0020290733719139677 , 0.12244840244125642 , 0.10732914690076638 , 0.0 , 0.9563682040026953 , 790.5507067044575 , 0.0 , 0.9552594974942447 , 0.0 , 0.9559410899580288 , 0.0 , 0.9570062408277437 , 0.0 , 0.9572659877307639 19 | connectionist_bench_vowel , MCAR(0.2) , utrees , 0.05976956621841621 , 0.10382249912440761 , 0.0 , 0.0 , 0.20792119393041208 , 0.737387697709447 , 0.006410373126504616 , 0.26627932092860485 , 0.2320615991149993 , 0.0 , 0.6628657357921165 , 162.6827143828074 , 0.0 , 0.6420072805560552 , 0.0 , 0.22917844837129117 , 0.0 , 0.9198757522587209 , 0.0 , 0.8604014619823989 20 | concrete_slump , MCAR(0.2) , utrees , 0.13336506913435178 , 0.21163758490490042 , 78.37393727464351 , 0.7166666666666666 , 0.296686493798486 , 1.1753105280747134 , 0.020353934344685538 , 0.41736394686933864 , 0.3433281455096098 , 0.6495419122376092 , 0.0 , 8.064836661020914 , 0.7383579877825754 , 0.0 , 0.6537969299030401 , 0.0 , 0.6491909134165881 , 0.0 , 0.556821817848233 , 0.0 21 | wine_quality_red , MCAR(0.2) , utrees , 0.04184544842818448 , 0.0721537184045824 , 22.31912314809874 , 0.9939393939393939 , 0.14254835041317498 , 0.5198545898986846 , 0.003952135412439399 , 0.18018191797690808 , 0.15692197788181259 , 0.29580732210337435 , 0.0 , 261.72702566782635 , 0.3173300368121892 , 0.0 , 0.28245331130378315 , 0.0 , 0.3492149836631822 , 0.0 , 0.2342309566343427 , 0.0 22 | wine_quality_white , MCAR(0.2) , utrees , 0.03527853218161474 , 0.06364601037858153 , 71.46802560034284 , 0.5 , 0.13879125826315522 , 0.4537299873404477 , 0.0033708521907209153 , 0.1781443095298026 , 0.15794928677256942 , 0.33958677441484636 , 0.0 , 1023.7504643599192 , 0.26957257718077343 , 0.0 , 0.28524000983461484 , 0.0 , 0.43435310087766693 , 0.0 , 0.3691814097663301 , 0.0 23 | california , MCAR(0.2) , utrees , 0.02085733850254894 , 0.052293952522410515 , 33.24752259308236 , 0.4518518518518519 , 0.0 , 0.0 , 0.005275218897392506 , 0.16294639904729727 , 0.1442791493561823 , 0.6562729852476104 , 0.0 , 3039.9609892368317 , 0.5865712054822544 , 0.0 , 0.4482862136463962 , 0.0 , 0.7846209767129884 , 0.0 , 0.8056135451488026 , 0.0 24 | bean , MCAR(0.2) , utrees , 0.013300777074474399 , 0.026065942643214362 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0015598047749835822 , 0.08307973988602241 , 0.07537123904968326 , 0.0 , 0.7889339055492804 , 6546.674813985825 , 0.0 , 0.7696592871750925 , 0.0 , 0.5184820537374906 , 0.0 , 0.9330313282932579 , 0.0 , 0.9345629529912808 25 | tictactoe , MCAR(0.2) , utrees , 0.3302788913404004 , 0.5047493883573801 , 0.0 , 0.0 , 0.7706701479547405 , 1.923353640847097 , 0.019924972355421657 , 1.1954983413252993 , 0.9192553211800524 , 0.0 , 0.8260135515119187 , 24.652372360229492 , 0.0 , 0.7624842539646501 , 0.0 , 0.7088680436660513 , 0.0 , 0.900034114313759 , 0.0 , 0.9326677941032139 26 | congress , MCAR(0.2) , utrees , 0.1890438211894098 , 0.2666739760698372 , 0.0 , 0.0 , 0.7869731800766308 , 2.359770114942525 , 0.007118368287168615 , 0.4271020972301169 , 0.3176210020716755 , 0.0 , 0.9346015785483384 , 28.701904376347862 , 0.0 , 0.9191830259679294 , 0.0 , 0.9395054129002633 , 0.0 , 0.9437576372605899 , 0.0 , 0.9359602380645711 27 | car , MCAR(0.2) , utrees , 0.45735025839689725 , 0.6895113519508365 , 0.0 , 0.0 , 0.5051616015436621 , 1.0813146733811403 , 0.02605064368622248 , 1.8221719722572072 , 1.4401332127730782 , 0.0 , 0.8191771109071323 , 29.15845568974813 , 0.0 , 0.8504025594199145 , 0.0 , 0.7390913722188874 , 0.0 , 0.8187829437640975 , 0.0 , 0.8684315682256301 28 | -------------------------------------------------------------------------------- /Results/tabular_imputation_results_baltobot4d8f71b.txt: -------------------------------------------------------------------------------- 1 | iris , MCAR(0.2) , utrees , 0.05999383499585656 , 0.0891422525651869 , 0.0 , 0.0 , 0.06622844415087935 , 0.2399510489123476 , 0.0026519395607755435 , 0.14782735105221712 , 0.12211262422232116 , 0.0 , 0.952578092959672 , 5.305421670277913 , 0.0 , 0.9866332497911444 , 0.0 , 0.8935526948684843 , 0.0 , 0.9657432470064047 , 0.0 , 0.9643831801726537 2 | wine , MCAR(0.2) , utrees , 0.09483351208087865 , 0.1309220058072117 , 0.0 , 0.0 , 0.3544626598677714 , 1.4421646366534844 , 0.005640737222570057 , 0.23726015274535953 , 0.19939392703444025 , 0.0 , 0.9369265175356536 , 26.762377103169758 , 0.0 , 0.9632461278304145 , 0.0 , 0.8151893968182157 , 0.0 , 0.9871233059233867 , 0.0 , 0.9821472395705978 3 | parkinsons , MCAR(0.2) , utrees , 0.04692251597772083 , 0.06515695016716706 , 0.0 , 0.0 , 0.2935529924783244 , 1.7112368414543835 , 0.002726737547563139 , 0.12343294098537427 , 0.10333986633321987 , 0.0 , 0.8296471697593928 , 58.98031155268352 , 0.0 , 0.7628487453017361 , 0.0 , 0.838878856219079 , 0.0 , 0.8611200454212758 , 0.0 , 0.8557410320954802 4 | climate_model_crashes , MCAR(0.2) , utrees , 0.23759707521127293 , 0.33835467617379056 , 0.0 , 0.0 , 1.2202348736572843 , 3.8713029908094416 , 0.04378390850643781 , 0.7939392208828289 , 0.6876264315110343 , 0.0 , 0.7075253091655707 , 103.73175875345868 , 0.0 , 0.8216503397548771 , 0.0 , 0.7460006942675264 , 0.0 , 0.4956869617631286 , 0.0 , 0.7667632408767506 5 | concrete_compression , MCAR(0.2) , utrees , 0.02333038049146902 , 0.051929753638068496 , 117.4323195156428 , 0.24444444444444446 , 0.07950083700139465 , 0.5046659061722523 , 0.0056108336568096535 , 0.15947619621621514 , 0.1297940883578644 , 0.7550948876750425 , 0.0 , 47.19429659843445 , 0.5677727274909049 , 0.0 , 0.7378459497855332 , 0.0 , 0.8652996866710022 , 0.0 , 0.8494611867527299 , 0.0 6 | yacht_hydrodynamics , MCAR(0.2) , utrees , 0.027212127217251424 , 0.06525405415346494 , 87.8356344827388 , 0.9619047619047618 , 0.06461383266455692 , 0.5112652032726271 , 0.012247314991194035 , 0.1899437795671296 , 0.14489959622012152 , 0.8958445449867422 , 0.0 , 8.892837285995483 , 0.6051986543649295 , 0.0 , 0.9878021530766475 , 0.0 , 0.9957263778055365 , 0.0 , 0.9946509946998549 , 0.0 7 | airfoil_self_noise , MCAR(0.2) , utrees , 0.024177747618689484 , 0.06644906397149036 , 3.367058543769148 , 1.0 , 0.04595237772071268 , 0.25049733817022946 , 0.012535618795616179 , 0.22880454575076564 , 0.18357767558903554 , 0.7236830200188702 , 0.0 , 29.907463312149048 , 0.5076127218483516 , 0.0 , 0.6223125813761253 , 0.0 , 0.8807123657190816 , 0.0 , 0.884094411131922 , 0.0 8 | connectionist_bench_sonar , MCAR(0.2) , utrees , 0.09859288170328898 , 0.11840726141041186 , 0.0 , 0.0 , 1.429042454382952 , 8.508219085835636 , 0.005136353547750008 , 0.22206330143535155 , 0.18837572035423833 , 0.0 , 0.7987797915830596 , 440.6014119784037 , 0.0 , 0.7211420515274188 , 0.0 , 0.8168914005300593 , 0.0 , 0.8238102292793955 , 0.0 , 0.8332754849953646 9 | ionosphere , MCAR(0.2) , utrees , 0.08523464358983444 , 0.11811873486059866 , 0.0 , 0.0 , 0.7824976996114046 , 4.441913223417513 , 0.014656794885137926 , 0.25404546334261535 , 0.202321645205517 , 0.0 , 0.9098789590534062 , 201.8103465239207 , 0.0 , 0.8732615124612854 , 0.0 , 0.9279816794874949 , 0.0 , 0.9329457905205485 , 0.0 , 0.9053268537442958 10 | qsar_biodegradation , MCAR(0.2) , utrees , 0.015254570285918954 , 0.023403987013942196 , 0.0 , 0.0 , 0.19212359101896018 , 1.3856575124044705 , 0.001247031131769453 , 0.053514867631470106 , 0.04361988321482159 , 0.0 , 0.8486549180518138 , 560.5758926868439 , 0.0 , 0.8511907477600135 , 0.0 , 0.8314650338714398 , 0.0 , 0.8565060426520982 , 0.0 , 0.8554578479237033 11 | seeds , MCAR(0.2) , utrees , 0.0536062409688773 , 0.08449679733209269 , 0.0 , 0.0 , 0.12191528713404252 , 0.47779747604268186 , 0.003367278377665574 , 0.17413053418775576 , 0.14828548838041686 , 0.0 , 0.8828062340725844 , 14.943417390187584 , 0.0 , 0.942468848874979 , 0.0 , 0.7190544422051238 , 0.0 , 0.9288809390255441 , 0.0 , 0.9408207061846907 12 | glass , MCAR(0.2) , utrees , 0.0496206279015378 , 0.07585823492308724 , 0.0 , 0.0 , 0.1395464455143801 , 0.6422554966107725 , 0.00509210478076279 , 0.1437355547360558 , 0.11662889755117278 , 0.0 , 0.543444784278404 , 17.121328353881836 , 0.0 , 0.5597076318333436 , 0.0 , 0.33844731725051297 , 0.0 , 0.6716156389645124 , 0.0 , 0.6040085490652478 13 | ecoli , MCAR(0.2) , utrees , 0.051482101920573564 , 0.07996184068979219 , 0.0 , 0.0 , 0.1094547532281822 , 0.40433483990272084 , 0.0035961465359099646 , 0.1538224632971858 , 0.13044989354532197 , 0.0 , 0.6831116597210881 , 14.686814228693645 , 0.0 , 0.7796676608032461 , 0.0 , 0.3758002592414551 , 0.0 , 0.8016274932079795 , 0.0 , 0.7753512256316719 14 | yeast , MCAR(0.2) , utrees , 0.04378552325566909 , 0.07404024912539708 , 0.0 , 0.0 , 0.10720798893085218 , 0.31883588185505407 , 0.0035788436758160324 , 0.1730284435457075 , 0.14983662228831252 , 0.0 , 0.44429114417699456 , 62.958171367645264 , 0.0 , 0.5408617994798776 , 0.0 , 0.23325600825320442 , 0.0 , 0.520586653798653 , 0.0 , 0.48246011517624343 15 | libras , MCAR(0.2) , utrees , 0.03063284486586349 , 0.03638737201499229 , 0.0 , 0.0 , 0.6566358196725776 , 8.933323976745273 , 0.0008112346007017335 , 0.08171470473946038 , 0.07063236617473345 , 0.0 , 0.568547098210967 , 1975.7754249572754 , 0.0 , 0.7305759672426337 , 0.0 , 0.046811110866585924 , 0.0 , 0.8014451030451031 , 0.0 , 0.6953562116895451 16 | planning_relax , MCAR(0.2) , utrees , 0.08411692836706103 , 0.12137667314935206 , 0.0 , 0.0 , 0.3058909258660944 , 1.460002019772343 , 0.005571385169514211 , 0.23866930500059752 , 0.20262838701678848 , 0.0 , 0.45205072102550564 , 25.245799938837685 , 0.0 , 0.4181319219394311 , 0.0 , 0.4824260898832327 , 0.0 , 0.42651343398470287 , 0.0 , 0.4811314382946558 17 | blood_transfusion , MCAR(0.2) , utrees , 0.03444502502881287 , 0.06688821811064323 , 0.0 , 0.0 , 0.03212365113334971 , 0.11167959040823328 , 0.004795884147921586 , 0.15814227169156175 , 0.13153806150238792 , 0.0 , 0.5870497695788247 , 13.1585369904836 , 0.0 , 0.5558678508828397 , 0.0 , 0.6147739158922271 , 0.0 , 0.584815196941721 , 0.0 , 0.5927421145985113 18 | breast_cancer_diagnostic , MCAR(0.2) , utrees , 0.03956558039073609 , 0.051628040615653036 , 0.0 , 0.0 , 0.31024376602207276 , 1.846053935538329 , 0.0012147251350276003 , 0.10086794072744275 , 0.08730993646823641 , 0.0 , 0.9586634419673519 , 279.3262273470561 , 0.0 , 0.957830676675561 , 0.0 , 0.9554613875073541 , 0.0 , 0.9554992846097188 , 0.0 , 0.9658624190767728 19 | connectionist_bench_vowel , MCAR(0.2) , utrees , 0.05296810071933365 , 0.09388309984454896 , 0.0 , 0.0 , 0.1878788763752319 , 0.725079410750455 , 0.0055250297398647535 , 0.24841801262652985 , 0.21448056138434413 , 0.0 , 0.6641100768645267 , 79.84957003593445 , 0.0 , 0.6514435776789809 , 0.0 , 0.20775287118264388 , 0.0 , 0.9232551451738964 , 0.0 , 0.8739887134225857 20 | concrete_slump , MCAR(0.2) , utrees , 0.1251534691485566 , 0.18784554316110125 , 47.56531663695364 , 0.725 , 0.2639606575784001 , 1.159978473434967 , 0.014801640827374072 , 0.34044859085444257 , 0.27503999548986047 , 0.6745252901753798 , 0.0 , 9.739169359207153 , 0.7541523974515636 , 0.0 , 0.6569805595622225 , 0.0 , 0.6881385743588726 , 0.0 , 0.5988296293288606 , 0.0 21 | wine_quality_red , MCAR(0.2) , utrees , 0.040802649372952006 , 0.0716406477510256 , 20.145217590505506 , 1.0 , 0.14132476856233 , 0.5169713070740208 , 0.0035164216726677883 , 0.18320501100487552 , 0.1580449829731422 , 0.3059707949127771 , 0.0 , 106.22926425933838 , 0.3170426193941901 , 0.0 , 0.2886979501708553 , 0.0 , 0.3691983195529146 , 0.0 , 0.24894429053314845 , 0.0 22 | wine_quality_white , MCAR(0.2) , utrees , 0.03407725293496344 , 0.06448257401806023 , 77.35079172193863 , 0.47777777777777775 , 0.14023782049929634 , 0.45284739625845083 , 0.0033595964423761425 , 0.19108912571449158 , 0.16837634325966563 , 0.33787091146852133 , 0.0 , 357.68214988708496 , 0.2680334535357732 , 0.0 , 0.2820063034609025 , 0.0 , 0.43474693460671304 , 0.0 , 0.3666969542706964 , 0.0 23 | california , MCAR(0.2) , utrees , 0.0196647697319096 , 0.04973844353652448 , 23.21634260208505 , 0.5925925925925927 , 0.0 , 0.0 , 0.004982026822227645 , 0.15803420495132264 , 0.13993056163009363 , 0.6548593415015938 , 0.0 , 968.1369389692943 , 0.5916727995885759 , 0.0 , 0.4237977876023389 , 0.0 , 0.7901864293781431 , 0.0 , 0.8137803494373173 , 0.0 24 | bean , MCAR(0.2) , utrees , 0.010192959415544817 , 0.02059535423265372 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0010522477503439255 , 0.06418175575996427 , 0.05799520234562458 , 0.0 , 0.7816665143749051 , 1929.1566933790846 , 0.0 , 0.7961193970932972 , 0.0 , 0.4618599044548654 , 0.0 , 0.9329963646203907 , 0.0 , 0.9356903913310666 25 | tictactoe , MCAR(0.2) , utrees , 0.29630144489131827 , 0.5110035391675715 , 0.0 , 0.0 , 0.7762402088772818 , 1.9328428706121097 , 0.024496550015139 , 1.4697930009083398 , 1.126450727895411 , 0.0 , 0.8234316722545656 , 25.853189309438072 , 0.0 , 0.7602380388772744 , 0.0 , 0.7110043303006972 , 0.0 , 0.8965042745894501 , 0.0 , 0.9259800452508402 26 | congress , MCAR(0.2) , utrees , 0.1729734997216229 , 0.2774138758682399 , 0.0 , 0.0 , 0.8032567049808451 , 2.3781609195402256 , 0.00975922053794807 , 0.5855532322768843 , 0.4332464178272987 , 0.0 , 0.9327371459215983 , 30.412309885025024 , 0.0 , 0.9068635943563466 , 0.0 , 0.932205365760348 , 0.0 , 0.9494113973748777 , 0.0 , 0.9424682261948213 27 | car , MCAR(0.2) , utrees , 0.4036153107206999 , 0.6852703901190534 , 0.0 , 0.0 , 0.48369512783406193 , 1.070682515914789 , 0.029514653850087412 , 2.0644051407904334 , 1.6485394454306355 , 0.0 , 0.80069479177201 , 30.177862962086998 , 0.0 , 0.8484301461689044 , 0.0 , 0.6644293855448289 , 0.0 , 0.8200663390532182 , 0.0 , 0.8698532963210888 28 | -------------------------------------------------------------------------------- /Results/latex_tables.txt: -------------------------------------------------------------------------------- 1 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 2 | % Fri Jul 5 22:11:55 2024 3 | \begin{table}[ht] 4 | \centering 5 | \begin{tabular}{rlllllllll} 6 | \hline 7 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 8 | \hline 9 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\ 10 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\ 11 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\ 12 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\ 13 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\ 14 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\ 15 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\ 16 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\ 17 | utrees & 0.1 (0.02) & 0.14 (0.03) & 0.42 (0.09) & 1.89 (0.49) & 0.27 (0.06) & 0.61 (0.1) & 0.76 (0.04) & 0.68 (0.24) & 0.68 (0.14) \\ 18 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\ 19 | \hline 20 | \end{tabular} 21 | \end{table} 22 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 23 | % Fri Jul 5 22:11:55 2024 24 | \begin{table}[ht] 25 | \centering 26 | \begin{tabular}{rlllllllll} 27 | \hline 28 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 29 | \hline 30 | KNN & 5.5 (0.5) & 6.3 (0.5) & 4.7 (0.4) & 4.8 (0.4) & 8.4 (0) & 6.5 (1) & 5.7 (1.1) & 6.2 (1) & 5.4 (0.6) \\ 31 | ICE & 6.7 (0.4) & 4.5 (0.4) & 7 (0.5) & 7.2 (0.4) & 1.6 (0.2) & 6 (1.1) & 7 (0.6) & 5.7 (0.9) & 5.2 (0.6) \\ 32 | MICE-Forest & 3.9 (0.4) & 2.5 (0.4) & 2.8 (0.2) & 2.7 (0.3) & 3.6 (0.2) & 3.7 (1.4) & 3.2 (1) & 5.5 (1.2) & 4.1 (0.6) \\ 33 | MissForest & 2.6 (0.5) & 4 (0.4) & 1.8 (0.3) & 1.9 (0.3) & 5.4 (0.2) & 3.8 (1.4) & 2.3 (0.6) & 5.5 (1.5) & 3.2 (0.4) \\ 34 | Softimpute & 6.7 (0.4) & 7.6 (0.4) & 7.1 (0.5) & 7.3 (0.5) & 8.4 (0) & 5.8 (1) & 7.8 (0.4) & 6.3 (0.9) & 6.5 (0.4) \\ 35 | OT & 5.9 (0.4) & 6 (0.3) & 6 (0.5) & 6 (0.5) & 3.7 (0.3) & 6.2 (0.5) & 6.8 (0.6) & 5.3 (0.7) & 4.8 (0.5) \\ 36 | GAIN & 4.6 (0.4) & 6.5 (0.3) & 5.8 (0.3) & 6 (0.3) & 6.9 (0.1) & 5.5 (0.8) & 5.4 (0.8) & 4.7 (1) & 5.1 (0.6) \\ 37 | Forest-VP & 5.1 (0.4) & 3.8 (0.5) & 5.6 (0.3) & 5 (0.4) & 3.2 (0.4) & 4.5 (0.9) & 4.6 (0.8) & 3.2 (0.7) & 5.4 (0.6) \\ 38 | utrees & 4 (0.5) & 3.8 (0.5) & 4.3 (0.5) & 4.2 (0.5) & 3.7 (0.3) & 3 (0.8) & 2.3 (0.6) & 2.7 (1) & 5.4 (0.6) \\ 39 | \hline 40 | \end{tabular} 41 | \end{table} 42 | 43 | 44 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 45 | % Sun Jul 7 22:39:34 2024 46 | \begin{table}[ht] 47 | \centering 48 | \begin{tabular}{rlllllllll} 49 | \hline 50 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\ 51 | \hline 52 | GaussianCopula & 2.74 (0.56) & 2.99 (0.61) & 0.18 (0.04) & 0.37 (0.06) & 0.2 (0.14) & 0.46 (0.06) & NA (0.05) & 2.27 (0.77) & 0.23 (0.12) \\ 53 | TVAE & 2.12 (0.58) & 2.35 (0.63) & 0.33 (0.04) & 0.63 (0.04) & -0.47 (0.61) & 0.52 (0.08) & NA (0.01) & 4.15 (1.97) & 0.26 (0.09) \\ 54 | CTGAN & 3.58 (0.99) & 3.74 (1.01) & 0.12 (0.03) & 0.28 (0.04) & -0.43 (0.08) & 0.35 (0.04) & NA (0.01) & 2.48 (1.3) & 0.2 (0.08) \\ 55 | CTAB-GAN+ & 2.71 (0.81) & 2.89 (0.83) & 0.22 (0.04) & 0.44 (0.05) & 0.05 (0.12) & 0.44 (0.05) & NA (0.02) & 2.95 (1.04) & 0.26 (0.07) \\ 56 | STaSy & 3.41 (1.39) & 3.66 (1.42) & 0.38 (0.05) & 0.63 (0.05) & -4.21 (4.44) & 0.61 (0.06) & NA (0.02) & 1.23 (0.44) & 0.45 (0.12) \\ 57 | TabDDPM & 4.27 (1.89) & 4.79 (1.89) & 0.76 (0.06) & 0.8 (0.06) & 0.6 (0.11) & 0.66 (0.06) & NA (0.03) & 0.76 (0.28) & 0.72 (0.11) \\ 58 | Forest-VP & 1.46 (0.4) & 1.94 (0.5) & 0.67 (0.05) & 0.84 (0.03) & 0.55 (0.1) & 0.73 (0.04) & NA (0.01) & 0.94 (0.3) & 0.52 (0.15) \\ 59 | Forest-Flow & 1.36 (0.39) & 1.9 (0.5) & 0.83 (0.03) & 0.9 (0.03) & 0.57 (0.11) & 0.73 (0.04) & NA (0.01) & 0.83 (0.23) & 0.63 (0.11) \\ 60 | UTrees & 1.73 (0.52) & 2.12 (0.58) & 0.63 (0.05) & 0.82 (0.03) & 0 (0.51) & 0.67 (0.06) & NA (NA) & 2.04 (1.03) & 0.36 (0.07) \\ 61 | Oracle & 0 (0) & 1.81 (0.47) & 0.99 (0.01) & 0.91 (0.04) & 0.64 (0.09) & 0.77 (0.04) & NA (0) & 0 (0) & 1 (0) \\ 62 | \hline 63 | \end{tabular} 64 | \end{table} 65 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 66 | % Sun Jul 7 22:39:34 2024 67 | \begin{table}[ht] 68 | \centering 69 | \begin{tabular}{rlllllllll} 70 | \hline 71 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\ 72 | \hline 73 | GaussianCopula & 7.1 (0.3) & 7.2 (0.3) & 7.3 (0.3) & 7.3 (0.3) & 6 (0) & 6.4 (0.3) & 7 (0.4) & 6.3 (1.1) & 7.3 (0.8) \\ 74 | TVAE & 5.2 (0.2) & 5 (0.2) & 5.7 (0.2) & 5.6 (0.2) & 6.5 (0.7) & 6 (0.5) & 5.5 (0.3) & 7.3 (0.6) & 6.7 (0.6) \\ 75 | CTGAN & 8.4 (0.1) & 8.4 (0.2) & 8.3 (0.2) & 8.1 (0.2) & 8.3 (0.3) & 8.3 (0.2) & 6.7 (0.3) & 5.3 (1.1) & 7.2 (0.5) \\ 76 | CTAB-GAN+ & 6.8 (0.3) & 6.6 (0.3) & 7.2 (0.3) & 7 (0.3) & 6.8 (0.4) & 6.8 (0.4) & 6.9 (0.3) & 7.7 (0.8) & 7 (0.8) \\ 77 | STaSy & 6.1 (0.2) & 6.2 (0.2) & 5.3 (0.2) & 5.2 (0.3) & 6 (1.2) & 5 (0.3) & 6.1 (0.3) & 4.5 (0.8) & 4.4 (1.1) \\ 78 | TabDDPM & 3 (0.7) & 3.8 (0.6) & 2.8 (0.5) & 3.2 (0.5) & 1.2 (0.2) & 3.6 (0.6) & 3.2 (0.4) & 3 (0.9) & 1.6 (0.2) \\ 79 | Forest-VP & 2.8 (0.1) & 2.6 (0.2) & 3.3 (0.3) & 3 (0.3) & 3 (0.4) & 2.1 (0.3) & 4.3 (0.4) & 3.2 (0.9) & 3.3 (0.7) \\ 80 | Forest-Flow & 1.9 (0.2) & 1.4 (0.2) & 1.7 (0.2) & 1.7 (0.2) & 2.3 (0.4) & 2.3 (0.3) & 4.3 (0.4) & 2.8 (0.5) & 2.7 (0.3) \\ 81 | UTrees & 3.5 (0.2) & 3.6 (0.2) & 3.6 (0.3) & 3.7 (0.3) & 4.8 (0.7) & 4.3 (0.4) & 1 (0) & 4.8 (0.8) & 4.8 (0.9) \\ 82 | \hline 83 | \end{tabular} 84 | \end{table} 85 | 86 | 87 | Rscript generation_script_nmiss50.R 88 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 89 | % Fri Aug 9 21:21:43 2024 90 | \begin{table}[ht] 91 | \centering 92 | \begin{tabular}{rlllllllll} 93 | \hline 94 | & W\_train & W\_test & coverage\_train & coverage\_test & R2\_fake & F1\_fake & class\_score & percent\_bias & coverage\_rate \\ 95 | \hline 96 | GaussianCopula & 7 (0.3) & 7.1 (0.2) & 7.2 (0.3) & 7.1 (0.3) & 6 (0.4) & 6.6 (0.3) & 6.7 (0.4) & 5.5 (1) & 7.7 (0.6) \\ 97 | TVAE & 5.2 (0.3) & 4.8 (0.3) & 5.7 (0.3) & 5.7 (0.2) & 6 (1) & 5.7 (0.5) & 5.8 (0.4) & 7.8 (0.6) & 6.2 (1) \\ 98 | CTGAN & 8.3 (0.2) & 8.4 (0.2) & 8.4 (0.2) & 8.3 (0.2) & 8.2 (0.5) & 8.3 (0.2) & 6.5 (0.2) & 4.8 (1.2) & 6.9 (0.7) \\ 99 | CTABGAN & 6.6 (0.4) & 6.5 (0.4) & 7 (0.3) & 6.7 (0.3) & 7.3 (0.6) & 7.1 (0.4) & 6.6 (0.3) & 7.5 (1) & 6.1 (0.5) \\ 100 | Stasy & 5.9 (0.2) & 6 (0.2) & 5.2 (0.2) & 5 (0.3) & 5.5 (1) & 4.3 (0.4) & 5.3 (0.4) & 3.8 (0.4) & 4.6 (1.1) \\ 101 | TabDDPM & 3 (0.7) & 3.4 (0.7) & 2.3 (0.5) & 2.9 (0.6) & 1.7 (0.3) & 3.2 (0.6) & 3.9 (0.6) & 3.8 (1.2) & 2 (0.5) \\ 102 | Forest-VP & 3.1 (0.3) & 3 (0.3) & 3.6 (0.2) & 3.5 (0.3) & 3 (0.3) & 2.2 (0.3) & 4.2 (0.4) & 4.3 (0.8) & 4.2 (1.2) \\ 103 | Forest-Flow & 2.5 (0.4) & 2.3 (0.3) & 2.3 (0.2) & 2.5 (0.3) & 3 (0.7) & 3.5 (0.3) & 5 (0.5) & 3.7 (0.8) & 3.1 (0.8) \\ 104 | UTrees & 3.3 (0.2) & 3.6 (0.3) & 3.3 (0.3) & 3.3 (0.3) & 4.3 (1.1) & 4.2 (0.5) & 1 (0) & 3.7 (1) & 4.3 (1) \\ 105 | \hline 106 | \end{tabular} 107 | \end{table} 108 | 109 | 110 | Rscript imputation_script.R -- commit daa3b12 111 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 112 | % Sat Aug 31 13:59:43 2024 113 | \begin{table}[ht] 114 | \centering 115 | \begin{tabular}{rlllllllll} 116 | \hline 117 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 118 | \hline 119 | KNN & 0.16 (0.03) & 0.16 (0.03) & 0.42 (0.08) & 1.89 (0.49) & 0 (0) & 0.59 (0.09) & 0.75 (0.04) & 1.27 (0.25) & 0.4 (0.11) \\ 120 | ICE & 0.1 (0.01) & 0.21 (0.03) & 0.52 (0.09) & 1.99 (0.49) & 0.69 (0.1) & 0.59 (0.09) & 0.74 (0.04) & 1.05 (0.29) & 0.39 (0.09) \\ 121 | MICE-Forest & 0.08 (0.02) & 0.13 (0.03) & 0.34 (0.07) & 1.86 (0.48) & 0.29 (0.08) & 0.61 (0.1) & 0.76 (0.04) & 0.61 (0.2) & 0.75 (0.11) \\ 122 | MissForest & 0.1 (0.03) & 0.12 (0.03) & 0.32 (0.07) & 1.85 (0.48) & 0.1 (0.03) & 0.61 (0.1) & 0.76 (0.04) & 0.62 (0.22) & 0.79 (0.08) \\ 123 | Softimpute & 0.22 (0.03) & 0.22 (0.03) & 0.53 (0.07) & 1.99 (0.48) & 0 (0) & 0.58 (0.09) & 0.74 (0.04) & 1.18 (0.34) & 0.31 (0.09) \\ 124 | OT & 0.14 (0.02) & 0.19 (0.03) & 0.56 (0.1) & 1.98 (0.49) & 0.28 (0.05) & 0.59 (0.1) & 0.75 (0.04) & 1.09 (0.27) & 0.39 (0.12) \\ 125 | GAIN & 0.16 (0.03) & 0.17 (0.03) & 0.49 (0.11) & 1.95 (0.51) & 0.01 (0) & 0.6 (0.1) & 0.75 (0.04) & 1.04 (0.25) & 0.54 (0.12) \\ 126 | Forest-VP & 0.14 (0.04) & 0.17 (0.03) & 0.55 (0.13) & 1.96 (0.5) & 0.25 (0.03) & 0.61 (0.1) & 0.74 (0.04) & 0.81 (0.25) & 0.57 (0.14) \\ 127 | utrees & 0.09 (0.02) & 0.13 (0.03) & 0.36 (0.08) & 1.87 (0.48) & 0.19 (0.07) & 0.61 (0.1) & 0.76 (0.04) & 0.44 (0.14) & 0.73 (0.12) \\ 128 | Oracle & 0 (0) & 0 (0) & 0 (0) & 1.87 (0.49) & 0 (0) & 0.64 (0.09) & 0.78 (0.04) & 0 (0) & 1 (0) \\ 129 | \hline 130 | \end{tabular} 131 | \end{table} 132 | % latex table generated in R 4.4.1 by xtable 1.8-4 package 133 | % Sat Aug 31 13:59:43 2024 134 | \begin{table}[ht] 135 | \centering 136 | \begin{tabular}{rlllllllll} 137 | \hline 138 | & MinMAE & AvgMAE & W\_train & W\_test & MedianMAD & R2\_imp & F1\_imp & PercentBias & CoverageRate \\ 139 | \hline 140 | KNN & 5.5 (0.5) & 6.3 (0.5) & 5 (0.4) & 5 (0.4) & 8.4 (0) & 6.5 (1) & 5.8 (1.1) & 6.2 (1) & 5.5 (0.6) \\ 141 | ICE & 6.8 (0.4) & 4.5 (0.4) & 7 (0.5) & 7.2 (0.4) & 1.5 (0.2) & 6.2 (1) & 7 (0.6) & 5.7 (0.9) & 5.2 (0.7) \\ 142 | MICE-Forest & 3.9 (0.4) & 2.3 (0.4) & 2.9 (0.2) & 3.1 (0.2) & 3.4 (0.2) & 3.8 (1.4) & 3.2 (1) & 5.5 (1.2) & 4.3 (0.6) \\ 143 | MissForest & 2.8 (0.5) & 4 (0.4) & 1.8 (0.3) & 2.1 (0.3) & 5.3 (0.2) & 4 (1.3) & 2.3 (0.6) & 5.5 (1.5) & 3.4 (0.5) \\ 144 | Softimpute & 6.7 (0.4) & 7.5 (0.4) & 7.1 (0.5) & 7.3 (0.5) & 8.4 (0) & 6.2 (0.8) & 7.8 (0.4) & 6.3 (0.9) & 6.7 (0.4) \\ 145 | OT & 6 (0.4) & 6 (0.3) & 6.1 (0.5) & 6 (0.5) & 3.5 (0.3) & 6.2 (0.5) & 6.8 (0.6) & 5.5 (0.8) & 4.9 (0.5) \\ 146 | GAIN & 4.7 (0.4) & 6.5 (0.3) & 6 (0.3) & 6 (0.2) & 6.9 (0.1) & 5.7 (0.8) & 5.4 (0.8) & 4.7 (1) & 5.1 (0.6) \\ 147 | Forest-VP & 5.3 (0.4) & 3.8 (0.5) & 5.9 (0.3) & 5.2 (0.4) & 3.1 (0.4) & 4.3 (0.9) & 4.6 (0.8) & 3.5 (0.7) & 5.4 (0.6) \\ 148 | utrees & 3.3 (0.5) & 4.1 (0.5) & 3.2 (0.4) & 3.1 (0.5) & 4.3 (0.2) & 2.2 (0.7) & 2.2 (0.5) & 2.2 (0.7) & 4.5 (0.5) \\ 149 | \hline 150 | \end{tabular} 151 | \end{table} 152 | -------------------------------------------------------------------------------- /paper/moons.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "7fbcfa7d-672a-4033-bb7d-4f58038778c5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import matplotlib as mpl\n", 11 | "mpl.rcParams['font.family'] = 'Arial'\n", 12 | "mpl.rcParams['text.usetex'] = False\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import numpy as np\n", 15 | "import seaborn as sns\n", 16 | "import sklearn.datasets as skd\n", 17 | "import pandas as pd\n", 18 | "\n", 19 | "from sklearn.utils import check_random_state\n", 20 | "from ForestDiffusion import ForestDiffusionModel\n", 21 | "from miceforest import ImputationKernel\n", 22 | "from missforest import MissForest\n", 23 | "from utrees import UnmaskingTrees" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "cbc3d109-ae5b-4f5f-9070-e1d9da13cea1", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "rix = 0\n", 34 | "n = 200\n", 35 | "nimp = 1 # number of multiple imputations needed\n", 36 | "ngen = 200\n", 37 | "\n", 38 | "rng = check_random_state(rix)\n", 39 | "data, labels = skd.make_moons(n, shuffle=False, noise=0.1, random_state=rix)\n", 40 | "data4impute = data.copy()\n", 41 | "data4impute[:, 1] = np.nan\n", 42 | "X=np.concatenate([data, data4impute], axis=0)\n", 43 | "impute_samples = np.isnan(X).any(axis=1)\n", 44 | "\n", 45 | "missfer = MissForest(random_state=rix)\n", 46 | "impute_missf = missfer.fit_transform(X.copy(), cat_vars=None)\n", 47 | "\n", 48 | "micer = ImputationKernel(X.copy(), random_state=rix)\n", 49 | "micer.mice(5)\n", 50 | "impute_mice = micer.complete_data()\n", 51 | "\n", 52 | "utreer = UnmaskingTrees(random_state=rix)\n", 53 | "utreer.fit(X.copy())\n", 54 | "gen_utrees = utreer.generate(n_generate=ngen);\n", 55 | "impute_utrees = utreer.impute(n_impute=nimp)[0, :, :]\n", 56 | "\n", 57 | "utaber = UnmaskingTrees(tabpfn=True, random_state=rix)\n", 58 | "utaber.fit(X.copy())\n", 59 | "gen_utab = utaber.generate(n_generate=ngen);\n", 60 | "impute_utab = utaber.impute(n_impute=nimp)[0, :, :]\n", 61 | "\n", 62 | "forestvper = ForestDiffusionModel(\n", 63 | " X=X.copy(),\n", 64 | " n_t=50, duplicate_K=100, diffusion_type='vp',\n", 65 | " bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n", 66 | "gen_forestvp = forestvper.generate(batch_size=ngen)\n", 67 | "impute_forestvp_fast = forestvper.impute(k=nimp) # regular (fast)\n", 68 | "impute_forestvp_repaint = forestvper.impute(repaint=True, r=10, j=5, k=nimp) # REPAINT (slow, but better)\n", 69 | "\n", 70 | "forestflower = ForestDiffusionModel(\n", 71 | " X=X.copy(),\n", 72 | " n_t=50, duplicate_K=100, diffusion_type='flow',\n", 73 | " bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n", 74 | "gen_forestflow = forestflower.generate(batch_size=ngen)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "b209f40a-e281-4e72-ab39-4ed723a2e1fb", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(7, 5), squeeze=False, dpi=200, sharex=True, sharey=True);\n", 85 | "markersize = 5\n", 86 | "alpha = 0.8\n", 87 | "color = 'blue'\n", 88 | "axes[0, 0].set_title('Original data');\n", 89 | "axes[0, 0].scatter(data[:, 0], data[:, 1], s=markersize, alpha=alpha, color='black', marker='x', linewidth=0.5,);\n", 90 | "axes[1, 0].set_title('Training data');\n", 91 | "axes[1, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color='green', marker='x', linewidth=0.5,);\n", 92 | "xlim = axes[0, 0].get_xlim();\n", 93 | "ylim = axes[0, 0].get_ylim();\n", 94 | "onlyfirst = X[~np.isnan(X[:, 0]), 0]\n", 95 | "onlyfirst = data[:, 0]\n", 96 | "onlysecond = X[~np.isnan(X[:, 1]), 1]\n", 97 | "onlysecond = data[:, 1]\n", 98 | "axes[1, 0].scatter(onlyfirst, ylim[0]*np.ones_like(onlyfirst), marker='|', s=100, linewidth=0.5, color='green');\n", 99 | "axes[1, 0].scatter(xlim[0]*np.ones_like(onlysecond), onlysecond, marker='_', s=100, linewidth=0.5, color='green');\n", 100 | "axes[1, 0].set_xlim(xlim);\n", 101 | "axes[1, 0].set_ylim(ylim);\n", 102 | "\n", 103 | "axes[0, 0].set_title('Training data');\n", 104 | "axes[0, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color='black', marker='x', linewidth=0.5,);\n", 105 | "\n", 106 | "axes[0, 1].set_title('Forest-VP (generate)');\n", 107 | "axes[0, 1].scatter(gen_forestvp[:, 0], gen_forestvp[:, 1], s=markersize, alpha=alpha, color=color);\n", 108 | "axes[1, 1].set_title('Forest-Flow (generate)');\n", 109 | "axes[1, 1].scatter(gen_forestflow[:, 0], gen_forestflow[:, 1], s=markersize, alpha=alpha, color=color);\n", 110 | "axes[0, 2].set_title('UnmaskingTrees (generate)');\n", 111 | "axes[0, 2].scatter(gen_utrees[:, 0], gen_utrees[:, 1], s=markersize, alpha=alpha, color=color);\n", 112 | "axes[1, 2].set_title('UnmaskingTabPFN (generate)');\n", 113 | "axes[1, 2].scatter(gen_utab[:, 0], gen_utab[:, 1], s=markersize, alpha=alpha, color=color);\n", 114 | "plt.tight_layout();\n", 115 | "\n", 116 | "\n", 117 | "fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(4, 4), squeeze=False, dpi=200, sharex=True, sharey=True);\n", 118 | "markersize = 5\n", 119 | "alpha = 0.7\n", 120 | "color = 'red'\n", 121 | "datacolor = 'green'\n", 122 | "axes[0, 0].set_title('Forest-VP');\n", 123 | "axes[0, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n", 124 | "axes[0, 0].scatter(gen_forestvp[:, 0], gen_forestvp[:, 1], s=markersize, alpha=alpha, color=color);\n", 125 | "axes[1, 0].set_title('Forest-Flow');\n", 126 | "axes[1, 0].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5);\n", 127 | "axes[1, 0].scatter(gen_forestflow[:, 0], gen_forestflow[:, 1], s=markersize, alpha=alpha, color=color);\n", 128 | "axes[0, 1].set_title('UnmaskingTrees');\n", 129 | "axes[0, 1].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5);\n", 130 | "axes[0, 1].scatter(gen_utrees[:, 0], gen_utrees[:, 1], s=markersize, alpha=alpha, color=color);\n", 131 | "axes[1, 1].set_title('UnmaskingTabPFN');\n", 132 | "axes[1, 1].scatter(X[:, 0], X[:, 1], s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5,label='data');\n", 133 | "axes[1, 1].scatter(gen_utab[:, 0], gen_utab[:, 1], s=markersize, alpha=alpha, color=color,label='generated');\n", 134 | "plt.legend(handlelength=0.4, borderpad=0.4, labelspacing=0.3,framealpha=0.9,handletextpad=0.3);\n", 135 | "for curax in axes.flatten():\n", 136 | " curax.set_yticks([0, 1]);\n", 137 | "plt.tight_layout();\n", 138 | "plt.savefig('moons-generation.png');" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "bf0923b0-5bdb-4533-9ada-25c10ab98581", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(6, 4.3), squeeze=False, dpi=200, sharex=True, sharey=True);\n", 149 | "markersize = 5\n", 150 | "alpha = 0.8\n", 151 | "color = 'blue'\n", 152 | "datacolor = 'green'\n", 153 | "\n", 154 | "axes[0, 0].set_title('MissForest');\n", 155 | "axes[0, 0].scatter(\n", 156 | " impute_missf[impute_samples, 0], impute_missf[impute_samples, 1],\n", 157 | " s=markersize, alpha=alpha, color=color);\n", 158 | "axes[0, 0].scatter(\n", 159 | " data[:, 0], data[:, 1],\n", 160 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n", 161 | "axes[1, 0].set_title('MICE-Forest');\n", 162 | "axes[1, 0].scatter(\n", 163 | " impute_mice[impute_samples, 0], impute_mice[impute_samples, 1],\n", 164 | " s=markersize, alpha=alpha, color=color);\n", 165 | "axes[1, 0].scatter(\n", 166 | " data[:, 0], data[:, 1],\n", 167 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n", 168 | "axes[0, 1].set_title('Forest-VP');\n", 169 | "axes[0, 1].scatter(\n", 170 | " impute_forestvp_fast[impute_samples, 0], impute_forestvp_fast[impute_samples, 1],\n", 171 | " s=markersize, alpha=alpha, color=color);\n", 172 | "axes[0, 1].scatter(\n", 173 | " data[:, 0], data[:, 1],\n", 174 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n", 175 | "axes[1, 1].set_title('Forest-VP w/ RePaint');\n", 176 | "axes[1, 1].scatter(\n", 177 | " impute_forestvp_repaint[impute_samples, 0], impute_forestvp_repaint[impute_samples, 1],\n", 178 | " s=markersize, alpha=alpha, color=color);\n", 179 | "axes[1, 1].scatter(\n", 180 | " data[:, 0], data[:, 1],\n", 181 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n", 182 | "axes[0, 2].set_title('UnmaskingTrees');\n", 183 | "axes[0, 2].scatter(\n", 184 | " impute_utrees[impute_samples, 0], impute_utrees[impute_samples, 1],\n", 185 | " s=markersize, alpha=alpha, color=color);\n", 186 | "axes[0, 2].scatter(\n", 187 | " data[:, 0], data[:, 1],\n", 188 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5, zorder=5,);\n", 189 | "axes[1, 2].set_title('UnmaskingTabPFN');\n", 190 | "axes[1, 2].scatter(\n", 191 | " data[:, 0], data[:, 1],\n", 192 | " s=markersize, alpha=alpha, color=datacolor, marker='x', linewidth=0.5,zorder=5,label='data');\n", 193 | "axes[1, 2].scatter(\n", 194 | " impute_utab[impute_samples, 0], impute_utab[impute_samples, 1],\n", 195 | " s=markersize, alpha=alpha, color=color,label='imputed');\n", 196 | "\n", 197 | "for curax in axes.flatten():\n", 198 | " curax.set_yticks([0, 1]);\n", 199 | "plt.legend(handlelength=0.4, borderpad=0.4, labelspacing=0.3,framealpha=0.9,handletextpad=0.3);\n", 200 | "plt.tight_layout();\n", 201 | "plt.savefig('moons-imputation.png');" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "10a7d730-04db-4b7c-94c1-422e1556cf3f", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [] 211 | } 212 | ], 213 | "metadata": { 214 | "kernelspec": { 215 | "display_name": "Python 3 (ipykernel)", 216 | "language": "python", 217 | "name": "python3" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.9.19" 230 | } 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 5 234 | } 235 | -------------------------------------------------------------------------------- /utrees/kdi_quantizer.py: -------------------------------------------------------------------------------- 1 | # Adapted from scikit-learn/sklearn/preprocessing/_discretization.py 2 | # with the following authorship and license: 3 | 4 | # Author: Henry Lin 5 | # Tom Dupré la Tour 6 | 7 | # License: BSD 8 | import warnings 9 | from numbers import Integral 10 | 11 | import numpy as np 12 | 13 | from kditransform import KDITransformer 14 | from sklearn.base import BaseEstimator, TransformerMixin 15 | from sklearn.cluster import KMeans 16 | from sklearn.preprocessing import OneHotEncoder 17 | from sklearn.utils import ( 18 | check_random_state, 19 | resample, 20 | ) 21 | from sklearn.utils._param_validation import Interval, Options, StrOptions 22 | from sklearn.utils.stats import _weighted_percentile 23 | from sklearn.utils.validation import ( 24 | _check_feature_names_in, 25 | _check_sample_weight, 26 | check_array, 27 | check_is_fitted, 28 | ) 29 | 30 | 31 | class KDIQuantizer(TransformerMixin, BaseEstimator): 32 | _parameter_constraints: dict = { 33 | "n_bins": [Interval(Integral, 2, None, closed="left"), "array-like"], 34 | "encode": [StrOptions({"ordinal"})], 35 | "strategy": [StrOptions({"uniform", "quantile", "kmeans", "kdiquantile"})], 36 | "dtype": [Options(type, {np.float64, np.float32}), None], 37 | "subsample": [Interval(Integral, 1, None, closed="left"), None], 38 | "random_state": ["random_state"], 39 | } 40 | 41 | def __init__( 42 | self, 43 | n_bins=5, 44 | *, 45 | encode="ordinal", 46 | strategy="kdiquantile", 47 | dtype=None, 48 | subsample=200_000, 49 | random_state=None, 50 | ): 51 | self.n_bins = n_bins 52 | self.encode = encode 53 | self.strategy = strategy 54 | self.dtype = dtype 55 | self.subsample = subsample 56 | self.random_state = random_state 57 | 58 | def fit(self, X, y=None, sample_weight=None): 59 | """ 60 | Fit the estimator. 61 | 62 | Parameters 63 | ---------- 64 | X : array-like of shape (n_samples, n_features) 65 | Data to be discretized. 66 | 67 | y : None 68 | Ignored. This parameter exists only for compatibility with 69 | :class:`~sklearn.pipeline.Pipeline`. 70 | 71 | sample_weight : ndarray of shape (n_samples,) 72 | Contains weight values to be associated with each sample. 73 | Cannot be used when `strategy` is set to `'uniform'`. 74 | 75 | .. versionadded:: 1.3 76 | 77 | Returns 78 | ------- 79 | self : object 80 | Returns the instance itself. 81 | """ 82 | X = self._validate_data(X, dtype="numeric") 83 | 84 | if self.dtype in (np.float64, np.float32): 85 | output_dtype = self.dtype 86 | else: # self.dtype is None 87 | output_dtype = X.dtype 88 | 89 | n_samples, n_features = X.shape 90 | 91 | if sample_weight is not None and self.strategy in ("uniform", "kdiquantile"): 92 | raise ValueError( 93 | "`sample_weight` was provided but it cannot be " 94 | "used with uniform or kdiquantile. Got strategy=" 95 | f"{self.strategy!r} instead." 96 | ) 97 | 98 | if self.subsample is not None and n_samples > self.subsample: 99 | # Take a subsample of `X` 100 | X = resample( 101 | X, 102 | replace=False, 103 | n_samples=self.subsample, 104 | random_state=self.random_state, 105 | ) 106 | 107 | n_features = X.shape[1] 108 | n_bins = self._validate_n_bins(n_features) 109 | 110 | if sample_weight is not None: 111 | sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) 112 | 113 | bin_edges = np.zeros(n_features, dtype=object) 114 | for jj in range(n_features): 115 | column = X[:, jj] 116 | col_min, col_max = column.min(), column.max() 117 | 118 | if col_min == col_max: 119 | # warnings.warn( 120 | # 'Feature %d is constant and will be replaced with 0.' % jj 121 | # ) 122 | n_bins[jj] = 1 123 | bin_edges[jj] = np.array([-np.inf, np.inf]) 124 | continue 125 | 126 | if self.strategy == "uniform": 127 | bin_edges[jj] = np.linspace(col_min, col_max, n_bins[jj] + 1) 128 | elif self.strategy == "quantile": 129 | quantiles = np.linspace(0, 100, n_bins[jj] + 1) 130 | if sample_weight is None: 131 | bin_edges[jj] = np.asarray(np.percentile(column, quantiles)) 132 | else: 133 | bin_edges[jj] = np.asarray( 134 | [ 135 | _weighted_percentile(column, sample_weight, q) 136 | for q in quantiles 137 | ], 138 | dtype=np.float64, 139 | ) 140 | elif self.strategy == "kdiquantile": 141 | quantiles = np.linspace(0, 1, n_bins[jj] + 1) 142 | kdier = KDITransformer().fit(column.reshape(-1, 1)) 143 | cur_bin_edges = kdier.inverse_transform( 144 | quantiles.reshape(-1, 1) 145 | ).ravel() 146 | if np.unique(cur_bin_edges).shape[0] < n_bins[jj] + 1: 147 | # XXX - this bugfix is gross: find local minima instead 148 | cur_bin_edges = np.linspace(col_min, col_max, n_bins[jj] + 1) 149 | warnings.warn("kdiquantile numerical error detected, backing off.") 150 | bin_edges[jj] = cur_bin_edges # ndarray of shape (n_bins[jj] + 1,) 151 | elif self.strategy == "kmeans": 152 | # Deterministic initialization with uniform spacing 153 | uniform_edges = np.linspace(col_min, col_max, n_bins[jj] + 1) 154 | init = (uniform_edges[1:] + uniform_edges[:-1])[:, None] * 0.5 155 | 156 | # 1D k-means procedure 157 | km = KMeans(n_clusters=n_bins[jj], init=init, n_init=1) 158 | centers = km.fit( 159 | column[:, None], sample_weight=sample_weight 160 | ).cluster_centers_[:, 0] 161 | # Must sort, centers may be unsorted even with sorted init 162 | centers.sort() 163 | bin_edges[jj] = (centers[1:] + centers[:-1]) * 0.5 164 | bin_edges[jj] = np.r_[col_min, bin_edges[jj], col_max] 165 | 166 | """ 167 | # Remove bins whose width are too small (i.e., <= 1e-8) 168 | if self.strategy in ('quantile', 'kmeans', 'kdiquantile'): 169 | mask = np.ediff1d(bin_edges[jj], to_begin=np.inf) > 1e-8 170 | bin_edges[jj] = bin_edges[jj][mask] 171 | if len(bin_edges[jj]) - 1 != n_bins[jj]: 172 | warnings.warn( 173 | 'Bins whose width are too small (i.e., <= ' 174 | '1e-8) in feature %d are removed. Consider ' 175 | 'decreasing the number of bins.' % jj 176 | ) 177 | n_bins[jj] = len(bin_edges[jj]) - 1 178 | """ 179 | 180 | self.bin_edges_ = bin_edges 181 | self.n_bins_ = n_bins 182 | 183 | if "onehot" in self.encode: 184 | self._encoder = OneHotEncoder( 185 | categories=[np.arange(i) for i in self.n_bins_], 186 | sparse_output=self.encode == "onehot", 187 | dtype=output_dtype, 188 | ) 189 | # Fit the OneHotEncoder with toy datasets 190 | # so that it's ready for use after the KDIQuantizer is fitted 191 | self._encoder.fit(np.zeros((1, len(self.n_bins_)))) 192 | 193 | return self 194 | 195 | def _validate_n_bins(self, n_features): 196 | """Returns n_bins_, the number of bins per feature.""" 197 | orig_bins = self.n_bins 198 | if isinstance(orig_bins, Integral): 199 | return np.full(n_features, orig_bins, dtype=int) 200 | 201 | n_bins = check_array(orig_bins, dtype=int, copy=True, ensure_2d=False) 202 | 203 | if n_bins.ndim > 1 or n_bins.shape[0] != n_features: 204 | raise ValueError("n_bins must be a scalar or array of shape (n_features,).") 205 | 206 | bad_nbins_value = (n_bins < 2) | (n_bins != orig_bins) 207 | 208 | violating_indices = np.where(bad_nbins_value)[0] 209 | if violating_indices.shape[0] > 0: 210 | indices = ", ".join(str(i) for i in violating_indices) 211 | raise ValueError( 212 | "{} received an invalid number " 213 | "of bins at indices {}. Number of bins " 214 | "must be at least 2, and must be an int.".format( 215 | KDIQuantizer.__name__, indices 216 | ) 217 | ) 218 | return n_bins 219 | 220 | def transform(self, X): 221 | """ 222 | Discretize the data. 223 | 224 | Parameters 225 | ---------- 226 | X : array-like of shape (n_samples, n_features) 227 | Data to be discretized. 228 | 229 | Returns 230 | ------- 231 | Xt : {ndarray, sparse matrix}, dtype={np.float32, np.float64} 232 | Data in the binned space. Will be a sparse matrix if 233 | `self.encode='onehot'` and ndarray otherwise. 234 | """ 235 | check_is_fitted(self) 236 | 237 | # check input and attribute dtypes 238 | dtype = (np.float64, np.float32) if self.dtype is None else self.dtype 239 | Xt = self._validate_data(X, copy=True, dtype=dtype, reset=False) 240 | 241 | bin_edges = self.bin_edges_ 242 | for jj in range(Xt.shape[1]): 243 | Xt[:, jj] = np.searchsorted(bin_edges[jj][1:-1], Xt[:, jj], side="right") 244 | 245 | if self.encode == "ordinal": 246 | return Xt 247 | 248 | dtype_init = None 249 | if "onehot" in self.encode: 250 | dtype_init = self._encoder.dtype 251 | self._encoder.dtype = Xt.dtype 252 | try: 253 | Xt_enc = self._encoder.transform(Xt) 254 | finally: 255 | # revert the initial dtype to avoid modifying self. 256 | self._encoder.dtype = dtype_init 257 | return Xt_enc 258 | 259 | def inverse_transform(self, X): 260 | """ 261 | Transform discretized data back to original feature space. 262 | 263 | Note that this function does not regenerate the original data 264 | due to discretization rounding. 265 | 266 | Parameters 267 | ---------- 268 | X : array-like of shape (n_samples, n_features) 269 | Transformed data in the binned space. 270 | 271 | Returns 272 | ------- 273 | Xinv : ndarray, dtype={np.float32, np.float64} 274 | Data in the original feature space. 275 | """ 276 | check_is_fitted(self) 277 | 278 | if "onehot" in self.encode: 279 | X = self._encoder.inverse_transform(X) 280 | 281 | Xinv = check_array(X, copy=True, dtype=(np.float64, np.float32)) 282 | n_features = self.n_bins_.shape[0] 283 | if Xinv.shape[1] != n_features: 284 | raise ValueError( 285 | "Incorrect number of features. Expecting {}, received {}.".format( 286 | n_features, Xinv.shape[1] 287 | ) 288 | ) 289 | 290 | for jj in range(n_features): 291 | bin_edges = self.bin_edges_[jj] 292 | bin_centers = (bin_edges[1:] + bin_edges[:-1]) * 0.5 293 | Xinv[:, jj] = bin_centers[(Xinv[:, jj]).astype(np.int64)] 294 | 295 | return Xinv 296 | 297 | def inverse_transform_sample(self, X=None, random_state=None): 298 | rng = check_random_state(random_state) 299 | check_is_fitted(self) 300 | 301 | if "onehot" in self.encode: 302 | X = self._encoder.inverse_transform(X) 303 | 304 | Xinv = check_array(X, copy=True, dtype=(np.float64, np.float32)) 305 | n_features = self.n_bins_.shape[0] 306 | if Xinv.shape[1] != n_features: 307 | raise ValueError( 308 | "Incorrect number of features. Expecting {}, received {}.".format( 309 | n_features, Xinv.shape[1] 310 | ) 311 | ) 312 | n = X.shape[0] 313 | for jj in range(n_features): 314 | jitter = rng.uniform(0.0, 1.0, size=n) 315 | bin_edges = self.bin_edges_[jj] 316 | bin_centers = (bin_edges[1:] + bin_edges[:-1]) * 0.5 317 | bin_lefts = bin_edges[1:][(Xinv[:, jj]).astype(np.int64)] 318 | bin_rights = bin_edges[:-1][(Xinv[:, jj]).astype(np.int64)] 319 | Xinv[:, jj] = bin_lefts * jitter + bin_rights * (1 - jitter) 320 | 321 | return Xinv 322 | 323 | def get_feature_names_out(self, input_features=None): 324 | """Get output feature names. 325 | 326 | Parameters 327 | ---------- 328 | input_features : array-like of str or None, default=None 329 | Input features. 330 | 331 | - If `input_features` is `None`, then `feature_names_in_` is 332 | used as feature names in. If `feature_names_in_` is not defined, 333 | then the following input feature names are generated: 334 | `['x0', 'x1', ..., 'x(n_features_in_ - 1)']`. 335 | - If `input_features` is an array-like, then `input_features` must 336 | match `feature_names_in_` if `feature_names_in_` is defined. 337 | 338 | Returns 339 | ------- 340 | feature_names_out : ndarray of str objects 341 | Transformed feature names. 342 | """ 343 | check_is_fitted(self, "n_features_in_") 344 | input_features = _check_feature_names_in(self, input_features) 345 | if hasattr(self, "_encoder"): 346 | return self._encoder.get_feature_names_out(input_features) 347 | 348 | # ordinal encoding 349 | return input_features 350 | -------------------------------------------------------------------------------- /utrees/baltobot.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | from copy import deepcopy 3 | import warnings 4 | 5 | import numpy as np 6 | import xgboost as xgb 7 | 8 | from sklearn.base import BaseEstimator, ClassifierMixin 9 | from sklearn.preprocessing import LabelEncoder 10 | from sklearn.utils import check_random_state 11 | 12 | from utrees.kdi_quantizer import KDIQuantizer 13 | 14 | XGBOOST_DEFAULT_KWARGS = { 15 | "tree_method": "hist", 16 | "verbosity": 1, 17 | "objective": "binary:logistic", 18 | } 19 | 20 | TABPFN_DEFAULT_KWARGS = { 21 | "N_ensemble_configurations": 1, 22 | "seed": 42, 23 | "batch_size_inference": 1, 24 | "subsample_features": True, 25 | } 26 | 27 | 28 | class NanTabPFNClassifier(BaseEstimator, ClassifierMixin): 29 | def __init__( 30 | self, 31 | random_state=None, 32 | **tabpfn_kwargs, 33 | ): 34 | # We do not import this at the top, because Python dependencies are a nightmare. 35 | from tabpfn import TabPFNClassifier 36 | 37 | self.tabpfn = TabPFNClassifier(**tabpfn_kwargs) 38 | self.rng = check_random_state(random_state) 39 | 40 | def fit( 41 | self, 42 | X, 43 | y, 44 | ): 45 | self.Xtrain = X.copy() 46 | self.ytrain = y.copy() 47 | self.n_classes = np.unique(y).shape[0] 48 | return self 49 | 50 | def predict_proba( 51 | self, 52 | X, 53 | ): 54 | dtype = self.Xtrain.dtype 55 | X = X.astype(dtype) 56 | n_test, d = X.shape 57 | pred_probas = np.zeros((n_test, self.n_classes), dtype=dtype) 58 | 59 | obsXtest = ~np.isnan(X) 60 | obsXtrain = ~np.isnan(self.Xtrain) 61 | obs_patterns = np.unique(obsXtest, axis=0) 62 | n_patterns = obs_patterns.shape[0] 63 | for pix in range(n_patterns): 64 | isobs = obs_patterns[[pix], :] 65 | test_ixs = (obsXtest == isobs).all(axis=1).nonzero()[0] 66 | obs_ixs = isobs.ravel().nonzero()[0] 67 | train_ixs = (obsXtrain | ~isobs).all(axis=1).nonzero()[0] 68 | while train_ixs.shape[0] == 0: 69 | # No training example covers the observed features in this test example. 70 | # This should be rare, but we handle it anyways. 71 | # We introduce fake missingness into test until training covers it. 72 | isobs[0, self.rng.choice(obs_ixs)] = False 73 | obs_ixs = isobs.ravel().nonzero()[0] 74 | train_ixs = (obsXtrain | ~isobs).all(axis=1).nonzero()[0] 75 | if isobs.sum() == 0: 76 | curXtrain = np.zeros((self.Xtrain.shape[0], 1), dtype=dtype) 77 | curytrain = self.ytrain 78 | curXtest = np.zeros((X[test_ixs, :].shape[0], 1), dtype=dtype) 79 | else: 80 | curXtrain = self.Xtrain[np.ix_(train_ixs, obs_ixs)] 81 | curytrain = self.ytrain[train_ixs] 82 | curXtest = X[np.ix_(test_ixs, obs_ixs)] 83 | if curXtrain.shape[0] > 1024: 84 | # warnings.warn(f'TabPFN must shrink from {curXtrain.shape[0]} to 1024 samples') 85 | sixs = self.rng.choice(curXtrain.shape[0], 1024, replace=False) 86 | curXtrain = curXtrain[sixs, :] 87 | curytrain = curytrain[sixs] 88 | if np.unique(curytrain).shape[0] == 1: 89 | # Removing missingness might make us only see one label 90 | pred_probas[test_ixs, curytrain[0]] = 1.0 91 | else: 92 | self.tabpfn.fit(curXtrain, curytrain) 93 | cur_pred_prob = self.tabpfn.predict_proba(curXtest) 94 | pred_probas[test_ixs, :] = cur_pred_prob 95 | return pred_probas 96 | 97 | 98 | class Baltobot(BaseEstimator): 99 | """Performs probabilistic prediction using a BALanced Tree Of BOosted Trees. 100 | 101 | Parameters 102 | ---------- 103 | depth : int >= 0 104 | Depth of balanced binary tree for recursively quantizing each feature. 105 | The total number of quantization bins is 2^depth. 106 | 107 | clf_kwargs : dict 108 | Arguments for XGBoost (or TabPFN) classifier. 109 | 110 | strategy : 'kdiquantile', 'quantile', 'uniform', 'kmeans' 111 | The quantization strategy for discretizing continuous features. 112 | 113 | softmax_temp : float > 0 114 | Softmax temperature for sampling from predicted probabilities. 115 | As temperature decreases below the default of 1, predictions converge 116 | to the argmax of each conditional distribution. 117 | 118 | tabpfn: bool, or int >= 0 119 | Whether to use TabPFN instead of XGBoost classifier. 120 | 121 | random_state : int, RandomState instance or None, default=None 122 | Determines random number generation. 123 | 124 | 125 | Attributes 126 | ---------- 127 | constant_val_ : None or float 128 | Stores the value of a non-varying target variable, avoiding uniform sampling. 129 | This allows us to sample from discrete or mixed-type random variables. 130 | 131 | quantizer_ : KDIQuantizer 132 | Maps continuous value to bin. Almost always 2 bins. 133 | 134 | encoder_ : LabelEncoder 135 | Maps bin (float) to label (int). Almost always 2 classes. 136 | 137 | clfer_ : XGBClassifier 138 | Binary classifier that tell us whether to go left or right bin. 139 | 140 | left_child_, right_child_ : Baltobot 141 | Next level in the balanced binary tree. 142 | """ 143 | 144 | def __init__( 145 | self, 146 | depth: int = 4, 147 | clf_kwargs: dict = {}, 148 | strategy: str = "kdiquantile", 149 | softmax_temp: float = 1.0, 150 | tabpfn: bool = False, 151 | random_state=None, 152 | ): 153 | 154 | self.depth = depth 155 | self.clf_kwargs = clf_kwargs 156 | self.strategy = strategy 157 | self.softmax_temp = softmax_temp 158 | self.tabpfn = tabpfn 159 | self.random_state = check_random_state(random_state) 160 | assert depth < 10 # 2^10 models is insane 161 | 162 | self.left_child_ = None 163 | self.right_child_ = None 164 | self.quantizer_ = KDIQuantizer( 165 | n_bins=2, strategy=strategy, random_state=self.random_state 166 | ) 167 | self.encoder_ = LabelEncoder() 168 | my_kwargs = ( 169 | XGBOOST_DEFAULT_KWARGS | clf_kwargs | {"random_state": self.random_state} 170 | ) 171 | self.clfer_ = xgb.XGBClassifier(**my_kwargs) 172 | 173 | if self.tabpfn: 174 | assert self.tabpfn > 0 175 | my_kwargs = TABPFN_DEFAULT_KWARGS | clf_kwargs 176 | self.clfer_ = NanTabPFNClassifier(**my_kwargs) 177 | 178 | self.constant_val_ = None 179 | 180 | if depth > 0: 181 | self.left_child_ = Baltobot( 182 | depth=depth - 1, 183 | clf_kwargs=clf_kwargs, 184 | strategy=strategy, 185 | softmax_temp=softmax_temp, 186 | tabpfn=tabpfn, 187 | random_state=self.random_state, 188 | ) 189 | self.right_child_ = Baltobot( 190 | depth=depth - 1, 191 | clf_kwargs=clf_kwargs, 192 | strategy=strategy, 193 | softmax_temp=softmax_temp, 194 | tabpfn=tabpfn, 195 | random_state=self.random_state, 196 | ) 197 | 198 | def fit( 199 | self, 200 | X: np.ndarray, 201 | y: np.ndarray, 202 | ): 203 | """Recursively fits balanced tree of boosted tree classifiers. 204 | 205 | Parameters 206 | ---------- 207 | X : array-like of shape (n_samples, n_features) 208 | Input variables of train set. 209 | 210 | y : array-like of shape (n_samples,) 211 | Continuous target variable of train set. 212 | 213 | Returns 214 | ------- 215 | self 216 | """ 217 | 218 | assert np.isnan(y).sum() == 0 219 | if y.size == 0: 220 | self.constant_val_ = 0.0 221 | return self 222 | self.constant_val_ = None 223 | rng = check_random_state(self.random_state) 224 | 225 | self.quantizer_.fit(y.reshape(-1, 1)) 226 | y_quant = self.quantizer_.transform(y.reshape(-1, 1)).ravel() 227 | self.encoder_.fit(y_quant) 228 | if len(self.encoder_.classes_) != 2: 229 | assert len(self.encoder_.classes_) == 1 230 | assert np.unique(y).shape[0] == 1 231 | self.constant_val_ = np.mean(y) 232 | return self 233 | y_enc = self.encoder_.transform(y_quant) 234 | 235 | self.clfer_.fit(X, y_enc) 236 | 237 | if self.depth > 0: 238 | left_ixs = y_enc == 0 239 | right_ixs = y_enc == 1 240 | self.left_child_.fit(X[left_ixs, :], y[left_ixs]) 241 | self.right_child_.fit(X[right_ixs, :], y[right_ixs]) 242 | 243 | return self 244 | 245 | def sample( 246 | self, 247 | X: np.ndarray, 248 | ): 249 | """Samples y conditional on X. 250 | 251 | Parameters 252 | ---------- 253 | X : array-like of shape (n_samples, n_features) 254 | Input variables of test set. 255 | 256 | Returns 257 | ------- 258 | y : array-like of shape (n_samples,) 259 | Samples from the conditional distribution. 260 | """ 261 | 262 | n, _ = X.shape 263 | rng = check_random_state(self.random_state) 264 | 265 | if self.constant_val_ is not None: 266 | return np.full((n,), fill_value=self.constant_val_) 267 | 268 | pred_prob = self.clfer_.predict_proba(X) 269 | 270 | with np.errstate(divide="ignore"): 271 | annealed_logits = np.log(pred_prob) / self.softmax_temp 272 | pred_prob = np.exp(annealed_logits) / np.sum( 273 | np.exp(annealed_logits), axis=1, keepdims=True 274 | ) 275 | pred_enc = np.zeros((n,), dtype=int) 276 | for i in range(n): 277 | pred_enc[i] = rng.choice(a=len(self.encoder_.classes_), p=pred_prob[i, :]) 278 | 279 | if self.depth == 0: 280 | pred_quant = self.encoder_.inverse_transform(pred_enc) 281 | pred_val = self.quantizer_.inverse_transform_sample( 282 | pred_quant.reshape(-1, 1) 283 | ) 284 | return pred_val.ravel() 285 | else: 286 | left_ixs = pred_enc == 0 287 | right_ixs = pred_enc == 1 288 | pred_val = np.zeros((n,)) 289 | if left_ixs.sum() > 0: 290 | pred_val[left_ixs] = self.left_child_.sample(X[left_ixs, :]) 291 | if right_ixs.sum() > 0: 292 | pred_val[right_ixs] = self.right_child_.sample(X[right_ixs, :]) 293 | return pred_val 294 | 295 | def score_samples( 296 | self, 297 | X: np.ndarray, 298 | y: np.ndarray, 299 | ): 300 | """Compute the conditional log-likelihood of y given X. 301 | 302 | Parameters 303 | ---------- 304 | X : array-like of shape (n_samples, n_features) 305 | An array of points to query. Last dimension should match dimension 306 | of training data (n_features). 307 | 308 | y : array-like of shape (n_samples,) 309 | Output variable. 310 | 311 | Returns 312 | ------- 313 | density : ndarray of shape (n_samples,) 314 | Log-likelihood of each sample. These are normalized to be 315 | probability densities, after exponentiating, though will not 316 | necessarily exactly sum to 1 due to numerical error. 317 | """ 318 | 319 | n, d = X.shape 320 | rng = check_random_state(self.random_state) 321 | if self.constant_val_ is not None: 322 | return np.log(y == self.constant_val_) 323 | 324 | if self.tabpfn: 325 | XX = X.copy() 326 | XX[np.isnan(XX)] = rng.normal(size=XX.shape)[np.isnan(XX)] 327 | pred_prob = self.clfer_.predict_proba(XX) 328 | else: 329 | pred_prob = self.clfer_.predict_proba(X) 330 | 331 | y_quant = self.quantizer_.transform(y.reshape(-1, 1)).ravel() 332 | y_enc = self.encoder_.transform(y_quant) 333 | if self.depth == 0: 334 | left_ixs = y_enc == 0 335 | right_ixs = y_enc == 1 336 | bin_edges = self.quantizer_.bin_edges_[0] 337 | left_width = bin_edges[1] - bin_edges[0] + 1e-15 338 | right_width = bin_edges[2] - bin_edges[1] + 1e-15 339 | scores = np.zeros((n,)) 340 | scores[left_ixs] = np.log(pred_prob[left_ixs, 0] / left_width) 341 | scores[right_ixs] = np.log(pred_prob[right_ixs, 1] / right_width) 342 | return scores 343 | else: 344 | left_ixs = y_enc == 0 345 | right_ixs = y_enc == 1 346 | scores = np.zeros((n,)) 347 | if left_ixs.sum() > 0: 348 | left_scores = self.left_child_.score_samples( 349 | X[left_ixs, :], y[left_ixs] 350 | ) 351 | scores[left_ixs] = np.log(pred_prob[left_ixs, 0]) + left_scores 352 | if right_ixs.sum() > 0: 353 | right_scores = self.right_child_.score_samples( 354 | X[right_ixs, :], y[right_ixs] 355 | ) 356 | scores[right_ixs] = np.log(pred_prob[right_ixs, 1]) + right_scores 357 | return scores 358 | 359 | def score( 360 | self, 361 | X: np.ndarray, 362 | y: np.ndarray, 363 | ): 364 | """Compute the total log probability density under the model. 365 | 366 | Parameters 367 | ---------- 368 | X : array_like, shape (n_samples, n_features) 369 | List of n_features-dimensional data points. Each row 370 | corresponds to a single data point. 371 | 372 | y : array-like of shape (n_samples,) 373 | Output variable. 374 | 375 | Returns 376 | ------- 377 | logprob : float 378 | Total log-likelihood of y given X. 379 | """ 380 | logprob = np.sum(self.score_samples(X, y)) 381 | return logprob 382 | -------------------------------------------------------------------------------- /paper/baltobot-simple-demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c8be354b-c400-4d08-9434-cc0051072192", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import matplotlib as mpl\n", 11 | "mpl.rcParams['font.family'] = 'Arial'\n", 12 | "mpl.rcParams['text.usetex'] = False\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import numpy as np\n", 15 | "import seaborn as sns\n", 16 | "import pandas as pd\n", 17 | "import time\n", 18 | "from utrees import Baltobot\n", 19 | "from treeffuser import Treeffuser\n", 20 | "time.time()\n", 21 | "\n", 22 | "# Generate the data\n", 23 | "seed = 0\n", 24 | "n = 5000\n", 25 | "rng = np.random.default_rng(seed=seed)\n", 26 | "x = rng.uniform(0, 2 * np.pi, size=n)\n", 27 | "z = rng.integers(0, 2, size=n)\n", 28 | "y = z * np.sin(x - np.pi / 2) + (1 - z) * np.cos(x) + rng.laplace(scale=x / 30, size=n)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "6383abd1-52a1-42ab-ae29-eb09b5d28c6b", 35 | "metadata": { 36 | "scrolled": true 37 | }, 38 | "outputs": [ 39 | { 40 | "name": "stderr", 41 | "output_type": "stream", 42 | "text": [ 43 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n", 44 | " X = _check_array(X)\n", 45 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:113: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n", 46 | " y = _check_array(y)\n", 47 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n", 48 | " X = _check_array(X)\n", 49 | "/Users/calvinm/sandbox/unmasking-trees/utrees/baltobot.py:49: UserWarning: Support for TabPFN is experimental.\n", 50 | " warnings.warn('Support for TabPFN is experimental.')\n", 51 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 52 | " from .autonotebook import tqdm as notebook_tqdm\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "# Fit the models\n", 58 | "start_time = time.time()\n", 59 | "tfer = Treeffuser(sde_initialize_from_data=True, seed=seed)\n", 60 | "tfer.fit(x, y)\n", 61 | "tf_train_time = time.time() - start_time\n", 62 | "y_tfer = tfer.sample(x, n_samples=1, seed=seed, verbose=True)\n", 63 | "tf_time = time.time() - start_time\n", 64 | "\n", 65 | "start_time = time.time()\n", 66 | "tber = Baltobot(random_state=seed)\n", 67 | "tber.fit(x.reshape(-1, 1), y)\n", 68 | "tb_train_time = time.time() - start_time\n", 69 | "y_tber = tber.sample(x.reshape(-1, 1))\n", 70 | "tb_time = time.time() - start_time\n", 71 | "\n", 72 | "start_time = time.time()\n", 73 | "tbtaber = Baltobot(tabpfn=True, random_state=seed)\n", 74 | "tbtaber.fit(x.reshape(-1, 1), y)\n", 75 | "tbtab_train_time = time.time() - start_time\n", 76 | "y_tbtaber = tbtaber.sample(x.reshape(-1, 1))\n", 77 | "tbtab_time = time.time() - start_time" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 3, 83 | "id": "2cb69a60-0533-40f0-a405-93be28619f6d", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "text/html": [ 89 | "
\n", 90 | "\n", 103 | "\n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | "
MethodTaskTime
0TreeffuserTotal6.401965
1TreeffuserTraining1.356733
2TreeffuserSampling5.045232
3BaltobotTotal3.175086
4BaltobotTraining2.396088
5BaltobotSampling0.778998
6BaltoboTabPFNTotal12.216128
7BaltoboTabPFNTraining2.148450
8BaltoboTabPFNSampling10.067678
\n", 169 | "
" 170 | ], 171 | "text/plain": [ 172 | " Method Task Time\n", 173 | "0 Treeffuser Total 6.401965\n", 174 | "1 Treeffuser Training 1.356733\n", 175 | "2 Treeffuser Sampling 5.045232\n", 176 | "3 Baltobot Total 3.175086\n", 177 | "4 Baltobot Training 2.396088\n", 178 | "5 Baltobot Sampling 0.778998\n", 179 | "6 BaltoboTabPFN Total 12.216128\n", 180 | "7 BaltoboTabPFN Training 2.148450\n", 181 | "8 BaltoboTabPFN Sampling 10.067678" 182 | ] 183 | }, 184 | "execution_count": 3, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "fig, axes = plt.subplots(nrows=3, figsize=(7, 7), sharex=True, dpi=300);\n", 191 | "axes[0].scatter(x, y, s=1, label=\"Observed data\")\n", 192 | "axes[0].scatter(x, y_tfer[0, :], s=1, alpha=0.7, label=\"Treeffuser samples\")\n", 193 | "axes[0].legend();\n", 194 | "\n", 195 | "axes[1].scatter(x, y, s=1, label=\"Observed data\")\n", 196 | "axes[1].scatter(x, y_tber, s=1, alpha=0.7, label=\"Baltobot samples\")\n", 197 | "axes[1].legend();\n", 198 | "\n", 199 | "axes[2].scatter(x, y, s=1, label=\"Observed data\")\n", 200 | "axes[2].scatter(x, y_tbtaber, s=1, alpha=0.7, label=\"BaltoboTabPFN samples\")\n", 201 | "axes[2].legend();\n", 202 | "plt.tight_layout();\n", 203 | "plt.savefig('wave-demo.png');\n", 204 | "plt.close();\n", 205 | "\n", 206 | "plt.figure(dpi=200, figsize=(4,3));\n", 207 | "total_time_df = pd.DataFrame.from_dict({'Treeffuser': [tf_time], 'Baltobot': [tb_time], 'BaltoboTabPFN': [tbtab_time]}).T\n", 208 | "total_time_df.columns = ['Total']\n", 209 | "train_time_df = pd.DataFrame.from_dict({'Treeffuser': [tf_train_time], 'Baltobot': [tb_train_time], 'BaltoboTabPFN': [tbtab_train_time]}).T\n", 210 | "train_time_df.columns = ['Training']\n", 211 | "time_df = pd.concat([total_time_df, train_time_df], axis=1)\n", 212 | "time_df['Sampling'] = time_df['Total'] - time_df['Training']\n", 213 | "\n", 214 | "time_dff = time_df.stack().reset_index()\n", 215 | "time_dff.columns = ['Method', 'Task', 'Time']\n", 216 | "sns.barplot(data=time_dff, y='Method', x='Time', hue='Task');\n", 217 | "plt.ylabel('Method');\n", 218 | "plt.xlabel('Time (s)');\n", 219 | "plt.tight_layout();\n", 220 | "plt.savefig('wave-demo-time.png');\n", 221 | "plt.close()\n", 222 | "time_dff" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 4, 228 | "id": "ccb3079e-72a3-403a-8ea4-d3b9a2435493", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "lhs, rhs = np.meshgrid(np.linspace(-1, 7, 30), np.linspace(-3,2, 30))\n", 233 | "lhsrhs = np.hstack([lhs.reshape(-1, 1), rhs.reshape(-1, 1)])\n", 234 | "plt.figure();\n", 235 | "plt.scatter(lhsrhs[:, 0], lhsrhs[:, 1])\n", 236 | "scores = tber.score_samples(lhs.reshape(-1, 1), rhs.reshape(-1))\n", 237 | "plt.close();\n", 238 | "plt.figure();\n", 239 | "plt.scatter(lhsrhs[:, 0], lhsrhs[:, 1], s=100*np.exp(scores));\n", 240 | "plt.close();" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 5, 246 | "id": "671dac1d-054b-4f64-8ca8-17676bd064b7", 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "name": "stdout", 251 | "output_type": "stream", 252 | "text": [ 253 | "0.9820769258371294 0.9859963150095391\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "Xs = 2 * np.ones((1000, 1))\n", 259 | "Ys = np.linspace(-5, 5, 1000)\n", 260 | "tb_scores = tber.score_samples(Xs, Ys)\n", 261 | "tbtab_scores = tbtaber.score_samples(Xs, Ys)\n", 262 | "print(np.exp(tb_scores).sum() * (Ys[1]-Ys[0]), np.exp(tbtab_scores).sum() * (Ys[1]-Ys[0]))\n", 263 | "plt.figure(figsize=(4,2), dpi=200);\n", 264 | "plt.plot(Ys, np.exp(tb_scores), label='Baltobot');\n", 265 | "plt.plot(Ys, np.exp(tbtab_scores), '--', label='BaltoboTabPFN');\n", 266 | "plt.xlim(-2, 2);\n", 267 | "plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05));\n", 268 | "plt.xlabel('y');\n", 269 | "plt.ylabel('pdf at x=2');\n", 270 | "plt.tight_layout();\n", 271 | "plt.savefig('wave-pdfat2.png');\n", 272 | "plt.close();" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 6, 278 | "id": "24ca0b83-9b7e-468f-8f34-823a36c917e6", 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stderr", 283 | "output_type": "stream", 284 | "text": [ 285 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n", 286 | " X = _check_array(X)\n", 287 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:113: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n", 288 | " y = _check_array(y)\n", 289 | "/Users/calvinm/miniconda3/envs/maskingtrees/lib/python3.9/site-packages/treeffuser/_base_tabular_diffusion.py:110: CastFloat32Warning: Input array is not float32; it has been recast to float32.\n", 290 | " X = _check_array(X)\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "nP = 500\n", 296 | "rng = np.random.default_rng(seed=seed)\n", 297 | "XP = rng.uniform(0, 3, size=nP)\n", 298 | "YP = rng.poisson(np.sqrt(XP), size=nP)\n", 299 | "tfer = Treeffuser(sde_initialize_from_data=True, seed=seed)\n", 300 | "tfer.fit(XP, YP)\n", 301 | "YP_tfer = tfer.sample(XP, n_samples=1, seed=seed, verbose=True)\n", 302 | "tber = Baltobot(random_state=seed)\n", 303 | "tber.fit(XP.reshape(-1, 1), YP)\n", 304 | "YP_tber = tber.sample(XP.reshape(-1, 1))" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 7, 310 | "id": "12dfdd1a-805c-4f7d-9b58-f06552ab97fb", 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "dfP = pd.DataFrame(); dfP['x'] = XP; dfP['y'] = YP\n", 315 | "s = 8; linewidth=0.3; edgecolor='white'; markercolor='blue';\n", 316 | "fig, axes = plt.subplots(figsize=(7,3), ncols=3, sharey=True, dpi=500);\n", 317 | "sns.scatterplot(data=dfP, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[0])\n", 318 | "axes[0].set_title('Original data');\n", 319 | "dfP_tfer = pd.DataFrame(); dfP_tfer['x'] = XP; dfP_tfer['y'] = YP_tfer.ravel()\n", 320 | "sns.scatterplot(data=dfP_tfer, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[1])\n", 321 | "axes[1].set_title('Treeffuser')\n", 322 | "dfP_tber = pd.DataFrame(); dfP_tber['x'] = XP; dfP_tber['y'] = YP_tber\n", 323 | "sns.scatterplot(data=dfP_tber, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[2])\n", 324 | "axes[2].set_title('Baltobot')\n", 325 | "plt.tight_layout();\n", 326 | "plt.savefig('poisson-demo.png');\n", 327 | "plt.close();" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "id": "92ae5c29-84e4-403b-8dd4-5fd7f0bb827c", 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python 3 (ipykernel)", 342 | "language": "python", 343 | "name": "python3" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.9.19" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 5 360 | } 361 | -------------------------------------------------------------------------------- /paper/iris.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1f6d6fc7-a52a-471d-9607-ad42f993159e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import matplotlib as mpl\n", 11 | "mpl.rcParams['font.family'] = 'Arial'\n", 12 | "mpl.rcParams['text.usetex'] = False\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "import seaborn as sns\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "\n", 18 | "from sklearn.datasets import load_iris\n", 19 | "from sklearn.utils import check_random_state\n", 20 | "from ForestDiffusion import ForestDiffusionModel\n", 21 | "from utrees import UnmaskingTrees\n", 22 | "from missforest import MissForest" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "6348b6f5-8eff-461d-833a-73c1bb8cfc16", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# Iris: numpy dataset with 4 variables (all numerical) and 1 outcome (categorical)\n", 33 | "my_data = load_iris()\n", 34 | "X, y = my_data['data'], my_data['target']\n", 35 | "Xy = np.concatenate((X, np.expand_dims(y, axis=1)), axis=1)\n", 36 | "palette = {'setosa': 'red', 'versicolor': 'green', 'virginica': 'blue'} \n", 37 | "alpha = 0.6\n", 38 | "rix = 0" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "5eb55da3-b595-448a-a776-5b6e67ff9a38", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "forestflow = ForestDiffusionModel(\n", 49 | " X, label_y=y, diffusion_type='flow', \n", 50 | " n_t=50, duplicate_K=100, bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n", 51 | "Xy_gen_forestflow = forestflow.generate(batch_size=X.shape[0]) # last variable will be the label_y\n", 52 | "\n", 53 | "forestvp = ForestDiffusionModel(\n", 54 | " X, label_y=y, diffusion_type='vp', \n", 55 | " n_t=50, duplicate_K=100, bin_indexes=[], cat_indexes=[], int_indexes=[], n_jobs=-1, seed=rix)\n", 56 | "Xy_gen_forestvp = forestvp.generate(batch_size=X.shape[0]) # last variable will be the label_y\n", 57 | "\n", 58 | "utrees = UnmaskingTrees(random_state=rix)\n", 59 | "utrees.fit(Xy, quantize_cols=['continuous']*4 + ['categorical'])\n", 60 | "Xy_gen_utrees = utrees.generate(n_generate=X.shape[0]);" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "1b97c3e5-c53b-4493-b5ca-235f8328bc81", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "X_mcar = X.copy()\n", 71 | "rng = check_random_state(rix)\n", 72 | "hasmissing = rng.normal(size=(150,)) < 0.\n", 73 | "X_hasmissing = X_mcar[hasmissing, :]\n", 74 | "X_hasmissing[rng.normal(size=X_hasmissing.shape) < 0] = np.nan\n", 75 | "X_mcar[hasmissing, :] = X_hasmissing\n", 76 | "Xy_mcar = np.concatenate((X_mcar, np.expand_dims(y, axis=1)), axis=1)\n", 77 | "np.isnan(X_mcar).sum()" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "5047c2aa-cc70-4b08-a153-89bed3ee51d3", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "missfer = MissForest(random_state=rix)\n", 88 | "Xy_impute_missf = missfer.fit_transform(Xy_mcar.copy(), cat_vars=[4])\n", 89 | "\n", 90 | "forestvp_mcar = ForestDiffusionModel(\n", 91 | " X=Xy_mcar, diffusion_type='vp', \n", 92 | " n_t=50, duplicate_K=100, bin_indexes=[], cat_indexes=[4], int_indexes=[0], n_jobs=-1, seed=rix)\n", 93 | "Xy_impute_forest_fast = forestvp_mcar.impute(k=1) # regular (fast)\n", 94 | "Xy_impute_forest = forestvp_mcar.impute(repaint=True, r=10, j=5, k=1) # REPAINT (slow, but better)\n", 95 | "\n", 96 | "utrees_mcar = UnmaskingTrees(random_state=rix)\n", 97 | "utrees_mcar.fit(Xy_mcar, quantize_cols=['continuous']*4 + ['categorical'])\n", 98 | "Xy_impute_utrees = utrees_mcar.impute(n_impute=1)[0, :, :]" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "b4fffa3c-4778-47d7-a18c-7fb778a67528", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(7, 3.5), sharex=True, sharey=True, squeeze=False);\n", 109 | "\n", 110 | "Xy_df = pd.DataFrame(data=Xy, columns=my_data['feature_names']+['target_names'])\n", 111 | "Xy_df['Species'] = Xy_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 112 | "sns.scatterplot(\n", 113 | " data=Xy_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n", 114 | " palette=palette, alpha=alpha, ax=axes[0, 0], legend=False)\n", 115 | "axes[0, 0].set_title('Original data');\n", 116 | "\n", 117 | "Xy_gen_forestvp_df = pd.DataFrame(data=Xy_gen_forestvp, columns=my_data['feature_names']+['target_names'])\n", 118 | "Xy_gen_forestvp_df['Species'] = Xy_gen_forestvp_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 119 | "sns.scatterplot(\n", 120 | " data=Xy_gen_forestvp_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n", 121 | " palette=palette, alpha=alpha, ax=axes[0, 1], legend=False)\n", 122 | "axes[0, 1].set_title('Forest-VP\\ngenerated');\n", 123 | "\n", 124 | "Xy_gen_forestflow_df = pd.DataFrame(data=Xy_gen_forestflow, columns=my_data['feature_names']+['target_names'])\n", 125 | "Xy_gen_forestflow_df['Species'] = Xy_gen_forestflow_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 126 | "sns.scatterplot(\n", 127 | " data=Xy_gen_forestflow_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n", 128 | " palette=palette, alpha=alpha, ax=axes[0, 2], legend=False)\n", 129 | "axes[0, 2].set_title('Forest-Flow\\ngenerated');\n", 130 | "\n", 131 | "Xy_gen_utrees_df = pd.DataFrame(data=Xy_gen_utrees, columns=my_data['feature_names']+['target_names'])\n", 132 | "Xy_gen_utrees_df['Species'] = Xy_gen_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 133 | "sns.scatterplot(\n", 134 | " data=Xy_gen_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=15,\n", 135 | " palette=palette, alpha=alpha, ax=axes[0, 3], legend=True)\n", 136 | "axes[0, 3].set_title('UnmaskingTrees\\ngenerated');\n", 137 | "\n", 138 | "sns.move_legend(axes[0, 3], \"upper left\", bbox_to_anchor=(1, 1.1))\n", 139 | "\n", 140 | "Xy_impute_missf_df = pd.DataFrame(data=Xy_impute_missf, columns=my_data['feature_names']+['target_names'])\n", 141 | "Xy_impute_missf_df['Species'] = Xy_impute_missf_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 142 | "Xy_impute_missf_df['Imputed'] = hasmissing\n", 143 | "sns.scatterplot(\n", 144 | " data=Xy_impute_missf_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 145 | " size='Imputed', sizes={False: 10, True: 30},\n", 146 | " palette=palette, alpha=alpha, ax=axes[1, 0], legend=False)\n", 147 | "axes[1, 0].set_title('MissForest\\nimputed');\n", 148 | "\n", 149 | "Xy_impute_forest_fast_df = pd.DataFrame(data=Xy_impute_forest_fast, columns=my_data['feature_names']+['target_names'])\n", 150 | "Xy_impute_forest_fast_df['Species'] = Xy_impute_forest_fast_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 151 | "Xy_impute_forest_fast_df['Imputed'] = hasmissing\n", 152 | "sns.scatterplot(\n", 153 | " data=Xy_impute_forest_fast_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 154 | " size='Imputed', sizes={False: 10, True: 30},\n", 155 | " palette=palette, alpha=alpha, ax=axes[1, 1], legend=False)\n", 156 | "axes[1, 1].set_title('Forest-VP\\nimputed');\n", 157 | "\n", 158 | "Xy_impute_forest_df = pd.DataFrame(data=Xy_impute_forest, columns=my_data['feature_names']+['target_names'])\n", 159 | "Xy_impute_forest_df['Species'] = Xy_impute_forest_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 160 | "Xy_impute_forest_df['Imputed'] = hasmissing\n", 161 | "sns.scatterplot(\n", 162 | " data=Xy_impute_forest_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 163 | " size='Imputed', sizes={False: 10, True: 30},\n", 164 | " palette=palette, alpha=alpha, ax=axes[1, 2], legend=False)\n", 165 | "axes[1, 2].set_title('Forest-VP\\nimputed w/ RePaint');\n", 166 | "\n", 167 | "Xy_impute_utrees_df = pd.DataFrame(data=Xy_impute_utrees, columns=my_data['feature_names']+['target_names'])\n", 168 | "Xy_impute_utrees_df['Species'] = Xy_impute_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 169 | "Xy_impute_utrees_df['Imputed'] = hasmissing\n", 170 | "\n", 171 | "sns.scatterplot(\n", 172 | " data=Xy_impute_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 173 | " size='Imputed', sizes={False: 10, True: 30},\n", 174 | " palette=palette, alpha=alpha, ax=axes[1, 3], legend=True)\n", 175 | "axes[1, 3].set_title('UnmaskingTrees\\nimputed');\n", 176 | "sns.move_legend(axes[1, 3], \"upper left\", bbox_to_anchor=(1, 1.1))\n", 177 | "\n", 178 | "plt.tight_layout();\n", 179 | "plt.savefig('iris.png');" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "0d040488-6031-45ea-8f1d-497e86f612a4", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(5, 7), dpi=200, sharex=True, sharey=True, squeeze=False);\n", 190 | "\n", 191 | "Xy_df = pd.DataFrame(data=Xy, columns=my_data['feature_names']+['target_names'])\n", 192 | "Xy_df['Species'] = Xy_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 193 | "sns.scatterplot(\n", 194 | " data=Xy_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n", 195 | " palette=palette, alpha=alpha, ax=axes[0, 0], legend=False)\n", 196 | "axes[0, 0].set_title('Original data');\n", 197 | "\n", 198 | "Xy_gen_forestvp_df = pd.DataFrame(data=Xy_gen_forestvp, columns=my_data['feature_names']+['target_names'])\n", 199 | "Xy_gen_forestvp_df['Species'] = Xy_gen_forestvp_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 200 | "sns.scatterplot(\n", 201 | " data=Xy_gen_forestvp_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n", 202 | " palette=palette, alpha=alpha, ax=axes[1, 0], legend=False)\n", 203 | "axes[1, 0].set_title('Forest-VP\\ngenerated');\n", 204 | "\n", 205 | "Xy_gen_forestflow_df = pd.DataFrame(data=Xy_gen_forestflow, columns=my_data['feature_names']+['target_names'])\n", 206 | "Xy_gen_forestflow_df['Species'] = Xy_gen_forestflow_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 207 | "sns.scatterplot(\n", 208 | " data=Xy_gen_forestflow_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n", 209 | " palette=palette, alpha=alpha, ax=axes[2, 0], legend=False)\n", 210 | "axes[2, 0].set_title('Forest-Flow\\ngenerated');\n", 211 | "\n", 212 | "Xy_gen_utrees_df = pd.DataFrame(data=Xy_gen_utrees, columns=my_data['feature_names']+['target_names'])\n", 213 | "Xy_gen_utrees_df['Species'] = Xy_gen_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 214 | "sns.scatterplot(\n", 215 | " data=Xy_gen_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', s=9,\n", 216 | " palette=palette, alpha=alpha, ax=axes[3, 0], legend=False)\n", 217 | "axes[3, 0].set_title('UnmaskingTrees\\ngenerated');\n", 218 | "\n", 219 | "#sns.move_legend(axes[3, 0], \"upper left\", bbox_to_anchor=(1, 1.1))\n", 220 | "\n", 221 | "Xy_impute_missf_df = pd.DataFrame(data=Xy_impute_missf, columns=my_data['feature_names']+['target_names'])\n", 222 | "Xy_impute_missf_df['Species'] = Xy_impute_missf_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 223 | "Xy_impute_missf_df['Imputed'] = hasmissing\n", 224 | "sns.scatterplot(\n", 225 | " data=Xy_impute_missf_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 226 | " size='Imputed', sizes={False: 5, True: 20},\n", 227 | " palette=palette, alpha=alpha, ax=axes[0, 1], legend=False)\n", 228 | "axes[0, 1].set_title('MissForest\\nimputed');\n", 229 | "\n", 230 | "Xy_impute_forest_fast_df = pd.DataFrame(data=Xy_impute_forest_fast, columns=my_data['feature_names']+['target_names'])\n", 231 | "Xy_impute_forest_fast_df['Species'] = Xy_impute_forest_fast_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 232 | "Xy_impute_forest_fast_df['Imputed'] = hasmissing\n", 233 | "sns.scatterplot(\n", 234 | " data=Xy_impute_forest_fast_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 235 | " size='Imputed', sizes={False: 5, True: 20},\n", 236 | " palette=palette, alpha=alpha, ax=axes[1, 1], legend=False)\n", 237 | "axes[1, 1].set_title('Forest-VP\\nimputed');\n", 238 | "\n", 239 | "Xy_impute_forest_df = pd.DataFrame(data=Xy_impute_forest, columns=my_data['feature_names']+['target_names'])\n", 240 | "Xy_impute_forest_df['Species'] = Xy_impute_forest_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 241 | "Xy_impute_forest_df['Imputed'] = hasmissing\n", 242 | "sns.scatterplot(\n", 243 | " data=Xy_impute_forest_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 244 | " size='Imputed', sizes={False: 5, True: 20},\n", 245 | " palette=palette, alpha=alpha, ax=axes[2, 1], legend=False)\n", 246 | "axes[2, 1].set_title('Forest-VP\\nimputed w/ RePaint');\n", 247 | "\n", 248 | "Xy_impute_utrees_df = pd.DataFrame(data=Xy_impute_utrees, columns=my_data['feature_names']+['target_names'])\n", 249 | "Xy_impute_utrees_df['Species'] = Xy_impute_utrees_df['target_names'].apply(lambda x: my_data['target_names'][int(x)].item())\n", 250 | "Xy_impute_utrees_df['Imputed'] = hasmissing\n", 251 | "\n", 252 | "sns.scatterplot(\n", 253 | " data=Xy_impute_utrees_df, x='petal length (cm)', y='petal width (cm)', hue='Species', style='Imputed',\n", 254 | " size='Imputed', sizes={False: 5, True: 20},\n", 255 | " palette=palette, alpha=alpha, ax=axes[3, 1], legend=True)\n", 256 | "axes[3, 1].set_title('UnmaskingTrees\\nimputed');\n", 257 | "sns.move_legend(axes[3, 1], \"upper left\", bbox_to_anchor=(1.1, 1.3))\n", 258 | "\n", 259 | "plt.tight_layout();\n", 260 | "plt.savefig('iris-vertical.png');\n", 261 | "plt.savefig('iris-vertical.pdf');" 262 | ] 263 | } 264 | ], 265 | "metadata": { 266 | "kernelspec": { 267 | "display_name": "Python 3 (ipykernel)", 268 | "language": "python", 269 | "name": "python3" 270 | }, 271 | "language_info": { 272 | "codemirror_mode": { 273 | "name": "ipython", 274 | "version": 3 275 | }, 276 | "file_extension": ".py", 277 | "mimetype": "text/x-python", 278 | "name": "python", 279 | "nbconvert_exporter": "python", 280 | "pygments_lexer": "ipython3", 281 | "version": "3.9.19" 282 | } 283 | }, 284 | "nbformat": 4, 285 | "nbformat_minor": 5 286 | } 287 | -------------------------------------------------------------------------------- /Results/imputation_script.R: -------------------------------------------------------------------------------- 1 | options(show.error.locations = TRUE) 2 | options(repos=c(CRAN = "https://cloud.r-project.org")) 3 | install.packages('plotrix') 4 | install.packages('matrixStats') 5 | install.packages('xtable') 6 | library(matrixStats) 7 | library(xtable) 8 | library(plotrix) 9 | 10 | data = read.csv("Results_tabular - imputation.csv") 11 | data = read.csv("Results_tabular - imputation.csv") 12 | 13 | # Better scaling for clean tables 14 | data$PercentBias = data$PercentBias / 100 15 | 16 | # debug small floats 17 | data$MeanVariance[data$MeanVariance < 1e-12] = 0 18 | data$MeanMAD[data$MeanMAD < 1e-12] = 0 19 | data$MedianMAD[data$MedianMAD < 1e-12] = 0 20 | 21 | all_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'california', 'bean', 'car','congress','tictactoe') 22 | W_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'car','congress','tictactoe') 23 | R2_vars = c('concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'wine_quality_red', 'wine_quality_white', 'california') 24 | F1_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'bean', 'car','congress','tictactoe') 25 | 26 | methods = c('KNN(n_neighbors=1)', 'ice', 'miceforest', 'MissForest', 'softimpute', 'OT', 'GAIN', 'forest_diffusion_repaint_nt51_ycond', 'utrees') 27 | methods_all = c('KNN(n_neighbors=1)', 'ice', 'miceforest', 'MissForest', 'softimpute', 'OT', 'GAIN', 'forest_diffusion_repaint_nt51_ycond', 'utrees', 'oracle') 28 | 29 | data_summary = data.frame(matrix(ncol = length(methods_all), nrow = 10)) 30 | colnames(data_summary) = methods_all 31 | rownames(data_summary) = c("MinMAE", "AvgMAE", "W_train", "W_test", "MedianMAD", "R2_imp", "F1_imp", "PercentBias", "CoverageRate", "time") 32 | 33 | for (method in methods_all){ 34 | 35 | MinMAE_mean = mean(data$MinMAE[data$method==method & data$dataset %in% all_vars]) 36 | MinMAE_sd = std.error(data$MinMAE[data$method==method & data$dataset %in% all_vars]) 37 | 38 | AvgMAE_mean = mean(data$AvgMAE[data$method==method & data$dataset %in% all_vars]) 39 | AvgMAE_sd = std.error(data$AvgMAE[data$method==method & data$dataset %in% all_vars]) 40 | 41 | W_train_mean = mean(data$W_train[data$method==method & data$dataset %in% W_vars]) 42 | W_train_sd = std.error(data$W_train[data$method==method & data$dataset %in% W_vars]) 43 | 44 | W_test_mean = mean(data$W_test[data$method==method & data$dataset %in% W_vars]) 45 | W_test_sd = std.error(data$W_test[data$method==method & data$dataset %in% W_vars]) 46 | 47 | PercentBias_mean = mean(data$PercentBias[data$method==method & data$dataset %in% R2_vars]) 48 | PercentBias_sd = std.error(data$PercentBias[data$method==method & data$dataset %in% R2_vars]) 49 | 50 | CoverageRate_mean = mean(data$CoverageRate[data$method==method & data$dataset %in% R2_vars]) 51 | CoverageRate_sd = std.error(data$CoverageRate[data$method==method & data$dataset %in% R2_vars]) 52 | 53 | MedianMAD_mean = mean(data$MedianMAD[data$method==method & data$dataset %in% all_vars]) 54 | MedianMAD_sd = std.error(data$MedianMAD[data$method==method & data$dataset %in% all_vars]) 55 | 56 | R2_imp_mean = mean(data$R2_imp[data$method==method & data$dataset %in% R2_vars]) 57 | R2_imp_sd = std.error(data$R2_imp[data$method==method & data$dataset %in% R2_vars]) 58 | 59 | F1_imp_mean = mean(data$F1_imp[data$method==method & data$dataset %in% F1_vars]) 60 | F1_imp_sd = std.error(data$F1_imp[data$method==method & data$dataset %in% F1_vars]) 61 | 62 | time_mean = mean(data$time[data$method==method & data$dataset %in% all_vars]) 63 | time_sd = std.error(data$time[data$method==method & data$dataset %in% all_vars]) 64 | 65 | data_summary[method] = c( 66 | paste0(as.character(round(MinMAE_mean, 2)), ' (', as.character(round(MinMAE_sd, 2)), ')'), 67 | paste0(as.character(round(AvgMAE_mean, 2)), ' (', as.character(round(AvgMAE_sd, 2)), ')'), 68 | paste0(as.character(round(W_train_mean, 2)), ' (', as.character(round(W_train_sd, 2)), ')'), 69 | paste0(as.character(round(W_test_mean, 2)), ' (', as.character(round(W_test_sd, 2)), ')'), 70 | paste0(as.character(round(MedianMAD_mean, 2)), ' (', as.character(round(MedianMAD_sd, 2)), ')'), 71 | paste0(as.character(round(R2_imp_mean, 2)), ' (', as.character(round(R2_imp_sd, 2)), ')'), 72 | paste0(as.character(round(F1_imp_mean, 2)), ' (', as.character(round(F1_imp_sd, 2)), ')'), 73 | paste0(as.character(round(PercentBias_mean, 2)), ' (', as.character(round(PercentBias_sd, 2)), ')'), 74 | paste0(as.character(round(CoverageRate_mean, 2)), ' (', as.character(round(CoverageRate_sd, 2)), ')'), 75 | paste0(as.character(round(time_mean, 2)), ' (', as.character(round(time_sd, 2)), ')') 76 | ) 77 | 78 | } 79 | data_summary 80 | data_summary_t <- t(data_summary) 81 | data_summary_t 82 | 83 | ###################### RANK 84 | 85 | data_summary_rank = data.frame(matrix(ncol = length(methods), nrow = 10)) 86 | colnames(data_summary_rank) = methods 87 | rownames(data_summary_rank) = c("MinMAE", "AvgMAE", "W_train", "W_test", "MedianMAD", "R2_imp", "F1_imp", "PercentBias", "CoverageRate", "time") 88 | 89 | MinMAE = data.frame(matrix(ncol = length(methods), nrow = length(all_vars))) 90 | colnames(MinMAE) = methods 91 | rownames(MinMAE) = all_vars 92 | AvgMAE = data.frame(matrix(ncol = length(methods), nrow = length(all_vars))) 93 | colnames(AvgMAE) = methods 94 | rownames(AvgMAE) = all_vars 95 | W_train = data.frame(matrix(ncol = length(methods), nrow = length(W_vars))) 96 | colnames(W_train) = methods 97 | rownames(W_train) = W_vars 98 | W_test = data.frame(matrix(ncol = length(methods), nrow = length(W_vars))) 99 | colnames(W_test) = methods 100 | rownames(W_test) = W_vars 101 | MedianMAD = data.frame(matrix(ncol = length(methods), nrow = length(all_vars))) 102 | colnames(MedianMAD) = methods 103 | rownames(MedianMAD) = all_vars 104 | R2_imp = data.frame(matrix(ncol = length(methods), nrow = length(R2_vars))) 105 | colnames(R2_imp) = methods 106 | rownames(R2_imp) = R2_vars 107 | F1_imp = data.frame(matrix(ncol = length(methods), nrow = length(F1_vars))) 108 | colnames(F1_imp) = methods 109 | rownames(F1_imp) = F1_vars 110 | PercentBias = data.frame(matrix(ncol = length(methods), nrow = length(R2_vars))) 111 | colnames(PercentBias) = methods 112 | rownames(PercentBias) = R2_vars 113 | CoverageRate = data.frame(matrix(ncol = length(methods), nrow = length(R2_vars))) 114 | colnames(CoverageRate) = methods 115 | rownames(CoverageRate) = R2_vars 116 | Time = data.frame(matrix(ncol = length(methods), nrow = length(all_vars))) 117 | colnames(Time) = methods 118 | rownames(Time) = all_vars 119 | 120 | for (method in methods){ 121 | print(method) 122 | MinMAE[method] = data$MinMAE[data$method==method & data$dataset %in% all_vars] 123 | AvgMAE[method] = data$AvgMAE[data$method==method & data$dataset %in% all_vars] 124 | print(method) 125 | W_train[method] = data$W_train[data$method==method & data$dataset %in% W_vars] 126 | W_test[method] = data$W_test[data$method==method & data$dataset %in% W_vars] 127 | print(method) 128 | MedianMAD[method] = -data$MedianMAD[data$method==method & data$dataset %in% all_vars] 129 | print(method) 130 | PercentBias[method] = data$PercentBias[data$method==method & data$dataset %in% R2_vars] 131 | CoverageRate[method] = -data$CoverageRate[data$method==method & data$dataset %in% R2_vars] 132 | print(method) 133 | R2_imp[method] = -data$R2_imp[data$method==method & data$dataset %in% R2_vars] 134 | F1_imp[method] = -data$F1_imp[data$method==method & data$dataset %in% F1_vars] 135 | 136 | Time[method] = data$time[data$method==method & data$dataset %in% all_vars] 137 | 138 | } 139 | 140 | # Rank by dataset 141 | 142 | MinMAE_ = as.data.frame(t(sapply(as.data.frame(t(MinMAE)), rank))) 143 | colnames(MinMAE_) = colnames(MinMAE) 144 | MinMAE = MinMAE_ 145 | 146 | AvgMAE_ = as.data.frame(t(sapply(as.data.frame(t(AvgMAE)), rank))) 147 | colnames(AvgMAE_) = colnames(AvgMAE) 148 | AvgMAE = AvgMAE_ 149 | 150 | W_train_ = as.data.frame(t(sapply(as.data.frame(t(W_train)), rank))) 151 | colnames(W_train_) = colnames(W_train) 152 | W_train = W_train_ 153 | 154 | W_test_ = as.data.frame(t(sapply(as.data.frame(t(W_test)), rank))) 155 | colnames(W_test_) = colnames(W_test) 156 | W_test = W_test_ 157 | 158 | MedianMAD_ = as.data.frame(t(sapply(as.data.frame(t(MedianMAD)), rank))) 159 | colnames(MedianMAD_) = colnames(MedianMAD) 160 | MedianMAD = MedianMAD_ 161 | 162 | PercentBias_ = as.data.frame(t(sapply(as.data.frame(t(PercentBias)), rank))) 163 | colnames(PercentBias_) = colnames(PercentBias) 164 | PercentBias = PercentBias_ 165 | 166 | CoverageRate_ = as.data.frame(t(sapply(as.data.frame(t(CoverageRate)), rank))) 167 | colnames(CoverageRate_) = colnames(CoverageRate) 168 | CoverageRate = CoverageRate_ 169 | 170 | R2_imp_ = as.data.frame(t(sapply(as.data.frame(t(R2_imp)), rank))) 171 | colnames(R2_imp_) = colnames(R2_imp) 172 | R2_imp = R2_imp_ 173 | 174 | F1_imp_ = as.data.frame(t(sapply(as.data.frame(t(F1_imp)), rank))) 175 | colnames(F1_imp_) = colnames(F1_imp) 176 | F1_imp = F1_imp_ 177 | 178 | Time_ = as.data.frame(t(sapply(as.data.frame(t(Time)), rank))) 179 | colnames(Time_) = colnames(Time) 180 | Time = Time_ 181 | 182 | for (method in methods){ 183 | 184 | AvgMAE_mean = mean(unlist(AvgMAE[method])) 185 | AvgMAE_sd = std.error(unlist(AvgMAE[method])) 186 | 187 | MinMAE_mean = mean(unlist(MinMAE[method])) 188 | MinMAE_sd = std.error(unlist(MinMAE[method])) 189 | 190 | W_train_mean = mean(unlist(W_train[method])) 191 | W_train_sd = std.error(unlist(W_train[method])) 192 | 193 | W_test_mean = mean(unlist(W_test[method])) 194 | W_test_sd = std.error(unlist(W_test[method])) 195 | 196 | MedianMAD_mean = mean(unlist(MedianMAD[method])) 197 | MedianMAD_sd = std.error(unlist(MedianMAD[method])) 198 | 199 | PercentBias_mean = mean(unlist(PercentBias[method])) 200 | PercentBias_sd = std.error(unlist(PercentBias[method])) 201 | 202 | CoverageRate_mean = mean(unlist(CoverageRate[method])) 203 | CoverageRate_sd = std.error(unlist(CoverageRate[method])) 204 | 205 | R2_imp_mean = mean(unlist(R2_imp[method])) 206 | R2_imp_sd = std.error(unlist(R2_imp[method])) 207 | 208 | F1_imp_mean = mean(unlist(F1_imp[method])) 209 | F1_imp_sd = std.error(unlist(F1_imp[method])) 210 | 211 | time_mean = mean(unlist(Time[method])) 212 | time_sd = std.error(unlist(Time[method])) 213 | 214 | data_summary_rank[method] = c( 215 | paste0(as.character(round(AvgMAE_mean, 1)), ' (', as.character(round(AvgMAE_sd, 1)), ')'), 216 | paste0(as.character(round(MinMAE_mean, 1)), ' (', as.character(round(MinMAE_sd, 1)), ')'), 217 | paste0(as.character(round(W_train_mean, 1)), ' (', as.character(round(W_train_sd, 1)), ')'), 218 | paste0(as.character(round(W_test_mean, 1)), ' (', as.character(round(W_test_sd, 1)), ')'), 219 | paste0(as.character(round(MedianMAD_mean, 1)), ' (', as.character(round(MedianMAD_sd, 1)), ')'), 220 | paste0(as.character(round(PercentBias_mean, 1)), ' (', as.character(round(PercentBias_sd, 1)), ')'), 221 | paste0(as.character(round(CoverageRate_mean, 1)), ' (', as.character(round(CoverageRate_sd, 1)), ')'), 222 | paste0(as.character(round(R2_imp_mean, 1)), ' (', as.character(round(R2_imp_sd, 1)), ')'), 223 | paste0(as.character(round(F1_imp_mean, 1)), ' (', as.character(round(F1_imp_sd, 1)), ')'), 224 | paste0(as.character(round(time_mean, 1)), ' (', as.character(round(time_sd, 1)), ')') 225 | ) 226 | 227 | } 228 | data_summary_rank 229 | data_summary_rank_t <- t(data_summary_rank) 230 | data_summary_rank_t 231 | 232 | ########### Latex tables 233 | 234 | rownames(data_summary_t) = c("KNN", "ICE", "MICE-Forest", "MissForest", "Softimpute", "OT", "GAIN", "Forest-VP", "utrees", "Oracle") 235 | rownames(data_summary_rank_t) = c("KNN", "ICE", "MICE-Forest", "MissForest", "Softimpute", "OT", "GAIN", "Forest-VP", "utrees") 236 | 237 | xtable(data_summary_t[,-10], type = "latex") 238 | xtable(data_summary_rank_t[,-10], type = "latex") 239 | 240 | if (FALSE) { 241 | # For building the barplot, to csv 242 | 243 | MinMAE = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 244 | colnames(MinMAE) = methods_all 245 | rownames(MinMAE) = all_vars 246 | MinMAE[,] = 0 247 | AvgMAE = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 248 | colnames(AvgMAE) = methods_all 249 | rownames(AvgMAE) = all_vars 250 | AvgMAE[,] = 0 251 | W_train = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 252 | colnames(W_train) = methods_all 253 | rownames(W_train) = all_vars 254 | W_train[,] = 0 255 | W_test = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 256 | colnames(W_test) = methods_all 257 | rownames(W_test) = all_vars 258 | W_test[,] = 0 259 | MedianMAD = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 260 | colnames(MedianMAD) = methods_all 261 | rownames(MedianMAD) = all_vars 262 | MedianMAD[,] = 0 263 | R2_imp = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 264 | colnames(R2_imp) = methods_all 265 | rownames(R2_imp) = all_vars 266 | R2_imp[,] = 0 267 | F1_imp = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 268 | colnames(F1_imp) = methods_all 269 | rownames(F1_imp) = all_vars 270 | F1_imp[,] = 0 271 | PercentBias = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 272 | colnames(PercentBias) = methods_all 273 | rownames(PercentBias) = all_vars 274 | PercentBias[,] = 0 275 | CoverageRate = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 276 | colnames(CoverageRate) = methods_all 277 | rownames(CoverageRate) = all_vars 278 | CoverageRate[,] = 0 279 | 280 | for (method in methods_all){ 281 | MinMAE[rownames(MinMAE) %in% all_vars, method] = data$MinMAE[data$method==method & data$dataset %in% all_vars] 282 | AvgMAE[rownames(AvgMAE) %in% all_vars, method] = data$AvgMAE[data$method==method & data$dataset %in% all_vars] 283 | 284 | W_train[rownames(W_train) %in% W_vars, method] = data$W_train[data$method==method & data$dataset %in% W_vars] 285 | W_test[rownames(W_test) %in% W_vars, method] = data$W_test[data$method==method & data$dataset %in% W_vars] 286 | MedianMAD[, method] = data$MedianMAD[data$method==method & data$dataset %in% all_vars] 287 | 288 | R2_imp[rownames(R2_imp) %in% R2_vars, method] = data$R2_imp[data$method==method & data$dataset %in% R2_vars] 289 | F1_imp[rownames(F1_imp) %in% F1_vars, method] = data$F1_imp[data$method==method & data$dataset %in% F1_vars] 290 | PercentBias[rownames(PercentBias) %in% R2_vars, method] = data$PercentBias[data$method==method & data$dataset %in% R2_vars] 291 | CoverageRate[rownames(CoverageRate) %in% R2_vars, method] = data$CoverageRate[data$method==method & data$dataset %in% R2_vars] 292 | } 293 | 294 | 295 | # For building the barplot, to csv 296 | write.csv(MinMAE, file="imp_MinMAE.csv") 297 | write.csv(AvgMAE, file="imp_AvgMAE.csv") 298 | write.csv(W_test, file="imp_Wtrain.csv") 299 | write.csv(W_test, file="imp_W.csv") 300 | write.csv(MedianMAD, file="imp_MedianMAD.csv") 301 | write.csv(R2_imp, file="imp_R2.csv") 302 | write.csv(F1_imp, file="imp_F1.csv") 303 | write.csv(PercentBias, file="imp_pb.csv") 304 | write.csv(CoverageRate, file="imp_cr.csv") 305 | } 306 | -------------------------------------------------------------------------------- /utrees/unmasking_trees.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Union, Optional 3 | import warnings 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import xgboost as xgb 8 | 9 | from sklearn.base import BaseEstimator 10 | from sklearn.utils import check_random_state 11 | from sklearn.preprocessing import LabelEncoder 12 | from tqdm import tqdm 13 | 14 | from utrees.baltobot import Baltobot, NanTabPFNClassifier 15 | from utrees.kdi_quantizer import KDIQuantizer 16 | 17 | 18 | XGBOOST_DEFAULT_KWARGS = { 19 | "tree_method": "hist", 20 | "verbosity": 1, 21 | "objective": "binary:logistic", 22 | } 23 | 24 | TABPFN_DEFAULT_KWARGS = { 25 | "seed": 42, 26 | "batch_size_inference": 1, 27 | "subsample_features": True, 28 | } 29 | 30 | 31 | class UnmaskingTrees(BaseEstimator): 32 | """Performs generative modeling and multiple imputation on tabular data. 33 | 34 | This method generates training data by iteratively masking samples, 35 | then training per-feature XGBoost models to unmask the data. 36 | 37 | Parameters 38 | ---------- 39 | depth : int >= 0 40 | Depth of balanced binary tree for recursively quantizing each feature. 41 | The total number of quantization bins is 2^depth. 42 | 43 | duplicate_K : int > 0 44 | Number of random masking orders per actual sample. 45 | The training dataset will be of size (n_samples * n_dims * duplicate_K, n_dims). 46 | 47 | clf_kwargs : dict 48 | Arguments for XGBoost (or TabPFN) classifier. 49 | 50 | strategy : 'kdiquantile', 'quantile', 'uniform', 'kmeans' 51 | The quantization strategy for discretizing continuous features. 52 | 53 | softmax_temp : float > 0 54 | Softmax temperature for sampling from predicted probabilities. 55 | As temperature decreases below the default of 1, predictions converge 56 | to the argmax of each conditional distribution. 57 | 58 | tabpfn: bool 59 | Whether to use TabPFN instead of XGBoost classifier. 60 | 61 | cast_float32: bool 62 | Whether to always convert inputs to float32. 63 | 64 | random_state : int, RandomState instance or None, default=None 65 | Determines random number generation. 66 | 67 | 68 | Attributes 69 | ---------- 70 | trees_ : list of Baltobot or XGBClassifier 71 | Fitted per-feature sampling models. For each column ix, this 72 | will be Baltobot if quantize_cols_[ix], and XGBClassifier otherwise. 73 | 74 | constant_vals_ : list of float or None, with length n_dims 75 | Gives the values of constant features, or None otherwise. 76 | 77 | quantize_cols : str or list of strs 78 | Saved user-provided quantize_cols input. 79 | 80 | quantize_cols_ : list of bool, with length n_dims 81 | Whether to apply (and un-apply) discretization to each feature. 82 | 83 | encoders_ : list of None or LabelEncoder 84 | Preprocessors for non-continuous features. 85 | 86 | clf_kwargs_ : dict 87 | Args passed to XGBClassifier or NanTabPFNClassifier. 88 | 89 | X_ : np.ndarray (n_samples, n_dims) 90 | Input data. 91 | 92 | """ 93 | 94 | def __init__( 95 | self, 96 | depth: int = 4, 97 | duplicate_K: int = 50, 98 | clf_kwargs: dict = {}, 99 | strategy: str = "kdiquantile", 100 | softmax_temp: float = 1.0, 101 | cast_float32: bool = True, 102 | tabpfn: bool = False, 103 | random_state=None, 104 | ): 105 | self.depth = depth 106 | self.duplicate_K = duplicate_K 107 | self.clf_kwargs = clf_kwargs 108 | self.strategy = strategy 109 | self.softmax_temp = softmax_temp 110 | self.cast_float32 = cast_float32 111 | self.tabpfn = tabpfn 112 | self.random_state = random_state 113 | 114 | if self.tabpfn: 115 | self.clf_kwargs_ = TABPFN_DEFAULT_KWARGS.copy() 116 | else: 117 | self.clf_kwargs_ = XGBOOST_DEFAULT_KWARGS.copy() 118 | self.clf_kwargs_.update(clf_kwargs) 119 | 120 | assert 1 <= duplicate_K 121 | assert strategy in ("kdiquantile", "quantile", "uniform", "kmeans") 122 | assert 0 < softmax_temp 123 | 124 | self.random_state_ = check_random_state(random_state) 125 | self.trees_ = None 126 | self.constant_vals_ = None 127 | self.quantize_cols_ = None 128 | self.encoders_ = None 129 | self.X_ = None 130 | 131 | def fit( 132 | self, 133 | X: np.ndarray, 134 | quantize_cols: Union[str, list] = "all", 135 | ): 136 | """ 137 | Fit the estimator. 138 | 139 | Parameters 140 | ---------- 141 | X : array-like of shape (n_samples, n_dims) 142 | Data to be modeled and imputed (possibly with np.nan). 143 | 144 | quantize_cols : 'all', 'none', or list of strs ('continuous', 'categorical', 'integer') 145 | Whether to apply (and un-apply) discretization to each feature. 146 | 147 | Returns 148 | ------- 149 | self : object 150 | Returns the instance itself. 151 | """ 152 | rng = check_random_state(self.random_state_) 153 | assert isinstance(X, np.ndarray) 154 | self.X_ = X.copy() 155 | if self.cast_float32: 156 | X = X.astype(np.float32) 157 | n_samples, n_dims = X.shape 158 | 159 | if isinstance(quantize_cols, list): 160 | assert len(quantize_cols) == n_dims 161 | self.quantize_cols_ = [] 162 | for d, elt in enumerate(quantize_cols): 163 | if elt == "continuous": 164 | self.quantize_cols_.append(True) 165 | elif elt == "categorical": 166 | self.quantize_cols_.append(False) 167 | elif elt == "integer": 168 | self.quantize_cols_.append(True) 169 | else: 170 | assert elt in ("continuous", "categorical", "integer") 171 | elif quantize_cols == "none": 172 | self.quantize_cols_ = [False] * n_dims 173 | elif quantize_cols == "all": 174 | self.quantize_cols_ = [True] * n_dims 175 | else: 176 | raise ValueError(f"unexpected quantize_cols: {quantize_cols}") 177 | self.quantize_cols = deepcopy(quantize_cols) 178 | 179 | # Find features with constant vals, to be unmasked before training and inference 180 | self.constant_vals_ = [] 181 | for d in range(n_dims): 182 | col_d = X[~np.isnan(X[:, d]), d] 183 | if len(np.unique(col_d)) == 1: 184 | self.constant_vals_.append(np.unique(col_d).item()) 185 | else: 186 | self.constant_vals_.append(None) 187 | 188 | # Fit encoders 189 | self.encoders_ = [] 190 | for d in range(n_dims): 191 | if self.quantize_cols_[d]: 192 | cur_enc = None 193 | else: 194 | cur_enc = LabelEncoder() 195 | cur_enc.fit(X[~np.isnan(X[:, d]), d]) 196 | self.encoders_.append(cur_enc) 197 | 198 | # Generate training data 199 | X_train = [] 200 | Y_train = [] 201 | for dupix in range(self.duplicate_K): 202 | mask_ixs = np.repeat(np.arange(n_dims)[np.newaxis, :], n_samples, axis=0) 203 | mask_ixs = np.apply_along_axis( 204 | rng.permutation, axis=1, arr=mask_ixs 205 | ) # n_samples, n_dims 206 | for n in range(n_samples): 207 | fuller_X = X[n, :] 208 | for d in range(n_dims): 209 | victim_ix = mask_ixs[n, d] 210 | if fuller_X[victim_ix] != np.nan: 211 | emptier_X = fuller_X.copy() 212 | emptier_X[victim_ix] = np.nan 213 | X_train.append(emptier_X.reshape(1, -1)) 214 | Y_train.append(fuller_X.reshape(1, -1)) 215 | fuller_X = emptier_X 216 | X_train = np.concatenate(X_train, axis=0) 217 | Y_train = np.concatenate(Y_train, axis=0) 218 | 219 | # Unmask constant-value columns before training 220 | for d in range(n_dims): 221 | if self.constant_vals_[d] is not None: 222 | X_train[:, d] = self.constant_vals_[d] 223 | 224 | # Fit trees 225 | self.trees_ = [] 226 | for d in range(n_dims): 227 | if self.constant_vals_[d] is not None: 228 | self.trees_.append(None) 229 | continue 230 | train_ixs = ~np.isnan(Y_train[:, d]) 231 | curX_train = X_train[train_ixs, :] 232 | if self.quantize_cols_[d]: 233 | curY_train = Y_train[train_ixs, d] 234 | balto = Baltobot( 235 | depth=self.depth, 236 | clf_kwargs=self.clf_kwargs, 237 | strategy=self.strategy, 238 | softmax_temp=self.softmax_temp, 239 | tabpfn=self.tabpfn, 240 | random_state=self.random_state_, 241 | ) 242 | balto.fit(curX_train, curY_train) 243 | self.trees_.append(balto) 244 | else: 245 | curY_train = self.encoders_[d].transform(Y_train[train_ixs, d]) 246 | if self.tabpfn: 247 | clfer = NanTabPFNClassifier(**self.clf_kwargs) 248 | else: 249 | clfer = xgb.XGBClassifier(**self.clf_kwargs) 250 | 251 | clfer.fit(curX_train, curY_train) 252 | self.trees_.append(clfer) 253 | return self 254 | 255 | def generate( 256 | self, 257 | n_generate: int = 1, 258 | ): 259 | """ 260 | Generate new data. 261 | 262 | Parameters 263 | ---------- 264 | n_generate : int > 0 265 | Desired number of samples. 266 | 267 | Returns 268 | ------- 269 | X : np.ndarray of size (n_generate, n_dims) 270 | Generated data. 271 | """ 272 | _, n_dims = self.X_.shape 273 | X = np.full(fill_value=np.nan, shape=(n_generate, n_dims)) 274 | for d in range(n_dims): 275 | if self.constant_vals_[d] is not None: 276 | X[:, d] = self.constant_vals_[d] 277 | genX = self.impute(n_impute=1, X=X)[0, :, :] 278 | return genX 279 | 280 | def impute( 281 | self, 282 | n_impute: int = 1, 283 | X: np.ndarray = None, 284 | ): 285 | """ 286 | Generate new data. 287 | 288 | Parameters 289 | ---------- 290 | n_impute : int > 0 291 | Desired number multiple imputations per sample. 292 | 293 | X : np.ndarray of size (n_samples, n_dims) or None 294 | Data to impute. If None, uses the data passed to fit. 295 | 296 | Returns 297 | ------- 298 | imputedX : np.ndarray of size (n_impute, n_samples, n_dims) 299 | Imputed data. 300 | """ 301 | if X is None: 302 | X = self.X_.copy() 303 | (n_samples, n_dims) = X.shape 304 | rng = check_random_state(self.random_state_) 305 | 306 | for d in range(n_dims): 307 | if self.constant_vals_[d] is not None: 308 | # Only replace nans with constant vals, because this is impute, ya know 309 | X[np.isnan(X[:, d]), d] = self.constant_vals_[d] 310 | 311 | imputedX = np.repeat( 312 | X[np.newaxis, :, :], repeats=n_impute, axis=0 313 | ) # (n_impute, n_samples, n_dims) 314 | with tqdm(total=n_samples * n_impute) as pbar: 315 | for n in range(n_samples): 316 | to_unmask = np.where(np.isnan(X[n, :]))[0] # (n_to_unmask,) 317 | unmask_ixs = np.repeat( 318 | to_unmask[np.newaxis, :], n_impute, axis=0 319 | ) # (n_impute, n_to_unmask) 320 | unmask_ixs = np.apply_along_axis( 321 | rng.permutation, axis=1, arr=unmask_ixs 322 | ) # (n_impute, n_to_unmask) 323 | n_to_unmask = unmask_ixs.shape[1] 324 | for kix in range(n_impute): 325 | pbar.update(1) 326 | for dix in range(n_to_unmask): 327 | unmask_ix = unmask_ixs[kix, dix] 328 | if self.quantize_cols_[unmask_ix]: 329 | pred_val = self.trees_[unmask_ix].sample( 330 | imputedX[kix, [n], :] 331 | ) 332 | else: 333 | # TODO: use TabPFN if requested 334 | pred_probas = self.trees_[unmask_ix].predict_proba( 335 | imputedX[kix, [n], :] 336 | ) 337 | with np.errstate(divide="ignore"): 338 | annealed_logits = ( 339 | np.log(pred_probas) / self.softmax_temp 340 | ) 341 | pred_probas = np.exp(annealed_logits) / np.sum( 342 | np.exp(annealed_logits), axis=1 343 | ) 344 | cur_enc = self.encoders_[unmask_ix] 345 | pred_enc = rng.choice( 346 | a=len(cur_enc.classes_), p=pred_probas.ravel() 347 | ) 348 | pred_val = cur_enc.inverse_transform(np.array([pred_enc])) 349 | imputedX[kix, n, unmask_ix] = pred_val.item() 350 | return imputedX 351 | 352 | def score_samples( 353 | self, 354 | X: np.ndarray, 355 | n_evals: int = 1, 356 | ): 357 | """Compute the log-likelihood of each sample under the model. 358 | 359 | Parameters 360 | ---------- 361 | X : array-like of shape (n_samples, n_features) 362 | An array of points to query. Last dimension should match dimension 363 | of training data (n_features). 364 | 365 | n_evals : int > 0 366 | 367 | Returns 368 | ------- 369 | density : ndarray of shape (n_samples,) 370 | Log-likelihood of each sample in `X`. These are normalized to be 371 | probability densities, so values will be low for high-dimensional 372 | data. 373 | """ 374 | (n_samples, n_dims) = X.shape 375 | rng = check_random_state(self.random_state_) 376 | 377 | density = np.zeros((n_samples,)) 378 | for n in range(n_samples): 379 | cur_density = np.zeros((n_evals,)) 380 | for k in range(n_evals): 381 | eval_order = rng.permutation(n_dims) 382 | for dix in range(n_dims): 383 | eval_ix = eval_order[dix] 384 | if not np.isnan(X[n, eval_ix]): 385 | evalX = X[[n], :].copy() 386 | evalX[:, eval_order[dix:]] = np.nan 387 | evaly = X[[n], eval_ix] 388 | if self.quantize_cols_[eval_ix]: 389 | cur_density[k] += ( 390 | self.trees_[eval_ix].score_samples(evalX, evaly).item() 391 | ) 392 | else: 393 | probas = self.trees_[eval_ix].predict_proba(evalX).ravel() 394 | cur_enc = self.encoders_[eval_ix] 395 | true_class = cur_enc.transform(evaly.reshape(1, 1)).item() 396 | cur_density[k] += np.log(probas[true_class]) 397 | density[n] = np.mean(cur_density) 398 | return density 399 | -------------------------------------------------------------------------------- /Results/generation_script_nmiss50.R: -------------------------------------------------------------------------------- 1 | options(show.error.locations = TRUE) 2 | options(repos=c(CRAN = "https://cloud.r-project.org")) 3 | install.packages('matrixStats') 4 | install.packages('xtable') 5 | install.packages('plotrix') 6 | library(matrixStats) 7 | library(xtable) 8 | library(plotrix) 9 | 10 | all_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'california', 'bean', 'car','congress','tictactoe') 11 | W_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'car','congress','tictactoe') 12 | R2_vars = c('concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'wine_quality_red', 'wine_quality_white', 'california') 13 | F1_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'bean', 'car','congress','tictactoe') 14 | 15 | methods = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'utrees') 16 | methods_all = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'oracle', 'utrees') 17 | 18 | # Change the folder to your own 19 | data = read.csv("Results_tabular - generation_miss.csv") 20 | 21 | # Better scaling for clean tables 22 | data$percent_bias = data$percent_bias / 100 23 | 24 | #### CHOOOSE HERE 25 | 26 | #data = data[data$missingness!='MCAR(0.2 MissForest)',] # miceforest results 27 | data = data[data$missingness!='MCAR(0.2 miceforest)',] # MissForest results 28 | 29 | #### 30 | 31 | data_summary = data.frame(matrix(ncol = 10, nrow = 10)) 32 | colnames(data_summary) = methods_all 33 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time") 34 | 35 | for (method in methods_all){ 36 | W_train_mean = mean(data$W_train[data$method==method & data$dataset %in% W_vars]) 37 | W_train_sd = std.error(data$W_train[data$method==method & data$dataset %in% W_vars]) 38 | 39 | W_test_mean = mean(data$W_test[data$method==method & data$dataset %in% W_vars]) 40 | W_test_sd = std.error(data$W_test[data$method==method & data$dataset %in% W_vars]) 41 | 42 | coverage_mean = mean(data$coverage_train[data$method==method & data$dataset %in% all_vars]) 43 | coverage_sd = std.error(data$coverage_train[data$method==method & data$dataset %in% all_vars]) 44 | 45 | coverage_test_mean = mean(data$coverage_test[data$method==method & data$dataset %in% all_vars]) 46 | coverage_test_sd = std.error(data$coverage_test[data$method==method & data$dataset %in% all_vars]) 47 | 48 | R2_fake_mean = mean(data$R2_fake[data$method==method & data$dataset %in% R2_vars]) 49 | R2_fake_sd = std.error(data$R2_fake[data$method==method & data$dataset %in% R2_vars]) 50 | 51 | F1_fake_mean = mean(data$F1_fake[data$method==method & data$dataset %in% F1_vars]) 52 | F1_fake_sd = std.error(data$F1_fake[data$method==method & data$dataset %in% F1_vars]) 53 | 54 | percent_bias_mean = mean(data$percent_bias[data$method==method & data$dataset %in% R2_vars]) 55 | percent_bias_sd = std.error(data$percent_bias[data$method==method & data$dataset %in% R2_vars]) 56 | 57 | coverage_rate_mean = mean(data$coverage_rate[data$method==method & data$dataset %in% R2_vars]) 58 | coverage_rate_sd = std.error(data$coverage_rate[data$method==method & data$dataset %in% R2_vars]) 59 | 60 | class_score_mean = mean(data$class_score[data$method==method & data$dataset %in% all_vars]) 61 | class_score_sd = std.error(data$class_score[data$method==method & data$dataset %in% all_vars]) 62 | 63 | time_mean = mean(data$time[data$method==method & data$dataset %in% all_vars]) 64 | time_sd = std.error(data$time[data$method==method & data$dataset %in% all_vars]) 65 | 66 | data_summary[method] = c( 67 | paste0(as.character(round(W_train_mean, 2)), ' (', as.character(round(W_train_sd, 2)), ')'), 68 | paste0(as.character(round(W_test_mean, 2)), ' (', as.character(round(W_test_sd, 2)), ')'), 69 | paste0(as.character(round(coverage_mean, 2)), ' (', as.character(round(coverage_sd, 2)), ')'), 70 | paste0(as.character(round(coverage_test_mean, 2)), ' (', as.character(round(coverage_test_sd, 2)), ')'), 71 | paste0(as.character(round(R2_fake_mean, 2)), ' (', as.character(round(R2_fake_sd, 2)), ')'), 72 | paste0(as.character(round(F1_fake_mean, 2)), ' (', as.character(round(F1_fake_sd, 2)), ')'), 73 | paste0(as.character(round(class_score_mean, 2)), ' (', as.character(round(class_score_sd, 2)), ')'), 74 | paste0(as.character(round(percent_bias_mean, 2)), ' (', as.character(round(percent_bias_sd, 2)), ')'), 75 | paste0(as.character(round(coverage_rate_mean, 2)), ' (', as.character(round(coverage_rate_sd, 2)), ')'), 76 | paste0(as.character(round(time_mean, 2)), ' (', as.character(round(time_sd, 2)), ')') 77 | ) 78 | 79 | } 80 | data_summary 81 | data_summary_t <- t(data_summary) 82 | data_summary_t 83 | 84 | 85 | 86 | ###################### RANK 87 | 88 | data_summary_rank = data.frame(matrix(ncol = 9, nrow = 10)) 89 | colnames(data_summary_rank) = methods 90 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time") 91 | 92 | W_train = data.frame(matrix(ncol = 9, nrow = length(W_vars))) 93 | colnames(W_train) = methods 94 | rownames(W_train) = W_vars 95 | W_test = data.frame(matrix(ncol = 9, nrow = length(W_vars))) 96 | colnames(W_test) = methods 97 | rownames(W_test) = W_vars 98 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 99 | colnames(coverage) = methods 100 | rownames(coverage) = all_vars 101 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 102 | colnames(coverage_test) = methods 103 | rownames(coverage_test) = all_vars 104 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(R2_vars))) 105 | colnames(R2_fake) = methods 106 | rownames(R2_fake) = R2_vars 107 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(F1_vars))) 108 | colnames(F1_fake) = methods 109 | rownames(F1_fake) = F1_vars 110 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(R2_vars))) 111 | colnames(percent_bias) = methods 112 | rownames(percent_bias) = R2_vars 113 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(R2_vars))) 114 | colnames(coverage_rate) = methods 115 | rownames(coverage_rate) = R2_vars 116 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 117 | colnames(ClassScore) = methods 118 | rownames(ClassScore) = all_vars 119 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 120 | colnames(Time) = methods 121 | rownames(Time) = all_vars 122 | 123 | for (method in methods){ 124 | 125 | W_train[method] = data$W_train[data$method==method & data$dataset %in% W_vars] 126 | W_test[method] = data$W_test[data$method==method & data$dataset %in% W_vars] 127 | 128 | coverage[method] = -data$coverage_train[data$method==method & data$dataset %in% all_vars] 129 | coverage_test[method] = -data$coverage_test[data$method==method & data$dataset %in% all_vars] 130 | 131 | R2_fake[method] = -data$R2_fake[data$method==method & data$dataset %in% R2_vars] 132 | F1_fake[method] = -data$F1_fake[data$method==method & data$dataset %in% F1_vars] 133 | percent_bias[method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars] 134 | coverage_rate[method] = -data$coverage_rate[data$method==method & data$dataset %in% R2_vars] 135 | ClassScore[method] = data$class_score[data$method==method & data$dataset %in% all_vars] 136 | Time[method] = data$time[data$method==method & data$dataset %in% all_vars] 137 | 138 | } 139 | 140 | # Rank by dataset 141 | 142 | W_train_ = as.data.frame(t(sapply(as.data.frame(t(W_train)), rank))) 143 | colnames(W_train_) = colnames(W_train) 144 | W_train = W_train_ 145 | 146 | W_test_ = as.data.frame(t(sapply(as.data.frame(t(W_test)), rank))) 147 | colnames(W_test_) = colnames(W_test) 148 | W_test = W_test_ 149 | 150 | coverage_ = as.data.frame(t(sapply(as.data.frame(t(coverage)), rank))) 151 | colnames(coverage_) = colnames(coverage) 152 | coverage = coverage_ 153 | coverage_test_ = as.data.frame(t(sapply(as.data.frame(t(coverage_test)), rank))) 154 | colnames(coverage_test_) = colnames(coverage_test) 155 | coverage_test = coverage_test_ 156 | 157 | R2_fake_ = as.data.frame(t(sapply(as.data.frame(t(R2_fake)), rank))) 158 | colnames(R2_fake_) = colnames(R2_fake) 159 | R2_fake = R2_fake_ 160 | 161 | F1_fake_ = as.data.frame(t(sapply(as.data.frame(t(F1_fake)), rank))) 162 | colnames(F1_fake_) = colnames(F1_fake) 163 | F1_fake = F1_fake_ 164 | 165 | percent_bias_ = as.data.frame(t(sapply(as.data.frame(t(percent_bias)), rank))) 166 | colnames(percent_bias_) = colnames(percent_bias) 167 | percent_bias = percent_bias_ 168 | coverage_rate_ = as.data.frame(t(sapply(as.data.frame(t(coverage_rate)), rank))) 169 | colnames(coverage_rate_) = colnames(coverage_rate) 170 | coverage_rate = coverage_rate_ 171 | 172 | ClassScore_ = as.data.frame(t(sapply(as.data.frame(t(ClassScore)), rank))) 173 | colnames(ClassScore_) = colnames(ClassScore) 174 | ClassScore = ClassScore_ 175 | Time_ = as.data.frame(t(sapply(as.data.frame(t(Time)), rank))) 176 | colnames(Time_) = colnames(Time) 177 | Time = Time_ 178 | 179 | for (method in methods){ 180 | W_train_mean = mean(unlist(W_train[method])) 181 | W_train_sd = std.error(unlist(W_train[method])) 182 | 183 | W_test_mean = mean(unlist(W_test[method])) 184 | W_test_sd = std.error(unlist(W_test[method])) 185 | 186 | coverage_mean = mean(unlist(coverage[method])) 187 | coverage_sd = std.error(unlist(coverage[method])) 188 | coverage_test_mean = mean(unlist(coverage_test[method])) 189 | coverage_test_sd = std.error(unlist(coverage_test[method])) 190 | 191 | R2_fake_mean = mean(unlist(R2_fake[method])) 192 | R2_fake_sd = std.error(unlist(R2_fake[method])) 193 | 194 | F1_fake_mean = mean(unlist(F1_fake[method])) 195 | F1_fake_sd = std.error(unlist(F1_fake[method])) 196 | 197 | percent_bias_mean = mean(unlist(percent_bias[method])) 198 | percent_bias_sd = std.error(unlist(percent_bias[method])) 199 | 200 | coverage_rate_mean = mean(unlist(coverage_rate[method])) 201 | coverage_rate_sd = std.error(unlist(coverage_rate[method])) 202 | 203 | class_score_mean = mean(unlist(ClassScore[method])) 204 | class_score_sd = std.error(unlist(ClassScore[method])) 205 | 206 | time_mean = mean(unlist(Time[method])) 207 | time_sd = std.error(unlist(Time[method])) 208 | 209 | data_summary_rank[method] = c( 210 | paste0(as.character(round(W_train_mean, 1)), ' (', as.character(round(W_train_sd, 1)), ')'), 211 | paste0(as.character(round(W_test_mean, 1)), ' (', as.character(round(W_test_sd, 1)), ')'), 212 | paste0(as.character(round(coverage_mean, 1)), ' (', as.character(round(coverage_sd, 1)), ')'), 213 | paste0(as.character(round(coverage_test_mean, 1)), ' (', as.character(round(coverage_test_sd, 1)), ')'), 214 | paste0(as.character(round(R2_fake_mean, 1)), ' (', as.character(round(R2_fake_sd, 1)), ')'), 215 | paste0(as.character(round(F1_fake_mean, 1)), ' (', as.character(round(F1_fake_sd, 1)), ')'), 216 | paste0(as.character(round(class_score_mean, 1)), ' (', as.character(round(class_score_sd, 1)), ')'), 217 | paste0(as.character(round(percent_bias_mean, 1)), ' (', as.character(round(percent_bias_sd, 1)), ')'), 218 | paste0(as.character(round(coverage_rate_mean, 1)), ' (', as.character(round(coverage_rate_sd, 1)), ')'), 219 | paste0(as.character(round(time_mean, 1)), ' (', as.character(round(time_sd, 1)), ')') 220 | ) 221 | 222 | } 223 | rownames(data_summary_rank) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time") 224 | data_summary_rank 225 | data_summary_rank_t <- t(data_summary_rank) 226 | data_summary_rank_t 227 | 228 | 229 | ########### Latex tables 230 | 231 | rownames(data_summary_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTABGAN", "Stasy", "TabDDPM", "Forest-VP", "Forest-Flow", "Utrees", "Oracle") 232 | rownames(data_summary_rank_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTABGAN", "Stasy", "TabDDPM", "Forest-VP", "Forest-Flow", "UTrees") 233 | 234 | xtable(data_summary_t[,-10], type = "latex") 235 | xtable(data_summary_rank_t[,-10], type = "latex") 236 | 237 | 238 | if (FALSE) { 239 | # For building the barplot, to csv 240 | 241 | W_train = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 242 | colnames(W_train) = methods_all 243 | rownames(W_train) = all_vars 244 | W_train[,] = 0 245 | W_test = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 246 | colnames(W_test) = methods_all 247 | rownames(W_test) = all_vars 248 | W_test[,] = 0 249 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 250 | colnames(coverage) = methods_all 251 | rownames(coverage) = all_vars 252 | coverage[,] = 0 253 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 254 | colnames(coverage_test) = methods_all 255 | rownames(coverage_test) = all_vars 256 | coverage_test[,] = 0 257 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 258 | colnames(R2_fake) = methods_all 259 | rownames(R2_fake) = all_vars 260 | R2_fake[,] = 0 261 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 262 | colnames(F1_fake) = methods_all 263 | rownames(F1_fake) = all_vars 264 | F1_fake[,] = 0 265 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 266 | colnames(ClassScore) = methods_all 267 | rownames(ClassScore) = all_vars 268 | ClassScore[,] = 0 269 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 270 | colnames(percent_bias) = methods_all 271 | rownames(percent_bias) = all_vars 272 | percent_bias[,] = 0 273 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 274 | colnames(coverage_rate) = methods_all 275 | rownames(coverage_rate) = all_vars 276 | coverage_rate[,] = 0 277 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 278 | colnames(Time) = methods_all 279 | rownames(Time) = all_vars 280 | Time[,] = 0 281 | 282 | for (method in methods_all){ 283 | W_train[rownames(W_train) %in% W_vars, method] = data$W_train[data$method==method & data$dataset %in% W_vars] 284 | W_test[rownames(W_train) %in% W_vars, method] = data$W_test[data$method==method & data$dataset %in% W_vars] 285 | 286 | coverage[, method] = data$coverage_train[data$method==method & data$dataset %in% all_vars] 287 | coverage_test[, method] = data$coverage_test[data$method==method & data$dataset %in% all_vars] 288 | 289 | R2_fake[rownames(W_train) %in% R2_vars, method] = data$R2_fake[data$method==method & data$dataset %in% R2_vars] 290 | F1_fake[rownames(W_train) %in% F1_vars, method] = data$F1_fake[data$method==method & data$dataset %in% F1_vars] 291 | ClassScore[rownames(W_train) %in% all_vars, method] = data$class_score[data$method==method & data$dataset %in% all_vars] 292 | percent_bias[rownames(W_train) %in% R2_vars, method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars] 293 | coverage_rate[rownames(W_train) %in% R2_vars, method] = data$coverage_rate[data$method==method & data$dataset %in% R2_vars] 294 | } 295 | 296 | # For building the barplot, to csv 297 | write.csv(W_test, file="gen_W_nmiss.csv") 298 | write.csv(coverage_test, file="gen_cov_nmiss.csv") 299 | write.csv(R2_fake, file="gen_R2_nmiss.csv") 300 | write.csv(F1_fake, file="gen_F1_nmiss.csv") 301 | write.csv(ClassScore, file="gen_disc_nmiss.csv") 302 | write.csv(percent_bias, file="gen_pb_nmiss.csv") 303 | write.csv(coverage_rate, file="gen_cr_nmiss.csv") 304 | } 305 | -------------------------------------------------------------------------------- /Results/generation_script.R: -------------------------------------------------------------------------------- 1 | options(show.error.locations = TRUE) 2 | options(repos=c(CRAN = "https://cloud.r-project.org")) 3 | install.packages('matrixStats') 4 | install.packages('xtable') 5 | install.packages('plotrix') 6 | library(matrixStats) 7 | library(xtable) 8 | library(plotrix) 9 | 10 | # Change the folder to your own 11 | data = read.csv("Results_tabular - generation.csv2") 12 | 13 | # Better scaling for clean tables 14 | data$percent_bias = data$percent_bias / 100 15 | 16 | all_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'ecoli', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'california', 'bean', 'car','congress','tictactoe') 17 | W_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'ecoli', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'concrete_slump', 'wine_quality_red', 'wine_quality_white', 'car','congress','tictactoe') 18 | R2_vars = c('concrete_compression', 'yacht_hydrodynamics', 'airfoil_self_noise', 'wine_quality_red', 'wine_quality_white', 'california') 19 | F1_vars = c('iris', 'wine', 'parkinsons', 'climate_model_crashes', 'connectionist_bench_sonar', 'ionosphere', 'qsar_biodegradation', 'seeds', 'glass', 'ecoli', 'yeast', 'libras', 'planning_relax', 'blood_transfusion', 'breast_cancer_diagnostic', 'connectionist_bench_vowel', 'bean', 'car','congress','tictactoe') 20 | 21 | methods = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'utrees') 22 | methods_all = c('GaussianCopula', 'TVAE', 'CTGAN', 'CTABGAN', 'stasy', 'TabDDPM', 'forest_diffusion_vp_nt51_ycond', 'forest_diffusion_flow_nt51_ycond', 'utrees', 'oracle') 23 | 24 | data_summary = data.frame(matrix(ncol = 10, nrow = 10)) 25 | colnames(data_summary) = methods_all 26 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake","class_score", "percent_bias", "coverage_rate", "time") 27 | 28 | for (method in methods_all){ 29 | W_train_mean = mean(data$W_train[data$method==method & data$dataset %in% W_vars]) 30 | W_train_sd = std.error(data$W_train[data$method==method & data$dataset %in% W_vars]) 31 | 32 | W_test_mean = mean(data$W_test[data$method==method & data$dataset %in% W_vars]) 33 | W_test_sd = std.error(data$W_test[data$method==method & data$dataset %in% W_vars]) 34 | 35 | coverage_mean = mean(data$coverage_train[data$method==method & data$dataset %in% all_vars]) 36 | coverage_sd = std.error(data$coverage_train[data$method==method & data$dataset %in% all_vars]) 37 | 38 | coverage_test_mean = mean(data$coverage_test[data$method==method & data$dataset %in% all_vars]) 39 | coverage_test_sd = std.error(data$coverage_test[data$method==method & data$dataset %in% all_vars]) 40 | 41 | R2_fake_mean = mean(data$R2_fake[data$method==method & data$dataset %in% R2_vars]) 42 | R2_fake_sd = std.error(data$R2_fake[data$method==method & data$dataset %in% R2_vars]) 43 | 44 | F1_fake_mean = mean(data$F1_fake[data$method==method & data$dataset %in% F1_vars]) 45 | F1_fake_sd = std.error(data$F1_fake[data$method==method & data$dataset %in% F1_vars]) 46 | 47 | percent_bias_mean = mean(data$percent_bias[data$method==method & data$dataset %in% R2_vars]) 48 | percent_bias_sd = std.error(data$percent_bias[data$method==method & data$dataset %in% R2_vars]) 49 | 50 | coverage_rate_mean = mean(data$coverage_rate[data$method==method & data$dataset %in% R2_vars]) 51 | coverage_rate_sd = std.error(data$coverage_rate[data$method==method & data$dataset %in% R2_vars]) 52 | 53 | class_score_mean = mean(data$class_score[data$method==method & data$dataset %in% all_vars]) 54 | class_score_sd = std.error(data$class_score[data$method==method & data$dataset %in% all_vars]) 55 | 56 | time_mean = mean(data$time[data$method==method & data$dataset %in% all_vars]) 57 | time_sd = std.error(data$time[data$method==method & data$dataset %in% all_vars]) 58 | 59 | data_summary[method] = c( 60 | paste0(as.character(round(W_train_mean, 2)), ' (', as.character(round(W_train_sd, 2)), ')'), 61 | paste0(as.character(round(W_test_mean, 2)), ' (', as.character(round(W_test_sd, 2)), ')'), 62 | paste0(as.character(round(coverage_mean, 2)), ' (', as.character(round(coverage_sd, 2)), ')'), 63 | paste0(as.character(round(coverage_test_mean, 2)), ' (', as.character(round(coverage_test_sd, 2)), ')'), 64 | paste0(as.character(round(R2_fake_mean, 2)), ' (', as.character(round(R2_fake_sd, 2)), ')'), 65 | paste0(as.character(round(F1_fake_mean, 2)), ' (', as.character(round(F1_fake_sd, 2)), ')'), 66 | paste0(as.character(round(class_score_mean, 2)), ' (', as.character(round(class_score_sd, 2)), ')'), 67 | paste0(as.character(round(percent_bias_mean, 2)), ' (', as.character(round(percent_bias_sd, 2)), ')'), 68 | paste0(as.character(round(coverage_rate_mean, 2)), ' (', as.character(round(coverage_rate_sd, 2)), ')'), 69 | paste0(as.character(round(time_mean, 2)), ' (', as.character(round(time_sd, 2)), ')') 70 | ) 71 | 72 | } 73 | data_summary 74 | data_summary_t <- t(data_summary) 75 | data_summary_t 76 | 77 | 78 | ###################### RANK 79 | 80 | data_summary_rank = data.frame(matrix(ncol = 9, nrow = 10)) 81 | colnames(data_summary_rank) = methods 82 | rownames(data_summary) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time") 83 | 84 | W_train = data.frame(matrix(ncol = 9, nrow = length(W_vars))) 85 | colnames(W_train) = methods 86 | rownames(W_train) = W_vars 87 | W_test = data.frame(matrix(ncol = 9, nrow = length(W_vars))) 88 | colnames(W_test) = methods 89 | rownames(W_test) = W_vars 90 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 91 | colnames(coverage) = methods 92 | rownames(coverage) = all_vars 93 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 94 | colnames(coverage_test) = methods 95 | rownames(coverage_test) = all_vars 96 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(R2_vars))) 97 | colnames(R2_fake) = methods 98 | rownames(R2_fake) = R2_vars 99 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(F1_vars))) 100 | colnames(F1_fake) = methods 101 | rownames(F1_fake) = F1_vars 102 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(R2_vars))) 103 | colnames(percent_bias) = methods 104 | rownames(percent_bias) = R2_vars 105 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(R2_vars))) 106 | colnames(coverage_rate) = methods 107 | rownames(coverage_rate) = R2_vars 108 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 109 | colnames(ClassScore) = methods 110 | rownames(ClassScore) = all_vars 111 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 112 | colnames(Time) = methods 113 | rownames(Time) = all_vars 114 | 115 | for (method in methods){ 116 | print(method) 117 | print('W') 118 | W_train[method] = data$W_train[data$method==method & data$dataset %in% W_vars] 119 | W_test[method] = data$W_test[data$method==method & data$dataset %in% W_vars] 120 | print('cov') 121 | coverage[method] = -data$coverage_train[data$method==method & data$dataset %in% all_vars] 122 | coverage_test[method] = -data$coverage_test[data$method==method & data$dataset %in% all_vars] 123 | print('R2') 124 | R2_fake[method] = -data$R2_fake[data$method==method & data$dataset %in% R2_vars] 125 | print('F1') 126 | F1_fake[method] = -data$F1_fake[data$method==method & data$dataset %in% F1_vars] 127 | print('stats') 128 | percent_bias[method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars] 129 | coverage_rate[method] = -data$coverage_rate[data$method==method & data$dataset %in% R2_vars] 130 | print('class') 131 | ClassScore[method] = data$class_score[data$method==method & data$dataset %in% all_vars] 132 | print('time') 133 | Time[method] = data$time[data$method==method & data$dataset %in% all_vars] 134 | } 135 | 136 | # Rank by dataset 137 | 138 | W_train_ = as.data.frame(t(sapply(as.data.frame(t(W_train)), rank))) 139 | colnames(W_train_) = colnames(W_train) 140 | W_train = W_train_ 141 | 142 | W_test_ = as.data.frame(t(sapply(as.data.frame(t(W_test)), rank))) 143 | colnames(W_test_) = colnames(W_test) 144 | W_test = W_test_ 145 | 146 | coverage_ = as.data.frame(t(sapply(as.data.frame(t(coverage)), rank))) 147 | colnames(coverage_) = colnames(coverage) 148 | coverage = coverage_ 149 | coverage_test_ = as.data.frame(t(sapply(as.data.frame(t(coverage_test)), rank))) 150 | colnames(coverage_test_) = colnames(coverage_test) 151 | coverage_test = coverage_test_ 152 | 153 | R2_fake_ = as.data.frame(t(sapply(as.data.frame(t(R2_fake)), rank))) 154 | colnames(R2_fake_) = colnames(R2_fake) 155 | R2_fake = R2_fake_ 156 | 157 | F1_fake_ = as.data.frame(t(sapply(as.data.frame(t(F1_fake)), rank))) 158 | colnames(F1_fake_) = colnames(F1_fake) 159 | F1_fake = F1_fake_ 160 | 161 | percent_bias_ = as.data.frame(t(sapply(as.data.frame(t(percent_bias)), rank))) 162 | colnames(percent_bias_) = colnames(percent_bias) 163 | percent_bias = percent_bias_ 164 | coverage_rate_ = as.data.frame(t(sapply(as.data.frame(t(coverage_rate)), rank))) 165 | colnames(coverage_rate_) = colnames(coverage_rate) 166 | coverage_rate = coverage_rate_ 167 | 168 | ClassScore_ = as.data.frame(t(sapply(as.data.frame(t(ClassScore)), rank))) 169 | colnames(ClassScore_) = colnames(ClassScore) 170 | ClassScore = ClassScore_ 171 | 172 | Time_ = as.data.frame(t(sapply(as.data.frame(t(Time)), rank))) 173 | colnames(Time_) = colnames(Time) 174 | Time = Time_ 175 | 176 | for (method in methods){ 177 | W_train_mean = mean(unlist(W_train[method])) 178 | W_train_sd = std.error(unlist(W_train[method])) 179 | 180 | W_test_mean = mean(unlist(W_test[method])) 181 | W_test_sd = std.error(unlist(W_test[method])) 182 | 183 | coverage_mean = mean(unlist(coverage[method])) 184 | coverage_sd = std.error(unlist(coverage[method])) 185 | coverage_test_mean = mean(unlist(coverage_test[method])) 186 | coverage_test_sd = std.error(unlist(coverage_test[method])) 187 | 188 | R2_fake_mean = mean(unlist(R2_fake[method])) 189 | R2_fake_sd = std.error(unlist(R2_fake[method])) 190 | 191 | F1_fake_mean = mean(unlist(F1_fake[method])) 192 | F1_fake_sd = std.error(unlist(F1_fake[method])) 193 | 194 | percent_bias_mean = mean(unlist(percent_bias[method])) 195 | percent_bias_sd = std.error(unlist(percent_bias[method])) 196 | 197 | coverage_rate_mean = mean(unlist(coverage_rate[method])) 198 | coverage_rate_sd = std.error(unlist(coverage_rate[method])) 199 | 200 | class_score_mean = mean(unlist(ClassScore[method])) 201 | class_score_sd = std.error(unlist(ClassScore[method])) 202 | 203 | time_mean = mean(unlist(Time[method])) 204 | time_sd = std.error(unlist(Time[method])) 205 | 206 | data_summary_rank[method] = c( 207 | paste0(as.character(round(W_train_mean, 1)), ' (', as.character(round(W_train_sd, 1)), ')'), 208 | paste0(as.character(round(W_test_mean, 1)), ' (', as.character(round(W_test_sd, 1)), ')'), 209 | paste0(as.character(round(coverage_mean, 1)), ' (', as.character(round(coverage_sd, 1)), ')'), 210 | paste0(as.character(round(coverage_test_mean, 1)), ' (', as.character(round(coverage_test_sd, 1)), ')'), 211 | paste0(as.character(round(R2_fake_mean, 1)), ' (', as.character(round(R2_fake_sd, 1)), ')'), 212 | paste0(as.character(round(F1_fake_mean, 1)), ' (', as.character(round(F1_fake_sd, 1)), ')'), 213 | paste0(as.character(round(class_score_mean, 1)), ' (', as.character(round(class_score_sd, 1)), ')'), 214 | paste0(as.character(round(percent_bias_mean, 1)), ' (', as.character(round(percent_bias_sd, 1)), ')'), 215 | paste0(as.character(round(coverage_rate_mean, 1)), ' (', as.character(round(coverage_rate_sd, 1)), ')'), 216 | paste0(as.character(round(time_mean, 1)), ' (', as.character(round(time_sd, 1)), ')') 217 | 218 | ) 219 | } 220 | 221 | rownames(data_summary_rank) = c("W_train", "W_test", "coverage_train", "coverage_test", "R2_fake", "F1_fake", "class_score", "percent_bias", "coverage_rate", "time") 222 | data_summary_rank 223 | data_summary_rank_t <- t(data_summary_rank) 224 | data_summary_rank_t 225 | 226 | 227 | 228 | ############################## Latex tables 229 | 230 | 231 | rownames(data_summary_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTAB-GAN+", "STaSy", "TabDDPM", "Forest-VP", "Forest-Flow", "UTrees", "Oracle") 232 | rownames(data_summary_rank_t) = c("GaussianCopula", "TVAE", "CTGAN", "CTAB-GAN+", "STaSy", "TabDDPM", "Forest-VP", "Forest-Flow", "UTrees") 233 | 234 | xtable(data_summary_t[,-10], type = "latex") 235 | xtable(data_summary_rank_t[,-10], type = "latex") 236 | 237 | 238 | if (FALSE) { 239 | ### ablation 240 | 241 | data_x = read.csv('C:/Users/Alexia-Mini/Downloads/Results_tabular - ablation_iris.csv') 242 | xtable(data_x, type = "latex") 243 | 244 | # For building the barplot, to csv 245 | 246 | W_train = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 247 | colnames(W_train) = methods_all 248 | rownames(W_train) = all_vars 249 | W_train[,] = 0 250 | W_test = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 251 | colnames(W_test) = methods_all 252 | rownames(W_test) = all_vars 253 | W_test[,] = 0 254 | coverage = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 255 | colnames(coverage) = methods_all 256 | rownames(coverage) = all_vars 257 | coverage[,] = 0 258 | coverage_test = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 259 | colnames(coverage_test) = methods_all 260 | rownames(coverage_test) = all_vars 261 | coverage_test[,] = 0 262 | R2_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 263 | colnames(R2_fake) = methods_all 264 | rownames(R2_fake) = all_vars 265 | R2_fake[,] = 0 266 | F1_fake = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 267 | colnames(F1_fake) = methods_all 268 | rownames(F1_fake) = all_vars 269 | F1_fake[,] = 0 270 | ClassScore = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 271 | colnames(ClassScore) = methods_all 272 | rownames(ClassScore) = all_vars 273 | ClassScore[,] = 0 274 | percent_bias = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 275 | colnames(percent_bias) = methods_all 276 | rownames(percent_bias) = all_vars 277 | percent_bias[,] = 0 278 | coverage_rate = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 279 | colnames(coverage_rate) = methods_all 280 | rownames(coverage_rate) = all_vars 281 | coverage_rate[,] = 0 282 | Time = data.frame(matrix(ncol = 9, nrow = length(all_vars))) 283 | colnames(Time) = methods_all 284 | rownames(Time) = all_vars 285 | Time[,] = 0 286 | 287 | for (method in methods_all){ 288 | W_train[rownames(W_train) %in% W_vars, method] = data$W_train[data$method==method & data$dataset %in% W_vars] 289 | W_test[rownames(W_train) %in% W_vars, method] = data$W_test[data$method==method & data$dataset %in% W_vars] 290 | 291 | coverage[, method] = data$coverage_train[data$method==method & data$dataset %in% all_vars] 292 | coverage_test[, method] = data$coverage_test[data$method==method & data$dataset %in% all_vars] 293 | 294 | R2_fake[rownames(W_train) %in% R2_vars, method] = data$R2_fake[data$method==method & data$dataset %in% R2_vars] 295 | F1_fake[rownames(W_train) %in% F1_vars, method] = data$F1_fake[data$method==method & data$dataset %in% F1_vars] 296 | ClassScore[rownames(W_train) %in% all_vars, method] = data$class_score[data$method==method & data$dataset %in% all_vars] 297 | percent_bias[rownames(W_train) %in% R2_vars, method] = data$percent_bias[data$method==method & data$dataset %in% R2_vars] 298 | coverage_rate[rownames(W_train) %in% R2_vars, method] = data$coverage_rate[data$method==method & data$dataset %in% R2_vars] 299 | } 300 | 301 | # For building the barplot, to csv 302 | write.csv(W_test, file="gen_W.csv") 303 | write.csv(coverage_test, file="gen_cov.csv") 304 | write.csv(R2_fake, file="gen_R2.csv") 305 | write.csv(F1_fake, file="gen_F1.csv") 306 | write.csv(ClassScore, file="gen_disc.csv") 307 | write.csv(percent_bias, file="gen_pb.csv") 308 | write.csv(coverage_rate, file="gen_cr.csv") 309 | } 310 | --------------------------------------------------------------------------------