├── .github └── workflows │ ├── publish.yaml │ └── publish_macos.yaml ├── .readthedocs.yaml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── build_wheel.sh ├── docs └── source │ ├── _static │ └── style.css │ ├── api │ ├── modules.rst │ ├── sure.distance_metrics.rst │ ├── sure.privacy.rst │ ├── sure.report_generator.rst │ └── sure.rst │ ├── conf.py │ ├── doc_1.md │ ├── doc_2.md │ ├── img │ ├── SURE_workflow_.png │ ├── cb_white_logo_compact.png │ ├── favicon.ico │ └── sure_logo_nobg.png │ └── index.rst ├── make.bat ├── project.toml ├── requirements.txt ├── setup.py ├── sure ├── __init__.py ├── _lazypredict.py ├── distance_metrics │ ├── __init__.py │ ├── distance.py │ └── gower_matrix_c.pyx ├── privacy │ ├── __init__.py │ └── privacy.py ├── report_generator │ ├── .streamlit │ │ └── config.toml │ ├── __init__.py │ ├── pages │ │ └── privacy.py │ ├── report_app.py │ └── report_generator.py └── utility.py ├── tests ├── resources │ ├── dataset.csv │ ├── synthetic_dataset.csv │ └── validation_dataset.csv └── test_sure.py └── tutorials ├── data └── census_dataset │ ├── census_dataset_synthetic.csv │ ├── census_dataset_training.csv │ └── census_dataset_validation.csv └── sure_tutorial.ipynb /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Libraty to PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12"] 15 | include: 16 | - python-version: "3.10" 17 | python-path: "/opt/python/cp310-cp310/bin" 18 | - python-version: "3.11" 19 | python-path: "/opt/python/cp311-cp311/bin" 20 | - python-version: "3.12" 21 | python-path: "/opt/python/cp312-cp312/bin" 22 | 23 | steps: 24 | - name: Checkout repository 25 | uses: actions/checkout@v2 26 | 27 | - name: Set up Python for tests 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install test dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install -r requirements.txt 36 | pip install pytest 37 | 38 | - name: Build Python wheels 39 | run: | 40 | python setup.py bdist_wheel 41 | 42 | - name: Install wheel 43 | run: | 44 | pip install dist/*.whl 45 | - name: Run tests 46 | run: | 47 | pytest tests/ 48 | 49 | - name: Set up Docker Buildx 50 | uses: docker/setup-buildx-action@v1 51 | 52 | - name: Set up QEMU 53 | uses: docker/setup-qemu-action@v1 54 | 55 | - name: Build Docker image 56 | run: | 57 | docker build -t my-python-wheels . 58 | 59 | - name: Run build script in Docker container 60 | run: | 61 | docker run --rm \ 62 | -e PYTHON_VERSION=${{ matrix.python-version }} \ 63 | -e PYTHON_PATH=${{ matrix.python-path }} \ 64 | -v ${{ github.workspace }}:/io \ 65 | my-python-wheels /build_wheel.sh 66 | 67 | # - name: Check if version exists on PyPI 68 | # id: check-version 69 | # run: | 70 | # PACKAGE_NAME=$(python setup.py --name) 71 | # PACKAGE_VERSION=$(python setup.py --version) 72 | # if curl --silent -f https://pypi.org/project/${PACKAGE_NAME}/${PACKAGE_VERSION}/; then 73 | # echo "Version ${PACKAGE_VERSION} already exists on PyPI." 74 | # echo "already_published=true" >> $GITHUB_ENV 75 | # else 76 | # echo "Version ${PACKAGE_VERSION} does not exist on PyPI." 77 | # echo "already_published=false" >> $GITHUB_ENV 78 | # fi 79 | 80 | # - name: Upload Python wheels as artifacts 81 | # uses: actions/upload-artifact@v4 82 | # with: 83 | # name: python-wheels 84 | # path: dist/*.whl 85 | 86 | - name: Publish wheels to PyPI 87 | # if: env.already_published == 'false' 88 | uses: pypa/gh-action-pypi-publish@v1.4.2 89 | with: 90 | user: __token__ 91 | password: ${{ secrets.PYPI_TOKEN }} 92 | 93 | 94 | -------------------------------------------------------------------------------- /.github/workflows/publish_macos.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Library to PyPI (MacOS) 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ${{ matrix.os }} 11 | 12 | strategy: 13 | matrix: 14 | os: [macos-latest] 15 | python-version: ["3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - name: Checkout repository 19 | uses: actions/checkout@v2 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install wheel setuptools 30 | pip install -r requirements.txt 31 | 32 | - name: Build Python wheels 33 | run: | 34 | python setup.py bdist_wheel 35 | 36 | 37 | # - name: Upload Python wheels as artifacts 38 | # uses: actions/upload-artifact@v4 39 | # with: 40 | # name: python-wheels 41 | # path: dist/*.whl 42 | 43 | - name: Publish on PyPI 44 | run: | 45 | pip install twine 46 | export TWINE_USERNAME=__token__ 47 | export TWINE_PASSWORD=${{ secrets.PYPI_TOKEN }} 48 | twine upload dist/*.whl -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version, and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | formats: all 24 | 25 | # Optional but recommended, declare the Python requirements required 26 | # to build your documentation 27 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 28 | python: 29 | install: 30 | - requirements: requirements.txt 31 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM quay.io/pypa/manylinux2014_x86_64 2 | 3 | ARG PYTHON_VERSION 4 | 5 | RUN /opt/python/cp310-cp310/bin/pip install -U pip setuptools wheel Cython 6 | RUN /opt/python/cp311-cp311/bin/pip install -U pip setuptools wheel Cython 7 | RUN /opt/python/cp312-cp312/bin/pip install -U pip setuptools wheel Cython 8 | 9 | COPY build_wheel.sh /build_wheel.sh 10 | RUN chmod +x /build_wheel.sh 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = docs/source 9 | BUILDDIR = docs/build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![Documentation Status](https://readthedocs.org/projects/clearbox-sure/badge/?version=latest)](https://clearbox-sure.readthedocs.io/en/latest/?badge=latest) 3 | [![PyPI](https://badge.fury.io/py/clearbox-sure.svg)](https://badge.fury.io/py/clearbox-sure) 4 | [![Downloads](https://pepy.tech/badge/clearbox-sure)](https://pepy.tech/project/clearbox-sure) 5 | [![GitHub stars](https://img.shields.io/github/stars/Clearbox-AI/SURE?style=social)](https://github.com/Clearbox-AI/SURE) 6 | 7 | 8 | 9 | ### Synthetic Data: Utility, Regulatory compliance, and Ethical privacy 10 | 11 | The SURE package is an open-source Python library intended to be used for the assessment of the utility and privacy performance of any tabular synthetic dataset. 12 | 13 | The SURE library features multiple Python modules that can be easily imported and seamlessly integrated into any Python script after installing the library. 14 | 15 | > [!WARNING] 16 | > This is a beta version of the library and only runs on Linux and MacOS for the moment. 17 | 18 | > [!IMPORTANT] 19 | > Requires Python >= 3.10 20 | 21 | # Installation 22 | 23 | To install the library run the following command in your terminal: 24 | 25 | ```shell 26 | $ pip install clearbox-sure 27 | ``` 28 | 29 | # Modules overview 30 | 31 | The SURE library features the following modules: 32 | 33 | 1. Preprocessor 34 | 2. Statistical similarity metrics 35 | 3. Model garden 36 | 4. ML utility metrics 37 | 5. Distance metrics 38 | 6. Privacy attack sandbox 39 | 7. Report generator 40 | 41 | **Preprocessor** 42 | 43 | The input datasets undergo manipulation by the preprocessor module, tailored to conform to the standard structure utilized across the subsequent processes. The Polars library used in the preprocessor makes this operation significantly faster compared to the use of other data processing libraries. 44 | 45 | **Utility** 46 | 47 | The statistical similarity metrics, the ML utility metrics and the model garden modules constitute the data **utility evaluation** part. 48 | 49 | The statistical similarity module and the distance metrics module take as input the pre-processed datasets and carry out the operation to assess the statistical similarity between the datasets and how different the content of the synthetic dataset is from the one of the original dataset. In particular, The real and synthetic input datasets are used in the statistical similarity metrics module to assess how close the two datasets are in terms of statistical properties, such as mean, correlation, distribution. 50 | 51 | The model garden executes a classification or regression task on the given dataset with multiple machine learning models, returning the performance metrics of each of the models tested on the given task and dataset. 52 | 53 | The model garden module’s best performing models are employed in the machine learning utility metrics module to compute the usefulness of the synthetic data on a given ML task (classification or regression). 54 | 55 | **Privacy** 56 | 57 | The distance metrics and the privacy attack sandbox make up the synthetic data **privacy assessment** modules. 58 | 59 | The distance metrics module computes the Gower distance between the two input datasets and the distance to the closest record for each line of the first dataset. 60 | 61 | The ML privacy attack sandbox allows to simulate a Membership Inference Attack for re-identification of vulnerable records identified with the distance metrics module and evaluate how exposed the synthetic dataset is to this kind of assault. 62 | 63 | **Report** 64 | 65 | Eventually, the report generator provides a summary of the utility and privacy metrics computed in the previous modules, providing a visual digest with charts and tables of the results. 66 | 67 | This following diagram serves as a visual representation of how each module contributes to the utility-privacy assessment process and highlights the seamless interconnection and synergy between individual blocks. 68 | 69 | drawing 70 | 71 | # Usage 72 | 73 | The library leverages Polars, which ensures faster computations compared to other data manipulation libraries. It supports both Polars and Pandas dataframes. 74 | 75 | The user must provide both the original real training dataset (which was used to train the generative model that produced the synthetic dataset), the real holdout dataset (which was NOT used to train the generative model that produced the synthetic dataset) and the corresponding synthetic dataset to enable the library's modules to perform the necessary computations for evaluation. 76 | 77 | Below is a code snippet example for the usage of the library: 78 | 79 | ```python 80 | # Import the necessary modules from the SURE library 81 | from sure import Preprocessor, report 82 | from sure.utility import (compute_statistical_metrics, compute_mutual_info, 83 | compute_utility_metrics_class, 84 | detection, 85 | query_power) 86 | from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test, 87 | adversary_dataset, membership_inference_test) 88 | 89 | # Assuming real_data, valid_data and synth_data are three pandas DataFrames 90 | 91 | # Preprocessor initialization and query execution on the real, synthetic and validation datasets 92 | preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize') 93 | 94 | real_data_preprocessed = preprocessor.transform(real_data) 95 | valid_data_preprocessed = preprocessor.transform(valid_data) 96 | synth_data_preprocessed = preprocessor.transform(synth_data) 97 | 98 | # Statistical properties and mutual information 99 | num_features_stats, cat_features_stats, temporal_feat_stats = compute_statistical_metrics(real_data, synth_data) 100 | corr_real, corr_synth, corr_difference = compute_mutual_info(real_data_preprocessed, synth_data_preprocessed) 101 | 102 | # ML utility: TSTR - Train on Synthetic, Test on Real 103 | X_train = real_data_preprocessed.drop("label", axis=1) # Assuming the datasets have a “label” column for the machine learning task they are intended for 104 | y_train = real_data_preprocessed["label"] 105 | X_synth = synth_data_preprocessed.drop("label", axis=1) 106 | y_synth = synth_data_preprocessed["label"] 107 | X_test = valid_data_preprocessed.drop("label", axis=1).limit(10000) # Test the trained models on a portion of the original real dataset (first 10k rows) 108 | y_test = valid_data_preprocessed["label"].limit(10000) 109 | TSTR_metrics = compute_utility_metrics_class(X_train, X_synth, X_test, y_train, y_synth, y_test) 110 | 111 | # Distance to closest record 112 | dcr_synth_train = distance_to_closest_record("synth_train", synth_data, real_data) 113 | dcr_synth_valid = distance_to_closest_record("synth_val", synth_data, valid_data) 114 | dcr_stats_synth_train = dcr_stats("synth_train", dcr_synth_train) 115 | dcr_stats_synth_valid = dcr_stats("synth_val", dcr_synth_valid) 116 | dcr_zero_synth_train = number_of_dcr_equal_to_zero("synth_train", dcr_synth_train) 117 | dcr_zero_synth_valid = number_of_dcr_equal_to_zero("synth_val", dcr_synth_valid) 118 | share = validation_dcr_test(dcr_synth_train, dcr_synth_valid) 119 | 120 | # Detection Score 121 | detection_score = detection(real_data, synth_data, preprocessor) 122 | 123 | # Query Power 124 | query_power_score = query_power(real_data, synth_data, preprocessor) 125 | 126 | # ML privacy attack sandbox initialization and simulation 127 | adversary_df = adversary_dataset(real_data, valid_data) 128 | # The function adversary_dataset adds a column "privacy_test_is_training" to the adversary dataset, indicating whether the record was part of the training set or not 129 | adversary_guesses_ground_truth = adversary_df["privacy_test_is_training"] 130 | MIA = membership_inference_test(adversary_dfv, synth_data, adversary_guesses_ground_truth) 131 | 132 | # Report generation as HTML page 133 | report(real_data, synth_data) 134 | ``` 135 | 136 | Follow the step-by-step [guide](https://github.com/Clearbox-AI/SURE/tree/main/examples) to test the library. 137 | 138 | 139 | -------------------------------------------------------------------------------- /build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | cd /io 5 | 6 | ${PYTHON_PATH}/pip install -U pip setuptools wheel Cython 7 | ${PYTHON_PATH}/pip install -r requirements.txt 8 | ${PYTHON_PATH}/python setup.py bdist_wheel 9 | 10 | # Prepare output directory 11 | mkdir -p /io/wheelhouse 12 | 13 | # Repair the wheels with auditwheel 14 | for whl in dist/*.whl; do 15 | auditwheel repair "$whl" --plat manylinux2014_x86_64 -w /io/wheelhouse/ 16 | done 17 | 18 | # Replace dist with repaired wheels 19 | rm -rf dist 20 | mv wheelhouse dist 21 | -------------------------------------------------------------------------------- /docs/source/_static/style.css: -------------------------------------------------------------------------------- 1 | @import url("https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap"); 2 | @import url('https://fonts.googleapis.com/css2?family=Be+Vietnam+Pro:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap'); 3 | 4 | .be-vietnam-pro-regular { 5 | font-family: "Be Vietnam Pro", serif; 6 | font-weight: 400; 7 | font-style: normal; 8 | } 9 | 10 | body { 11 | font-family: "Be Vietnam Pro", serif; 12 | color: #212121; 13 | } 14 | 15 | footer { 16 | color: #646369; 17 | } 18 | 19 | h1, h2, h3, h4, h5, h6 { 20 | font-family: "Be Vietnam Pro", serif; 21 | font-weight: 600; 22 | } 23 | 24 | .wy-nav-content-wrap { 25 | background-color: #fafafa; 26 | } 27 | .wy-nav-content { 28 | background-color: #ffffff; 29 | max-width: 60rem; 30 | min-height: 100vh; 31 | } 32 | .wy-nav-side, 33 | .wy-nav-top { 34 | background-color: #483a8f; 35 | } 36 | 37 | .wy-form input[type="text"] { 38 | border-radius: 0; 39 | border: none; 40 | } 41 | 42 | /** This fixes the horizontal scroll on mobile */ 43 | @media only screen and (max-width: 770px) { 44 | .py.class dt, 45 | .py.function dt { 46 | display: block !important; 47 | overflow-x: auto; 48 | } 49 | } -------------------------------------------------------------------------------- /docs/source/api/modules.rst: -------------------------------------------------------------------------------- 1 | sure 2 | ==== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | sure 8 | -------------------------------------------------------------------------------- /docs/source/api/sure.distance_metrics.rst: -------------------------------------------------------------------------------- 1 | sure.distance\_metrics package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | sure.distance\_metrics.distance module 8 | -------------------------------------- 9 | 10 | .. automodule:: sure.distance_metrics.distance 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | sure.distance\_metrics.gower\_matrix\_c module 16 | ---------------------------------------------- 17 | 18 | .. automodule:: sure.distance_metrics.gower_matrix_c 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: sure.distance_metrics 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/api/sure.privacy.rst: -------------------------------------------------------------------------------- 1 | sure.privacy package 2 | ==================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | sure.privacy.privacy module 8 | --------------------------- 9 | 10 | .. automodule:: sure.privacy.privacy 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: sure.privacy 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/api/sure.report_generator.rst: -------------------------------------------------------------------------------- 1 | sure.report\_generator package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | sure.report\_generator.report\_app module 8 | ----------------------------------------- 9 | 10 | .. automodule:: sure.report_generator.report_app 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | sure.report\_generator.report\_generator module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: sure.report_generator.report_generator 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: sure.report_generator 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/api/sure.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | .. toctree:: 4 | :maxdepth: 4 5 | 6 | .. automodule:: sure 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | Utility 12 | ======= 13 | 14 | .. automodule:: sure.utility 15 | :members: 16 | 17 | 18 | Privacy 19 | ======= 20 | .. toctree:: 21 | :maxdepth: 2 22 | 23 | .. automodule:: sure.privacy 24 | :members: 25 | 26 | .. .. automodule:: sure.distance_metrics.distance 27 | :members: -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath('../..')) # Adjust the path as needed 4 | 5 | project = 'SURE' 6 | author = 'Dario Brunelli' 7 | copyright = "2024, Clearbox AI" 8 | release = '0.1.9.9' 9 | 10 | extensions = [ 11 | 'sphinx.ext.autodoc', 12 | "sphinx.ext.coverage", 13 | 'sphinx.ext.napoleon', 14 | "myst_parser", 15 | 'sphinx_rtd_theme' 16 | ] 17 | 18 | myst_enable_extensions = [ 19 | "html_image", # Allows html images format conversion 20 | ] 21 | 22 | source_suffix = { 23 | '.rst': 'restructuredtext', 24 | '.txt': 'markdown', 25 | '.md': 'markdown', 26 | } 27 | templates_path = ['_templates'] 28 | exclude_patterns = [] 29 | 30 | html_theme = "sphinx_rtd_theme" 31 | html_theme_options = { 32 | "logo_only": True, 33 | "style_nav_header_background": "#483a8f", 34 | } 35 | html_css_files = [ 36 | 'style.css', 37 | ] 38 | html_static_path = ['_static', 'img'] 39 | html_logo = "img/cb_white_logo_compact.png" 40 | html_favicon = "img/favicon.ico" 41 | 42 | master_doc = 'index' # Ensure this points to your main document 43 | 44 | 45 | # Napoleon settings 46 | napoleon_google_docstring = True 47 | napoleon_numpy_docstring = False -------------------------------------------------------------------------------- /docs/source/doc_1.md: -------------------------------------------------------------------------------- 1 | [![Documentation Status](https://readthedocs.org/projects/clearbox-sure/badge/?version=latest)](https://clearbox-sure.readthedocs.io/en/latest/?badge=latest) 2 | [![PyPI](https://badge.fury.io/py/clearbox-sure.svg)](https://badge.fury.io/py/clearbox-sure) 3 | [![Downloads](https://pepy.tech/badge/clearbox-sure)](https://pepy.tech/project/clearbox-sure) 4 | [![GitHub stars](https://img.shields.io/github/stars/Clearbox-AI/SURE?style=social)](https://github.com/Clearbox-AI/SURE) -------------------------------------------------------------------------------- /docs/source/doc_2.md: -------------------------------------------------------------------------------- 1 | ## SURE 2 | Synthetic Data: Utility, Regulatory compliance, and Ethical privacy
3 | 4 | The SURE package is an open-source Python library for the assessment of the utility and privacy performance of any tabular **synthetic dataset**. 5 | 6 | The SURE library works both with [pandas](https://pandas.pydata.org/) and [polars](https://pola.rs/) DataFrames. 7 | 8 | ## Installation 9 | 10 | It is highly recommended to install the library in a virtual environment or container. 11 | 12 | ```shell 13 | $ pip install clearbox-sure 14 | ``` 15 | 16 | ## Usage 17 | 18 | The user must provide both the original real training dataset (which was used to train the generative model that produced the synthetic dataset), the real holdout dataset (which was NOT used to train the generative model that produced the synthetic dataset) and the corresponding synthetic dataset to enable the library's modules to perform the necessary computations for evaluation. 19 | 20 | Follow the step-by-step guide to test the library using the provided [instructions](https://github.com/Clearbox-AI/SURE/blob/main/testing/sure_test.ipynb). 21 | 22 | ```python 23 | # Import the necessary modules from the SURE library 24 | from sure import Preprocessor, report 25 | from sure.utility import (compute_statistical_metrics, compute_mutual_info, 26 | compute_utility_metrics_class, 27 | detection, 28 | query_power) 29 | from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test, 30 | adversary_dataset, membership_inference_test) 31 | 32 | # Assuming real_data, valid_data and synth_data are three pandas DataFrames 33 | 34 | # Preprocessor initialization and query execution on the real, synthetic and validation datasets 35 | preprocessor = Preprocessor(real_data) 36 | real_data_preprocessed = preprocessor.transform(real_data) 37 | valid_data_preprocessed = preprocessor.transform(valid_data) 38 | synth_data_preprocessed = preprocessor.transform(synth_data) 39 | 40 | # Statistical properties and mutual information 41 | num_features_stats, cat_features_stats, temporal_feat_stats = compute_statistical_metrics(real_data, synth_data) 42 | corr_real, corr_synth, corr_difference = compute_mutual_info(real_data_preprocessed, synth_data_preprocessed) 43 | 44 | # ML utility: TSTR - Train on Synthetic, Test on Real 45 | X_train = real_data_preprocessed.drop("label", axis=1) # Assuming the datasets have a “label” column for the machine learning task they are intended for 46 | y_train = real_data_preprocessed["label"] 47 | X_synth = synth_data_preprocessed.drop("label", axis=1) 48 | y_synth = synth_data_preprocessed["label"] 49 | X_test = valid_data_preprocessed.drop("label", axis=1).limit(10000) # Test the trained models on a portion of the original real dataset (first 10k rows) 50 | y_test = valid_data_preprocessed["label"].limit(10000) 51 | TSTR_metrics = compute_utility_metrics_class(X_train, X_synth, X_test, y_train, y_synth, y_test) 52 | 53 | # Distance to closest record 54 | dcr_synth_train = distance_to_closest_record("synth_train", synth_data, real_data) 55 | dcr_synth_valid = distance_to_closest_record("synth_val", synth_data, valid_data) 56 | dcr_stats_synth_train = dcr_stats("synth_train", dcr_synth_train) 57 | dcr_stats_synth_valid = dcr_stats("synth_val", dcr_synth_valid) 58 | dcr_zero_synth_train = number_of_dcr_equal_to_zero("synth_train", dcr_synth_train) 59 | dcr_zero_synth_valid = number_of_dcr_equal_to_zero("synth_val", dcr_synth_valid) 60 | share = validation_dcr_test(dcr_synth_train, dcr_synth_valid) 61 | 62 | # Detection Score 63 | detection_score = detection(real_data, synth_data, preprocessor) 64 | 65 | # Query Power 66 | query_power_score = query_power(real_data, synth_data, preprocessor) 67 | 68 | # ML privacy attack sandbox initialization and simulation 69 | adversary_df = adversary_dataset(real_data, valid_data) 70 | # The function adversary_dataset adds a column "privacy_test_is_training" to the adversary dataset, indicating whether the record was part of the training set or not 71 | adversary_guesses_ground_truth = adversary_df["privacy_test_is_training"] 72 | MIA = membership_inference_test(adversary_df, synth, adversary_guesses_ground_truth) 73 | 74 | # Report generation as HTML page 75 | report() 76 | ``` 77 | -------------------------------------------------------------------------------- /docs/source/img/SURE_workflow_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/SURE_workflow_.png -------------------------------------------------------------------------------- /docs/source/img/cb_white_logo_compact.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/cb_white_logo_compact.png -------------------------------------------------------------------------------- /docs/source/img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/favicon.ico -------------------------------------------------------------------------------- /docs/source/img/sure_logo_nobg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/sure_logo_nobg.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Clearbox SURE documentation master file, created by 2 | sphinx-quickstart on Mon Nov 18 15:46:14 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. include:: doc_1.md 7 | :parser: markdown 8 | 9 | .. image:: img/sure_logo_nobg.png 10 | :alt: SURE Logo 11 | :width: 450px 12 | 13 | .. include:: doc_2.md 14 | :parser: markdown 15 | 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Contents: 20 | 21 | 22 | Modules 23 | ------- 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | 28 | api/sure.rst 29 | 30 | Indices and tables 31 | ------------------ 32 | 33 | * :ref:`genindex` 34 | * :ref:`modindex` 35 | * :ref:`search` -------------------------------------------------------------------------------- /make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /project.toml: -------------------------------------------------------------------------------- 1 | # pyproject.toml (build-time deps only) 2 | [build-system] 3 | requires = [ 4 | "setuptools>=61", 5 | "wheel", 6 | "cython", 7 | "numpy" 8 | ] 9 | build-backend = "setuptools.build_meta" 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | clearbox-preprocessor 2 | setuptools 3 | pandas<=2.2.3 4 | polars<=1.17.1 5 | numpy<2.0.0 6 | cython 7 | wheel 8 | streamlit 9 | matplotlib 10 | matplotlib-inline 11 | seaborn 12 | click 13 | scikit-learn 14 | tqdm 15 | joblib 16 | lightgbm 17 | xgboost 18 | sphinx_rtd_theme 19 | myst_parser -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import numpy as np 4 | from pathlib import Path 5 | from setuptools import setup, Extension, find_packages 6 | from Cython.Build import cythonize 7 | from Cython.Distutils import build_ext 8 | 9 | # Read the content of the README.md for the long_description metadata 10 | with open("README.md", "r") as readme: 11 | long_description = readme.read() 12 | 13 | # Parse the requirements.txt file to get a list of dependencies 14 | with open("requirements.txt") as f: 15 | requirements = f.read().splitlines() 16 | 17 | # List of files to exclude from Cythonization 18 | EXCLUDE_FILES = [ 19 | "sure/utility.py", 20 | "sure/__init__.py", 21 | "sure/privacy/__init__.py", 22 | "sure/privacy/privacy.py" 23 | "sure/report_generator/report_generator.py", 24 | "sure/report_generator/report_app.py", 25 | "sure/report_generator/__init__.py", 26 | "sure/report_generator/pages/privacy.py", 27 | "sure/report_generator/.streamlit/config.toml", 28 | "sure/_lazypredict.py" 29 | ] 30 | 31 | def get_extensions_paths(root_dir, exclude_files): 32 | """ 33 | Retrieve file paths for compilation. 34 | 35 | Parameters 36 | ---------- 37 | root_dir : str 38 | Root directory to start searching for files. 39 | exclude_files : list of str 40 | List of file paths to exclude from the result. 41 | 42 | Returns 43 | ------- 44 | list of str or Extension 45 | A list containing file paths and/or Extension objects. 46 | 47 | """ 48 | paths = [] 49 | 50 | # Walk the directory to find .py and .pyx files 51 | for root, _, files in os.walk(root_dir): 52 | for filename in files: 53 | if ( 54 | os.path.splitext(filename)[1] != ".py" 55 | and os.path.splitext(filename)[1] != ".pyx" 56 | ): 57 | continue 58 | 59 | file_path = os.path.join(root, filename) 60 | 61 | if file_path in exclude_files: 62 | continue 63 | 64 | if os.path.splitext(filename)[1] == ".pyx": 65 | file_path = Extension( 66 | root.replace("/", ".").replace("\\", "."), 67 | [file_path], 68 | include_dirs=[np.get_include()], 69 | ) 70 | 71 | paths.append(file_path) 72 | 73 | return paths 74 | 75 | class CustomBuild(build_ext): 76 | """ 77 | Custom build class that inherits from Cython's build_ext. 78 | 79 | This class is created to override the default build behavior. 80 | Specifically, it ensures certain non-Cython files are copied 81 | over to the build output directory after the Cythonization process. 82 | """ 83 | 84 | def run(self): 85 | """Override the run method to copy specific files after build.""" 86 | # Run the original run method 87 | build_ext.run(self) 88 | 89 | build_dir = Path(self.build_lib) 90 | root_dir = Path(__file__).parent 91 | target_dir = build_dir if not self.inplace else root_dir 92 | 93 | # List of files to copy after the build process 94 | files_to_copy = [ 95 | "sure/distance_metrics/__init__.py", 96 | "sure/distance_metrics/gower_matrix_c.pyx", 97 | "sure/utility.py", 98 | "sure/__init__.py", 99 | "sure/privacy/__init__.py", 100 | "sure/privacy/privacy.py", 101 | "sure/report_generator/report_generator.py", 102 | "sure/report_generator/report_app.py", 103 | "sure/report_generator/__init__.py", 104 | "sure/report_generator/pages/privacy.py", 105 | "sure/report_generator/.streamlit/config.toml", 106 | ] 107 | 108 | for file in files_to_copy: 109 | self.copy_file(Path(file), root_dir, target_dir) 110 | 111 | def copy_file(self, path, source_dir, destination_dir): 112 | """ 113 | Utility method to copy files from source to destination. 114 | 115 | Parameters 116 | ---------- 117 | path : Path 118 | Path of the file to be copied. 119 | source_dir : Path 120 | Directory where the source file resides. 121 | destination_dir : Path 122 | Directory where the file should be copied to. 123 | 124 | """ 125 | src_file = source_dir / path 126 | dest_file = destination_dir / path 127 | dest_file.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists 128 | shutil.copyfile(str(src_file), str(dest_file)) 129 | 130 | # Main setup configuration 131 | setup( 132 | # Metadata about the package 133 | name="clearbox-sure", 134 | version="0.2.15", 135 | author="Clearbox AI", 136 | author_email="info@clearbox.ai", 137 | description="A utility and privacy evaluation library for synthetic data", 138 | long_description=long_description, 139 | long_description_content_type="text/markdown", 140 | url="https://github.com/Clearbox-AI/SURE", 141 | install_requires=requirements, 142 | python_requires=">=3.7.0", 143 | 144 | # Cython modules compilation 145 | ext_modules=cythonize( 146 | get_extensions_paths("sure", EXCLUDE_FILES), 147 | build_dir="build", 148 | compiler_directives=dict(language_level=3, always_allow_keywords=True), 149 | ), 150 | 151 | # Override the build command with our custom class 152 | cmdclass=dict(build_ext=CustomBuild), 153 | 154 | # List of packages included in the distribution 155 | packages=find_packages(), # Include all packages in the distribution 156 | include_package_data=True, 157 | ) 158 | -------------------------------------------------------------------------------- /sure/__init__.py: -------------------------------------------------------------------------------- 1 | from .report_generator.report_generator import report, _save_to_json 2 | from .utility import _drop_cols 3 | from clearbox_preprocessor import Preprocessor 4 | 5 | __all__ = [ 6 | "report", 7 | "_save_to_json", 8 | "_drop_cols", 9 | "Preprocessor" 10 | ] -------------------------------------------------------------------------------- /sure/_lazypredict.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is taken from the library lazypredict-nightly 3 | """ 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | import time 9 | from sklearn.pipeline import Pipeline 10 | from sklearn.impute import SimpleImputer 11 | from sklearn.preprocessing import StandardScaler, OneHotEncoder, OrdinalEncoder 12 | from sklearn.compose import ColumnTransformer 13 | from sklearn.utils import all_estimators 14 | from sklearn.base import RegressorMixin 15 | from sklearn.base import ClassifierMixin 16 | from sklearn.metrics import ( 17 | accuracy_score, 18 | balanced_accuracy_score, 19 | roc_auc_score, 20 | f1_score, 21 | r2_score, 22 | mean_squared_error, 23 | ) 24 | import warnings 25 | import xgboost 26 | import lightgbm 27 | import inspect 28 | 29 | warnings.filterwarnings("ignore") 30 | pd.set_option("display.precision", 2) 31 | pd.set_option("display.float_format", lambda x: "%.2f" % x) 32 | 33 | removed_classifiers = [ 34 | "ClassifierChain", 35 | "ComplementNB", 36 | "GradientBoostingClassifier", 37 | "GaussianProcessClassifier", 38 | "HistGradientBoostingClassifier", 39 | "MLPClassifier", 40 | "LogisticRegressionCV", 41 | "MultiOutputClassifier", 42 | "MultinomialNB", 43 | "OneVsOneClassifier", 44 | "OneVsRestClassifier", 45 | "OutputCodeClassifier", 46 | "RadiusNeighborsClassifier", 47 | "VotingClassifier", 48 | ] 49 | 50 | removed_regressors = [ 51 | "TheilSenRegressor", 52 | "ARDRegression", 53 | "CCA", 54 | "IsotonicRegression", 55 | "StackingRegressor", 56 | "MultiOutputRegressor", 57 | "MultiTaskElasticNet", 58 | "MultiTaskElasticNetCV", 59 | "MultiTaskLasso", 60 | "MultiTaskLassoCV", 61 | "PLSCanonical", 62 | "PLSRegression", 63 | "RadiusNeighborsRegressor", 64 | "RegressorChain", 65 | "VotingRegressor", 66 | ] 67 | 68 | CLASSIFIERS = [ 69 | est 70 | for est in all_estimators() 71 | if (issubclass(est[1], ClassifierMixin) and (est[0] not in removed_classifiers)) 72 | ] 73 | 74 | REGRESSORS = [ 75 | est 76 | for est in all_estimators() 77 | if (issubclass(est[1], RegressorMixin) and (est[0] not in removed_regressors)) 78 | ] 79 | 80 | REGRESSORS.append(("XGBRegressor", xgboost.XGBRegressor)) 81 | REGRESSORS.append(("LGBMRegressor", lightgbm.LGBMRegressor)) 82 | # REGRESSORS.append(('CatBoostRegressor',catboost.CatBoostRegressor)) 83 | 84 | CLASSIFIERS.append(("XGBClassifier", xgboost.XGBClassifier)) 85 | CLASSIFIERS.append(("LGBMClassifier", lightgbm.LGBMClassifier)) 86 | # CLASSIFIERS.append(('CatBoostClassifier',catboost.CatBoostClassifier)) 87 | 88 | numeric_transformer = Pipeline( 89 | steps=[("imputer", SimpleImputer(strategy="mean")), ("scaler", StandardScaler())] 90 | ) 91 | 92 | categorical_transformer_low = Pipeline( 93 | steps=[ 94 | ("imputer", SimpleImputer(strategy="constant", fill_value="missing")), 95 | ("encoding", OneHotEncoder(handle_unknown="ignore", sparse_output=False)), 96 | ] 97 | ) 98 | 99 | categorical_transformer_high = Pipeline( 100 | steps=[ 101 | ("imputer", SimpleImputer(strategy="constant", fill_value="missing")), 102 | # 'OrdianlEncoder' Raise a ValueError when encounters an unknown value. Check https://github.com/scikit-learn/scikit-learn/pull/13423 103 | ("encoding", OrdinalEncoder()), 104 | ] 105 | ) 106 | 107 | 108 | # Helper function 109 | def get_card_split(df, cols, n=11): 110 | """ 111 | Splits categorical columns into 2 lists based on cardinality (i.e # of unique values) 112 | Parameters 113 | ---------- 114 | df : Pandas DataFrame 115 | DataFrame from which the cardinality of the columns is calculated. 116 | cols : list-like 117 | Categorical columns to list 118 | n : int, optional (default=11) 119 | The value of 'n' will be used to split columns. 120 | Returns 121 | ------- 122 | card_low : list-like 123 | Columns with cardinality < n 124 | card_high : list-like 125 | Columns with cardinality >= n 126 | """ 127 | cond = df[cols].nunique() > n 128 | card_high = cols[cond] 129 | card_low = cols[~cond] 130 | return card_low, card_high 131 | 132 | 133 | # Helper class for performing classification 134 | 135 | 136 | class LazyClassifier: 137 | """ 138 | This module helps in fitting to all the classification algorithms that are available in Scikit-learn 139 | Parameters 140 | ---------- 141 | verbose : int, optional (default=0) 142 | For the liblinear and lbfgs solvers set verbose to any positive 143 | number for verbosity. 144 | ignore_warnings : bool, optional (default=True) 145 | When set to True, the warning related to algorigms that are not able to run are ignored. 146 | custom_metric : function, optional (default=None) 147 | When function is provided, models are evaluated based on the custom evaluation metric provided. 148 | prediction : bool, optional (default=False) 149 | When set to True, the predictions of all the models models are returned as dataframe. 150 | classifiers : list, optional (default="all") 151 | When function is provided, trains the chosen classifier(s). 152 | 153 | Examples 154 | -------- 155 | >>> from lazypredict.Supervised import LazyClassifier 156 | >>> from sklearn.datasets import load_breast_cancer 157 | >>> from sklearn.model_selection import train_test_split 158 | >>> data = load_breast_cancer() 159 | >>> X = data.data 160 | >>> y= data.target 161 | >>> X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=.5,random_state =123) 162 | >>> clf = LazyClassifier(verbose=0,ignore_warnings=True, custom_metric=None) 163 | >>> models,predictions = clf.fit(X_train, X_test, y_train, y_test) 164 | >>> model_dictionary = clf.provide_models(X_train,X_test,y_train,y_test) 165 | >>> models 166 | | Model | Accuracy | Balanced Accuracy | ROC AUC | F1 Score | Time Taken | 167 | |:-------------------------------|-----------:|--------------------:|----------:|-----------:|-------------:| 168 | | LinearSVC | 0.989474 | 0.987544 | 0.987544 | 0.989462 | 0.0150008 | 169 | | SGDClassifier | 0.989474 | 0.987544 | 0.987544 | 0.989462 | 0.0109992 | 170 | | MLPClassifier | 0.985965 | 0.986904 | 0.986904 | 0.985994 | 0.426 | 171 | | Perceptron | 0.985965 | 0.984797 | 0.984797 | 0.985965 | 0.0120046 | 172 | | LogisticRegression | 0.985965 | 0.98269 | 0.98269 | 0.985934 | 0.0200036 | 173 | | LogisticRegressionCV | 0.985965 | 0.98269 | 0.98269 | 0.985934 | 0.262997 | 174 | | SVC | 0.982456 | 0.979942 | 0.979942 | 0.982437 | 0.0140011 | 175 | | CalibratedClassifierCV | 0.982456 | 0.975728 | 0.975728 | 0.982357 | 0.0350015 | 176 | | PassiveAggressiveClassifier | 0.975439 | 0.974448 | 0.974448 | 0.975464 | 0.0130005 | 177 | | LabelPropagation | 0.975439 | 0.974448 | 0.974448 | 0.975464 | 0.0429988 | 178 | | LabelSpreading | 0.975439 | 0.974448 | 0.974448 | 0.975464 | 0.0310006 | 179 | | RandomForestClassifier | 0.97193 | 0.969594 | 0.969594 | 0.97193 | 0.033 | 180 | | GradientBoostingClassifier | 0.97193 | 0.967486 | 0.967486 | 0.971869 | 0.166998 | 181 | | QuadraticDiscriminantAnalysis | 0.964912 | 0.966206 | 0.966206 | 0.965052 | 0.0119994 | 182 | | HistGradientBoostingClassifier | 0.968421 | 0.964739 | 0.964739 | 0.968387 | 0.682003 | 183 | | RidgeClassifierCV | 0.97193 | 0.963272 | 0.963272 | 0.971736 | 0.0130029 | 184 | | RidgeClassifier | 0.968421 | 0.960525 | 0.960525 | 0.968242 | 0.0119977 | 185 | | AdaBoostClassifier | 0.961404 | 0.959245 | 0.959245 | 0.961444 | 0.204998 | 186 | | ExtraTreesClassifier | 0.961404 | 0.957138 | 0.957138 | 0.961362 | 0.0270066 | 187 | | KNeighborsClassifier | 0.961404 | 0.95503 | 0.95503 | 0.961276 | 0.0560005 | 188 | | BaggingClassifier | 0.947368 | 0.954577 | 0.954577 | 0.947882 | 0.0559971 | 189 | | BernoulliNB | 0.950877 | 0.951003 | 0.951003 | 0.951072 | 0.0169988 | 190 | | LinearDiscriminantAnalysis | 0.961404 | 0.950816 | 0.950816 | 0.961089 | 0.0199995 | 191 | | GaussianNB | 0.954386 | 0.949536 | 0.949536 | 0.954337 | 0.0139935 | 192 | | NuSVC | 0.954386 | 0.943215 | 0.943215 | 0.954014 | 0.019989 | 193 | | DecisionTreeClassifier | 0.936842 | 0.933693 | 0.933693 | 0.936971 | 0.0170023 | 194 | | NearestCentroid | 0.947368 | 0.933506 | 0.933506 | 0.946801 | 0.0160074 | 195 | | ExtraTreeClassifier | 0.922807 | 0.912168 | 0.912168 | 0.922462 | 0.0109999 | 196 | | CheckingClassifier | 0.361404 | 0.5 | 0.5 | 0.191879 | 0.0170043 | 197 | | DummyClassifier | 0.512281 | 0.489598 | 0.489598 | 0.518924 | 0.0119965 | 198 | """ 199 | 200 | def __init__( 201 | self, 202 | verbose=0, 203 | ignore_warnings=True, 204 | custom_metric=None, 205 | predictions=False, 206 | random_state=42, 207 | classifiers="all", 208 | ): 209 | self.verbose = verbose 210 | self.ignore_warnings = ignore_warnings 211 | self.custom_metric = custom_metric 212 | self.predictions = predictions 213 | self.models = {} 214 | self.random_state = random_state 215 | self.classifiers = classifiers 216 | 217 | def fit(self, X_train, X_test, y_train, y_test): 218 | """Fit Classification algorithms to X_train and y_train, predict and score on X_test, y_test. 219 | Parameters 220 | ---------- 221 | X_train : array-like, 222 | Training vectors, where rows is the number of samples 223 | and columns is the number of features. 224 | X_test : array-like, 225 | Testing vectors, where rows is the number of samples 226 | and columns is the number of features. 227 | y_train : array-like, 228 | Training vectors, where rows is the number of samples 229 | and columns is the number of features. 230 | y_test : array-like, 231 | Testing vectors, where rows is the number of samples 232 | and columns is the number of features. 233 | Returns 234 | ------- 235 | scores : Pandas DataFrame 236 | Returns metrics of all the models in a Pandas DataFrame. 237 | predictions : Pandas DataFrame 238 | Returns predictions of all the models in a Pandas DataFrame. 239 | """ 240 | Accuracy = [] 241 | B_Accuracy = [] 242 | ROC_AUC = [] 243 | F1 = [] 244 | names = [] 245 | TIME = [] 246 | predictions = {} 247 | 248 | if self.custom_metric is not None: 249 | CUSTOM_METRIC = [] 250 | 251 | if isinstance(X_train, np.ndarray): 252 | X_train = pd.DataFrame(X_train) 253 | X_test = pd.DataFrame(X_test) 254 | 255 | numeric_features = X_train.select_dtypes(include=[np.number]).columns 256 | categorical_features = X_train.select_dtypes(include=["object"]).columns 257 | 258 | categorical_low, categorical_high = get_card_split( 259 | X_train, categorical_features 260 | ) 261 | 262 | preprocessor = ColumnTransformer( 263 | transformers=[ 264 | ("numeric", numeric_transformer, numeric_features), 265 | ("categorical_low", categorical_transformer_low, categorical_low), 266 | ("categorical_high", categorical_transformer_high, categorical_high), 267 | ] 268 | ) 269 | 270 | if self.classifiers == "all": 271 | self.classifiers = CLASSIFIERS 272 | else: 273 | try: 274 | temp_list = [] 275 | for classifier in self.classifiers: 276 | full_name = (classifier.__name__, classifier) 277 | temp_list.append(full_name) 278 | self.classifiers = temp_list 279 | except Exception as exception: 280 | print(exception) 281 | print("Invalid Classifier(s)") 282 | 283 | for name, model in tqdm(self.classifiers): 284 | start = time.time() 285 | try: 286 | if "random_state" in model().get_params().keys(): 287 | pipe = Pipeline( 288 | steps=[ 289 | ("preprocessor", preprocessor), 290 | ("classifier", model(random_state=self.random_state)), 291 | ], 292 | verbose=False 293 | ) 294 | else: 295 | pipe = Pipeline( 296 | steps=[("preprocessor", preprocessor), ("classifier", model())], 297 | verbose=False 298 | ) 299 | 300 | if model.__name__=="LGBMClassifier": 301 | callbacks = [lightgbm.early_stopping(10, verbose=0), lightgbm.log_evaluation(period=0)] 302 | pipe.fit(X_train, y_train,callbacks=callbacks) 303 | else: 304 | pipe.fit(X_train, y_train) 305 | 306 | self.models[name] = pipe 307 | y_pred = pipe.predict(X_test) 308 | accuracy = accuracy_score(y_test, y_pred, normalize=True) 309 | b_accuracy = balanced_accuracy_score(y_test, y_pred) 310 | f1 = f1_score(y_test, y_pred, average="weighted") 311 | try: 312 | roc_auc = roc_auc_score(y_test, y_pred) 313 | except Exception as exception: 314 | roc_auc = None 315 | if self.ignore_warnings is False: 316 | print("ROC AUC couldn't be calculated for " + name) 317 | print(exception) 318 | names.append(name) 319 | Accuracy.append(accuracy) 320 | B_Accuracy.append(b_accuracy) 321 | ROC_AUC.append(roc_auc) 322 | F1.append(f1) 323 | TIME.append(time.time() - start) 324 | if self.custom_metric is not None: 325 | custom_metric = self.custom_metric(y_test, y_pred) 326 | CUSTOM_METRIC.append(custom_metric) 327 | if self.verbose > 0: 328 | if self.custom_metric is not None: 329 | print( 330 | { 331 | "Model": name, 332 | "Accuracy": accuracy, 333 | "Balanced Accuracy": b_accuracy, 334 | "ROC AUC": roc_auc, 335 | "F1 Score": f1, 336 | self.custom_metric.__name__: custom_metric, 337 | "Time taken": time.time() - start, 338 | } 339 | ) 340 | else: 341 | print( 342 | { 343 | "Model": name, 344 | "Accuracy": accuracy, 345 | "Balanced Accuracy": b_accuracy, 346 | "ROC AUC": roc_auc, 347 | "F1 Score": f1, 348 | "Time taken": time.time() - start, 349 | } 350 | ) 351 | if self.predictions: 352 | predictions[name] = y_pred 353 | except Exception as exception: 354 | if self.ignore_warnings is False: 355 | print(name + " model failed to execute") 356 | print(exception) 357 | if self.custom_metric is None: 358 | scores = pd.DataFrame( 359 | { 360 | "Model": names, 361 | "Accuracy": Accuracy, 362 | "Balanced Accuracy": B_Accuracy, 363 | "ROC AUC": ROC_AUC, 364 | "F1 Score": F1, 365 | "Time Taken": TIME, 366 | } 367 | ) 368 | else: 369 | scores = pd.DataFrame( 370 | { 371 | "Model": names, 372 | "Accuracy": Accuracy, 373 | "Balanced Accuracy": B_Accuracy, 374 | "ROC AUC": ROC_AUC, 375 | "F1 Score": F1, 376 | self.custom_metric.__name__: CUSTOM_METRIC, 377 | "Time Taken": TIME, 378 | } 379 | ) 380 | scores = scores.sort_values(by="Balanced Accuracy", ascending=False).set_index( 381 | "Model" 382 | ) 383 | 384 | if self.predictions: 385 | predictions_df = pd.DataFrame.from_dict(predictions) 386 | return scores, predictions_df if self.predictions is True else scores 387 | 388 | def provide_models(self, X_train, X_test, y_train, y_test): 389 | """ 390 | This function returns all the model objects trained in fit function. 391 | If fit is not called already, then we call fit and then return the models. 392 | Parameters 393 | ---------- 394 | X_train : array-like, 395 | Training vectors, where rows is the number of samples 396 | and columns is the number of features. 397 | X_test : array-like, 398 | Testing vectors, where rows is the number of samples 399 | and columns is the number of features. 400 | y_train : array-like, 401 | Training vectors, where rows is the number of samples 402 | and columns is the number of features. 403 | y_test : array-like, 404 | Testing vectors, where rows is the number of samples 405 | and columns is the number of features. 406 | Returns 407 | ------- 408 | models: dict-object, 409 | Returns a dictionary with each model pipeline as value 410 | with key as name of models. 411 | """ 412 | if len(self.models.keys()) == 0: 413 | self.fit(X_train, X_test, y_train, y_test) 414 | 415 | return self.models 416 | 417 | 418 | def adjusted_rsquared(r2, n, p): 419 | return 1 - (1 - r2) * ((n - 1) / (n - p - 1)) 420 | 421 | 422 | # Helper class for performing classification 423 | 424 | 425 | class LazyRegressor: 426 | """ 427 | This module helps in fitting regression models that are available in Scikit-learn 428 | Parameters 429 | ---------- 430 | verbose : int, optional (default=0) 431 | For the liblinear and lbfgs solvers set verbose to any positive 432 | number for verbosity. 433 | ignore_warnings : bool, optional (default=True) 434 | When set to True, the warning related to algorigms that are not able to run are ignored. 435 | custom_metric : function, optional (default=None) 436 | When function is provided, models are evaluated based on the custom evaluation metric provided. 437 | prediction : bool, optional (default=False) 438 | When set to True, the predictions of all the models models are returned as dataframe. 439 | regressors : list, optional (default="all") 440 | When function is provided, trains the chosen regressor(s). 441 | 442 | Examples 443 | -------- 444 | >>> from lazypredict.Supervised import LazyRegressor 445 | >>> from sklearn import datasets 446 | >>> from sklearn.utils import shuffle 447 | >>> import numpy as np 448 | 449 | >>> boston = datasets.load_boston() 450 | >>> X, y = shuffle(boston.data, boston.target, random_state=13) 451 | >>> X = X.astype(np.float32) 452 | 453 | >>> offset = int(X.shape[0] * 0.9) 454 | >>> X_train, y_train = X[:offset], y[:offset] 455 | >>> X_test, y_test = X[offset:], y[offset:] 456 | 457 | >>> reg = LazyRegressor(verbose=0, ignore_warnings=False, custom_metric=None) 458 | >>> models, predictions = reg.fit(X_train, X_test, y_train, y_test) 459 | >>> model_dictionary = reg.provide_models(X_train, X_test, y_train, y_test) 460 | >>> models 461 | | Model | Adjusted R-Squared | R-Squared | RMSE | Time Taken | 462 | |:------------------------------|-------------------:|----------:|------:|-----------:| 463 | | SVR | 0.83 | 0.88 | 2.62 | 0.01 | 464 | | BaggingRegressor | 0.83 | 0.88 | 2.63 | 0.03 | 465 | | NuSVR | 0.82 | 0.86 | 2.76 | 0.03 | 466 | | RandomForestRegressor | 0.81 | 0.86 | 2.78 | 0.21 | 467 | | XGBRegressor | 0.81 | 0.86 | 2.79 | 0.06 | 468 | | GradientBoostingRegressor | 0.81 | 0.86 | 2.84 | 0.11 | 469 | | ExtraTreesRegressor | 0.79 | 0.84 | 2.98 | 0.12 | 470 | | AdaBoostRegressor | 0.78 | 0.83 | 3.04 | 0.07 | 471 | | HistGradientBoostingRegressor | 0.77 | 0.83 | 3.06 | 0.17 | 472 | | PoissonRegressor | 0.77 | 0.83 | 3.11 | 0.01 | 473 | | LGBMRegressor | 0.77 | 0.83 | 3.11 | 0.07 | 474 | | KNeighborsRegressor | 0.77 | 0.83 | 3.12 | 0.01 | 475 | | DecisionTreeRegressor | 0.65 | 0.74 | 3.79 | 0.01 | 476 | | MLPRegressor | 0.65 | 0.74 | 3.80 | 1.63 | 477 | | HuberRegressor | 0.64 | 0.74 | 3.84 | 0.01 | 478 | | GammaRegressor | 0.64 | 0.73 | 3.88 | 0.01 | 479 | | LinearSVR | 0.62 | 0.72 | 3.96 | 0.01 | 480 | | RidgeCV | 0.62 | 0.72 | 3.97 | 0.01 | 481 | | BayesianRidge | 0.62 | 0.72 | 3.97 | 0.01 | 482 | | Ridge | 0.62 | 0.72 | 3.97 | 0.01 | 483 | | TransformedTargetRegressor | 0.62 | 0.72 | 3.97 | 0.01 | 484 | | LinearRegression | 0.62 | 0.72 | 3.97 | 0.01 | 485 | | ElasticNetCV | 0.62 | 0.72 | 3.98 | 0.04 | 486 | | LassoCV | 0.62 | 0.72 | 3.98 | 0.06 | 487 | | LassoLarsIC | 0.62 | 0.72 | 3.98 | 0.01 | 488 | | LassoLarsCV | 0.62 | 0.72 | 3.98 | 0.02 | 489 | | Lars | 0.61 | 0.72 | 3.99 | 0.01 | 490 | | LarsCV | 0.61 | 0.71 | 4.02 | 0.04 | 491 | | SGDRegressor | 0.60 | 0.70 | 4.07 | 0.01 | 492 | | TweedieRegressor | 0.59 | 0.70 | 4.12 | 0.01 | 493 | | GeneralizedLinearRegressor | 0.59 | 0.70 | 4.12 | 0.01 | 494 | | ElasticNet | 0.58 | 0.69 | 4.16 | 0.01 | 495 | | Lasso | 0.54 | 0.66 | 4.35 | 0.02 | 496 | | RANSACRegressor | 0.53 | 0.65 | 4.41 | 0.04 | 497 | | OrthogonalMatchingPursuitCV | 0.45 | 0.59 | 4.78 | 0.02 | 498 | | PassiveAggressiveRegressor | 0.37 | 0.54 | 5.09 | 0.01 | 499 | | GaussianProcessRegressor | 0.23 | 0.43 | 5.65 | 0.03 | 500 | | OrthogonalMatchingPursuit | 0.16 | 0.38 | 5.89 | 0.01 | 501 | | ExtraTreeRegressor | 0.08 | 0.32 | 6.17 | 0.01 | 502 | | DummyRegressor | -0.38 | -0.02 | 7.56 | 0.01 | 503 | | LassoLars | -0.38 | -0.02 | 7.56 | 0.01 | 504 | | KernelRidge | -11.50 | -8.25 | 22.74 | 0.01 | 505 | """ 506 | 507 | def __init__( 508 | self, 509 | verbose=0, 510 | ignore_warnings=True, 511 | custom_metric=None, 512 | predictions=False, 513 | random_state=42, 514 | regressors="all", 515 | ): 516 | self.verbose = verbose 517 | self.ignore_warnings = ignore_warnings 518 | self.custom_metric = custom_metric 519 | self.predictions = predictions 520 | self.models = {} 521 | self.random_state = random_state 522 | self.regressors = regressors 523 | 524 | def fit(self, X_train, X_test, y_train, y_test): 525 | """Fit Regression algorithms to X_train and y_train, predict and score on X_test, y_test. 526 | Parameters 527 | ---------- 528 | X_train : array-like, 529 | Training vectors, where rows is the number of samples 530 | and columns is the number of features. 531 | X_test : array-like, 532 | Testing vectors, where rows is the number of samples 533 | and columns is the number of features. 534 | y_train : array-like, 535 | Training vectors, where rows is the number of samples 536 | and columns is the number of features. 537 | y_test : array-like, 538 | Testing vectors, where rows is the number of samples 539 | and columns is the number of features. 540 | Returns 541 | ------- 542 | scores : Pandas DataFrame 543 | Returns metrics of all the models in a Pandas DataFrame. 544 | predictions : Pandas DataFrame 545 | Returns predictions of all the models in a Pandas DataFrame. 546 | """ 547 | R2 = [] 548 | ADJR2 = [] 549 | RMSE = [] 550 | # WIN = [] 551 | names = [] 552 | TIME = [] 553 | predictions = {} 554 | 555 | if self.custom_metric: 556 | CUSTOM_METRIC = [] 557 | 558 | if isinstance(X_train, np.ndarray): 559 | X_train = pd.DataFrame(X_train) 560 | X_test = pd.DataFrame(X_test) 561 | 562 | numeric_features = X_train.select_dtypes(include=[np.number]).columns 563 | categorical_features = X_train.select_dtypes(include=["object"]).columns 564 | 565 | categorical_low, categorical_high = get_card_split( 566 | X_train, categorical_features 567 | ) 568 | 569 | preprocessor = ColumnTransformer( 570 | transformers=[ 571 | ("numeric", numeric_transformer, numeric_features), 572 | ("categorical_low", categorical_transformer_low, categorical_low), 573 | ("categorical_high", categorical_transformer_high, categorical_high), 574 | ] 575 | ) 576 | 577 | if self.regressors == "all": 578 | self.regressors = REGRESSORS 579 | else: 580 | try: 581 | temp_list = [] 582 | for regressor in self.regressors: 583 | full_name = (regressor.__name__, regressor) 584 | temp_list.append(full_name) 585 | self.regressors = temp_list 586 | except Exception as exception: 587 | print(exception) 588 | print("Invalid Regressor(s)") 589 | 590 | for name, model in tqdm(self.regressors): 591 | start = time.time() 592 | try: 593 | if "random_state" in model().get_params().keys(): 594 | pipe = Pipeline( 595 | steps=[ 596 | ("preprocessor", preprocessor), 597 | ("regressor", model(random_state=self.random_state)), 598 | ] 599 | ) 600 | else: 601 | pipe = Pipeline( 602 | steps=[("preprocessor", preprocessor), ("regressor", model())] 603 | ) 604 | 605 | if model.__name__=="LGBMRegressor": 606 | callbacks = [lightgbm.early_stopping(10, verbose=0), lightgbm.log_evaluation(period=0)] 607 | pipe.fit(X_train, y_train,callbacks=callbacks) 608 | else: 609 | pipe.fit(X_train, y_train) 610 | 611 | self.models[name] = pipe 612 | y_pred = pipe.predict(X_test) 613 | 614 | r_squared = r2_score(y_test, y_pred) 615 | adj_rsquared = adjusted_rsquared( 616 | r_squared, X_test.shape[0], X_test.shape[1] 617 | ) 618 | rmse = np.sqrt(mean_squared_error(y_test, y_pred)) 619 | 620 | names.append(name) 621 | R2.append(r_squared) 622 | ADJR2.append(adj_rsquared) 623 | RMSE.append(rmse) 624 | TIME.append(time.time() - start) 625 | 626 | if self.custom_metric: 627 | custom_metric = self.custom_metric(y_test, y_pred) 628 | CUSTOM_METRIC.append(custom_metric) 629 | 630 | if self.verbose > 0: 631 | scores_verbose = { 632 | "Model": name, 633 | "R-Squared": r_squared, 634 | "Adjusted R-Squared": adj_rsquared, 635 | "RMSE": rmse, 636 | "Time taken": time.time() - start, 637 | } 638 | 639 | if self.custom_metric: 640 | scores_verbose[self.custom_metric.__name__] = custom_metric 641 | 642 | print(scores_verbose) 643 | if self.predictions: 644 | predictions[name] = y_pred 645 | except Exception as exception: 646 | if self.ignore_warnings is False: 647 | print(name + " model failed to execute") 648 | print(exception) 649 | 650 | scores = { 651 | "Model": names, 652 | "Adjusted R-Squared": ADJR2, 653 | "R-Squared": R2, 654 | "RMSE": RMSE, 655 | "Time Taken": TIME, 656 | } 657 | 658 | if self.custom_metric: 659 | scores[self.custom_metric.__name__] = CUSTOM_METRIC 660 | 661 | scores = pd.DataFrame(scores) 662 | scores = scores.sort_values(by="Adjusted R-Squared", ascending=False).set_index( 663 | "Model" 664 | ) 665 | 666 | if self.predictions: 667 | predictions_df = pd.DataFrame.from_dict(predictions) 668 | return scores, predictions_df if self.predictions is True else scores 669 | 670 | def provide_models(self, X_train, X_test, y_train, y_test): 671 | """ 672 | This function returns all the model objects trained in fit function. 673 | If fit is not called already, then we call fit and then return the models. 674 | Parameters 675 | ---------- 676 | X_train : array-like, 677 | Training vectors, where rows is the number of samples 678 | and columns is the number of features. 679 | X_test : array-like, 680 | Testing vectors, where rows is the number of samples 681 | and columns is the number of features. 682 | y_train : array-like, 683 | Training vectors, where rows is the number of samples 684 | and columns is the number of features. 685 | y_test : array-like, 686 | Testing vectors, where rows is the number of samples 687 | and columns is the number of features. 688 | Returns 689 | ------- 690 | models: dict-object, 691 | Returns a dictionary with each model pipeline as value 692 | with key as name of models. 693 | """ 694 | if len(self.models.keys()) == 0: 695 | self.fit(X_train, X_test, y_train, y_test) 696 | 697 | return self.models 698 | -------------------------------------------------------------------------------- /sure/distance_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/sure/distance_metrics/__init__.py -------------------------------------------------------------------------------- /sure/distance_metrics/distance.py: -------------------------------------------------------------------------------- 1 | import pyximport 2 | import os 3 | from multiprocessing import Pool 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import polars as pl 8 | from sklearn.preprocessing import OrdinalEncoder 9 | 10 | from typing import Dict, List, Tuple, Union 11 | int_type = Union[int, np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64] 12 | float_type = Union[float, np.float16, np.float32, np.float64] 13 | 14 | from ..report_generator.report_generator import _save_to_json 15 | 16 | from sure import _drop_cols 17 | 18 | pyximport.install(setup_args={"include_dirs": np.get_include()}) 19 | from sure.distance_metrics.gower_matrix_c import gower_matrix_c 20 | 21 | 22 | def _polars_to_pandas(dataframe: pl.DataFrame | pl.LazyFrame): 23 | if isinstance(dataframe, pl.DataFrame): 24 | dataframe = dataframe.to_pandas() 25 | if isinstance(dataframe, pl.LazyFrame): 26 | dataframe = dataframe.collect().to_pandas() 27 | return dataframe 28 | 29 | def _gower_matrix( 30 | X_categorical: np.ndarray, 31 | X_numerical: np.ndarray, 32 | Y_categorical: np.ndarray, 33 | Y_numerical: np.ndarray, 34 | numericals_ranges: np.ndarray, 35 | features_weight_sum: float, 36 | fill_diagonal: bool, 37 | first_index: int = -1, 38 | ) -> np.ndarray: 39 | """ 40 | _summary_ 41 | 42 | Parameters 43 | ---------- 44 | X_categorical : np.ndarray 45 | 2D array containing only the categorical features of the X dataframe as uint8 values, shape (x_rows, cat_features). 46 | X_numerical : np.ndarray 47 | 2D array containing only the numerical features of the X dataframe as float32 values, shape (x_rows, num_features). 48 | Y_categorical : np.ndarray 49 | 2D array containing only the categorical features of the Y dataframe as uint8 values, shape (y_rows, cat_features). 50 | Y_numerical : np.ndarray 51 | 2D array containing only the numerical features of the Y dataframe as float32 values, shape (y_rows, num_features). 52 | numericals_ranges : np.ndarray 53 | 1D array containing the range (max-min) of each numerical feature as float32 values, shap (num_features,). 54 | features_weight_sum : float 55 | Sum of the feature weights used for the final average computation (usually it's just the number of features, each 56 | feature has a weigth of 1). 57 | fill_diagonal : bool 58 | Whether to fill the matrix diagonal values with a value larger than 1 (5.0). It must be True to get correct values 59 | if your computing the matrix just for one dataset (comparing a dataset with itself), otherwise you will get DCR==0 60 | for each row because on the diagonal you will compare a pair of identical instances. 61 | first_index : int, optioanl 62 | This is required only in case of parallel computation: the computation will occur batch by batch so ther original 63 | diagonal values will no longer be on the diagonal on each batch. We use this index to fill correctly the diagonal 64 | values. If -1 it's assumed there's no parallel computation, by default -1 65 | 66 | Returns 67 | ------- 68 | np.ndarray 69 | 1D array containing the Distance to the Closest Record for each row of x_dataframe shape (x_dataframe rows, ) 70 | """ 71 | return gower_matrix_c( 72 | X_categorical, 73 | X_numerical, 74 | Y_categorical, 75 | Y_numerical, 76 | numericals_ranges, 77 | features_weight_sum, 78 | fill_diagonal, 79 | first_index, 80 | ) 81 | 82 | def distance_to_closest_record( 83 | dcr_name: str, 84 | x_dataframe: pd.DataFrame | pl.DataFrame | pl.LazyFrame, 85 | y_dataframe: pd.DataFrame | pl.DataFrame | pl.LazyFrame = None, 86 | feature_weights: np.ndarray | List = None, 87 | parallel: bool = True, 88 | save_data: bool = True, 89 | path_to_json: str = "" 90 | ) -> np.ndarray: 91 | """ 92 | Compute the distancees to closest record of dataframe X from dataframe Y using 93 | a modified version of the Gower's distance. 94 | The two dataframes may contain mixed datatypes (numerical and categorical). 95 | 96 | Paper references: 97 | * A General Coefficient of Similarity and Some of Its Properties, J. C. Gower 98 | * Dimensionality Invariant Similarity Measure, Ahmad Basheer Hassanat 99 | 100 | Parameters 101 | ---------- 102 | dcr_name : str 103 | Name with which the DCR will be saved with in the JSON file used to generate the final report. 104 | Can be one of the following: 105 | - synth_train 106 | - synth_val 107 | - other 108 | x_dataframe : pd.DataFrame 109 | A dataset containing numerical and categorical data. 110 | categorical_features : List 111 | List of booleans that indicates which features are categorical. 112 | If categoricals_features[i] is True, feature i is categorical. 113 | Must have same length of x_dataframe.columns. 114 | y_dataframe : pd.DataFrame, optional 115 | Another dataset containing numerical and categorical data, by default None. 116 | It must contains the same columns of x_dataframe. 117 | If None, the distance matrix is computed between x_dataframe and x_dataframe 118 | feature_weights : List, optional 119 | List of features weights to use computing distances, by default None. 120 | If None, each feature weight is 1.0 121 | parallel : Boolean, optional 122 | Whether to enable the parallelization to compute Gower matrix, by default True 123 | save_data : bool 124 | If True, saves the DCR information into the JSON file used to generate the final report. 125 | path_to_json : str 126 | Path to the JSON file used to generate the final report. 127 | 128 | Returns 129 | ------- 130 | np.ndarray 131 | 1D array containing the Distance to the Closest Record for each row of x_dataframe 132 | shape (x_dataframe rows, ) 133 | 134 | Raises 135 | ------ 136 | TypeError 137 | If dc_name is not one of the names listed above. 138 | TypeError 139 | If X and Y don't have the same (number of) columns. 140 | """ 141 | if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other": 142 | raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"") 143 | 144 | # Converting X Dataset to pd.DataFrame 145 | X = _polars_to_pandas(x_dataframe) 146 | 147 | # Convert any temporal features to int 148 | temporal_columns = X.select_dtypes(include=['datetime']).columns 149 | X[temporal_columns] = X[temporal_columns].astype('int64') 150 | 151 | # If a second dataframe is provided, the distances are calculated using it; otherwise, they are calculated within X itself. 152 | if y_dataframe is None: 153 | Y = X 154 | fill_diagonal = True 155 | else: 156 | # Converting X Dataset into pd.DataFrame 157 | Y = _polars_to_pandas(y_dataframe) 158 | fill_diagonal = False 159 | Y[temporal_columns] = Y[temporal_columns].astype('int64') 160 | 161 | # Drop columns that are present in Y but are missing in X and vice versa 162 | X, Y = _drop_cols(X, Y) 163 | 164 | if not isinstance(X, np.ndarray): 165 | if not np.array_equal(X.columns, Y.columns): 166 | raise TypeError("X and Y dataframes have different columns.") 167 | else: 168 | if not X.shape[1] == Y.shape[1]: 169 | raise TypeError("X and Y arrays have different number of columns.") 170 | 171 | # Get categorical features 172 | categorical_features = np.array(X.dtypes)==pl.Utf8 173 | 174 | # Both datafrmaes are turned into numpy arrays 175 | if not isinstance(X, np.ndarray): 176 | X = np.asarray(X) 177 | if not isinstance(Y, np.ndarray): 178 | Y = np.asarray(Y) 179 | 180 | X_categorical = X[:, categorical_features] 181 | Y_categorical = Y[:, categorical_features] 182 | 183 | if feature_weights is None: 184 | # If no weights are specified, all weights are 1 185 | feature_weights = np.ones(X.shape[1]) 186 | else: 187 | feature_weights = np.array(feature_weights) 188 | 189 | # The sum of the weights is necessary to compute the mean value in the end (division) 190 | weight_sum = feature_weights.sum().astype("float32") 191 | 192 | # Perform label encoding on categorical features of X and Y 193 | if categorical_features.any(): 194 | encoder = OrdinalEncoder( 195 | handle_unknown='use_encoded_value', 196 | dtype="uint8", 197 | unknown_value=-1, 198 | encoded_missing_value=-1 199 | ) 200 | 201 | # Apply the encoder on all categorical columns at once 202 | # Categorical feature matrix of X (num_rows_X x num_cat_feat) 203 | X_categorical = encoder.fit_transform(X_categorical) 204 | # Categorical feature matrix of Y (num_rows_Y x num_cat_feat) 205 | Y_categorical = encoder.transform(Y_categorical) 206 | else: 207 | X_categorical = X_categorical.astype("uint8") 208 | Y_categorical = Y_categorical.astype("uint8") 209 | 210 | # Numerical feature matrix of X (num_rows_X x num_num_feat) 211 | X_numerical = X[:, np.logical_not(categorical_features)].astype("float32") 212 | 213 | # Numerical feature matrix ofMatrice delle feature numeriche di Y (num_rows_Y x num_num_feat) 214 | Y_numerical = Y[:, np.logical_not(categorical_features)].astype("float32") 215 | 216 | # The range of the numerical features is necessary for the way Gower distances are calculated. 217 | # I find the minimum and maximum for each numerical feature by concatenating X and Y, and then 218 | # calculate the range by subtracting all the minimum values from the maximum values. 219 | 220 | numericals_mins = np.amin(np.concatenate((X_numerical, Y_numerical)), axis=0) 221 | numericals_maxs = np.amax(np.concatenate((X_numerical, Y_numerical)), axis=0) 222 | numericals_ranges = numericals_maxs - numericals_mins 223 | 224 | X_rows = X.shape[0] 225 | 226 | """ 227 | Perform a parallel calculation on a DataFrame by dividing the data into chunks 228 | and distributing the chunks across multiple CPUs (all the ones available CPUs except one). 229 | divide the total number of rows in X by the number of CPUs used to determine the size of the chunks on 230 | which the parallel calculation is performed. 231 | The for loop with index executes the actual calculation, and index is passed to fill the value that 232 | corresponds to the distance from the same instance in case of calculation on a single DataFrame 233 | (fill_diagonal == True). 234 | """ 235 | if parallel: 236 | result_objs = [] 237 | number_of_cpus = os.cpu_count() - 1 238 | chunk_size = int(X_rows / number_of_cpus) 239 | chunk_size = chunk_size if chunk_size > 0 else 1 240 | with Pool(processes=number_of_cpus) as pool: 241 | for index in range(0, X_rows, chunk_size): 242 | result = pool.apply_async( 243 | _gower_matrix, 244 | ( 245 | X_categorical[index : index + chunk_size], 246 | X_numerical[index : index + chunk_size], 247 | Y_categorical, 248 | Y_numerical, 249 | numericals_ranges, 250 | weight_sum, 251 | fill_diagonal, 252 | index, 253 | ), 254 | ) 255 | result_objs.append(result) 256 | results = [result.get() for result in result_objs] 257 | dcr = np.concatenate(results) 258 | else: 259 | dcr = _gower_matrix( 260 | X_categorical, 261 | X_numerical, 262 | Y_categorical, 263 | Y_numerical, 264 | numericals_ranges, 265 | weight_sum, 266 | fill_diagonal, 267 | ) 268 | if save_data: 269 | _save_to_json("dcr_"+dcr_name, dcr, path_to_json) 270 | return dcr 271 | 272 | def dcr_stats(dcr_name: str, 273 | distances_to_closest_record: np.ndarray, 274 | save_data: bool = True, 275 | path_to_json: str = "") -> Dict: 276 | """ 277 | This function returns the statisitcs for an array containing DCR computed previously. 278 | 279 | Parameters 280 | ---------- 281 | dcr_name : str 282 | Name with which the DCR will be saved with in the JSON file used to generate the final report. 283 | Can be one of the following: 284 | - synth_train 285 | - synth_val 286 | - other 287 | distances_to_closest_record : np.ndarray 288 | A 1D-array containing the Distance to the Closest Record for each row of a dataframe 289 | shape (dataframe rows, ) 290 | save_data : bool 291 | If True, saves the DCR information into the JSON file used to generate the final report. 292 | path_to_json : str 293 | Path to the JSON file used to generate the final report. 294 | 295 | Returns 296 | ------- 297 | Dict 298 | A dictionary containing mean and percentiles of the given DCR array. 299 | """ 300 | if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other": 301 | raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"") 302 | 303 | dcr_mean = np.mean(distances_to_closest_record) 304 | dcr_percentiles = np.percentile(distances_to_closest_record, [0, 25, 50, 75, 100]) 305 | dcr_stats = { 306 | "mean": dcr_mean.item(), 307 | "min": dcr_percentiles[0].item(), 308 | "25%": dcr_percentiles[1].item(), 309 | "median": dcr_percentiles[2].item(), 310 | "75%": dcr_percentiles[3].item(), 311 | "max": dcr_percentiles[4].item(), 312 | } 313 | if save_data: 314 | _save_to_json("dcr_"+dcr_name+"_stats", dcr_stats, path_to_json) 315 | return dcr_stats 316 | 317 | def number_of_dcr_equal_to_zero(dcr_name: str, 318 | distances_to_closest_record: np.ndarray, 319 | save_data: bool = True, 320 | path_to_json: str = "") -> int_type: 321 | """ 322 | Return the number of 0s in the given DCR array, that is the number of duplicates/clones detected. 323 | 324 | Parameters 325 | ---------- 326 | distances_to_closest_record : np.ndarray 327 | A 1D-array containing the Distance to the Closest Record for each row of a dataframe 328 | shape (dataframe rows, ) 329 | save_data : bool 330 | If True, saves the DCR information into the JSON file used to generate the final report. 331 | path_to_json : str 332 | Path to the JSON file used to generate the final report. 333 | 334 | Returns 335 | ------- 336 | int 337 | The number of 0s in the given DCR array. 338 | """ 339 | if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other": 340 | raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"") 341 | 342 | zero_values_mask = distances_to_closest_record == 0.0 343 | if save_data: 344 | _save_to_json("dcr_"+dcr_name+"_num_of_zeros", zero_values_mask.sum(), path_to_json) 345 | return zero_values_mask.sum() 346 | 347 | # def dcr_histogram( 348 | # dcr_name: str, 349 | # distances_to_closest_record: np.ndarray, 350 | # bins: int = 20, 351 | # scale_to_100: bool = True, 352 | # save_data: bool = True, 353 | # path_to_json: str = "" 354 | # ) -> Dict: 355 | # """ 356 | # Compute the histogram of a DCR array: the DCR values equal to 0 are extracted before the 357 | # histogram computation so that the first bar represent only the 0 (duplicates/clones) 358 | # and the following bars represent the standard bins (with edge) of an histogram. 359 | 360 | # Parameters 361 | # ---------- 362 | # distances_to_closest_record : np.ndarray 363 | # A 1D-array containing the Distance to the Closest Record for each row of a dataframe 364 | # shape (dataframe rows, ) 365 | # bins : int, optional 366 | # _description_, by default 20 367 | # scale_to_100 : bool, optional 368 | # Wheter to scale the histogram bins between 0 and 100 (instead of 0 and 1), by default True 369 | 370 | # Returns 371 | # ------- 372 | # Dict 373 | # A dict containing the following items: 374 | # * bins, histogram bins detected as string labels. 375 | # The first bin/label is 0 (duplicates/clones), then the format is [inf_edge, sup_edge). 376 | # * count, histogram values for each bin in bins 377 | # * bins_edge_without_zero, the bin edges as returned by the np.histogram function without 0. 378 | # """ 379 | # if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other": 380 | # raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"") 381 | 382 | # range_bins_with_zero = ["0.0"] 383 | # number_of_dcr_zeros = number_of_dcr_equal_to_zero(dcr_name, distances_to_closest_record) 384 | # dcr_non_zeros = distances_to_closest_record[distances_to_closest_record > 0] 385 | # counts_without_zero, bins_without_zero = np.histogram( 386 | # dcr_non_zeros, bins=bins, range=(0.0, 1.0), density=False 387 | # ) 388 | # if scale_to_100: 389 | # scaled_bins_without_zero = bins_without_zero * 100 390 | # else: 391 | # scaled_bins_without_zero = bins_without_zero 392 | 393 | # range_bins_with_zero.append("(0.0-{:.2f})".format(scaled_bins_without_zero[1])) 394 | # for i, left_edge in enumerate(scaled_bins_without_zero[1:-2]): 395 | # range_bins_with_zero.append( 396 | # "[{:.2f}-{:.2f})".format(left_edge, scaled_bins_without_zero[i + 2]) 397 | # ) 398 | # range_bins_with_zero.append( 399 | # "[{:.2f}-{:.2f}]".format( 400 | # scaled_bins_without_zero[-2], scaled_bins_without_zero[-1] 401 | # ) 402 | # ) 403 | 404 | # counts_with_zero = np.insert(counts_without_zero, 0, number_of_dcr_zeros) 405 | 406 | # dcr_hist = { 407 | # "bins": range_bins_with_zero, 408 | # "counts": counts_with_zero.tolist(), 409 | # "bins_edge_without_zero": bins_without_zero.tolist(), 410 | # } 411 | 412 | # if save_data: 413 | # _save_to_json("dcr_"+dcr_name+"_hist", dcr_hist, path_to_json) 414 | # return dcr_hist 415 | 416 | def validation_dcr_test( 417 | dcr_synth_train: np.ndarray, 418 | dcr_synth_validation: np.ndarray, 419 | save_data: bool = True, 420 | path_to_json: str = "" 421 | ) -> float_type: 422 | """ 423 | - If the returned percentage is close to (or smaller than) 50%, then the synthetic datset's records are equally close to the original training set and to the validation set. 424 | In this casse the synthetic data does not allow to conjecture whether a record was or was not contained in the training dataset. 425 | - If the returned percentage is greater than 50%, then the synthetic datset's records are closer to the training set than to the validation set, indicating 426 | that vulnerable records are present in the synthetic dataset. 427 | 428 | Parameters 429 | ---------- 430 | dcr_synth_train : np.ndarray 431 | A 1D-array containing the Distance to the Closest Record for each row of the synthetic 432 | dataset wrt the training dataset, shape (synthetic rows, ) 433 | dcr_synth_validation : np.ndarray 434 | A 1D-array containing the Distance to the Closest Record for each row of the synthetic 435 | dataset wrt the validation dataset, shape (synthetic rows, ) 436 | save_data : bool 437 | If True, saves the DCR information into the JSON file used to generate the final report. 438 | path_to_json : str 439 | Path to the JSON file used to generate the final report. 440 | 441 | Returns 442 | ------- 443 | float 444 | The percentage of synthetic rows closer to the training dataset than to the validation dataset. 445 | 446 | Raises 447 | ------ 448 | ValueError 449 | If the two DCR array given as parameters have different shapes. 450 | """ 451 | if dcr_synth_train.shape != dcr_synth_validation.shape: 452 | raise ValueError("Dcr arrays have different shapes.") 453 | 454 | warnings = "" 455 | percentage = 0.0 456 | 457 | if dcr_synth_train.sum() == 0: 458 | percentage = 100.0 459 | warnings = ( 460 | "The synthetic dataset is an exact copy/clone of the training dataset." 461 | ) 462 | elif (dcr_synth_train == dcr_synth_validation).all(): 463 | percentage = 0.0 464 | warnings = ( 465 | "The validation dataset is an exact copy/clone of the training dataset." 466 | ) 467 | else: 468 | if dcr_synth_validation.sum() == 0: 469 | warnings = "The synthetic dataset is an exact copy/clone of the validation dataset." 470 | 471 | number_of_rows = dcr_synth_train.shape[0] 472 | synth_dcr_smaller_than_holdout_dcr_mask = dcr_synth_train < dcr_synth_validation 473 | synth_dcr_smaller_than_holdout_dcr_sum = ( 474 | synth_dcr_smaller_than_holdout_dcr_mask.sum() 475 | ) 476 | percentage = synth_dcr_smaller_than_holdout_dcr_sum / number_of_rows * 100 477 | 478 | dcr_validation = {"percentage": round(percentage,4), "warnings": warnings} 479 | if save_data: 480 | _save_to_json("dcr_validation", dcr_validation, path_to_json) 481 | return dcr_validation -------------------------------------------------------------------------------- /sure/distance_metrics/gower_matrix_c.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level = 3 2 | # distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION 3 | 4 | import numpy 5 | cimport numpy 6 | cimport cython 7 | 8 | cdef numpy.ndarray[numpy.float32_t, ndim=1] gower_row( 9 | numpy.ndarray[numpy.uint8_t, ndim=1] x_categoricals, 10 | numpy.ndarray[numpy.float32_t, ndim=1] x_numericals, 11 | numpy.ndarray[numpy.uint8_t, ndim=2] Y_categoricals, 12 | numpy.ndarray[numpy.float32_t, ndim=2] Y_numericals, 13 | numpy.ndarray[numpy.float32_t, ndim=1] numericals_ranges, 14 | numpy.float32_t features_weight_sum 15 | ): 16 | cdef numpy.ndarray[numpy.uint8_t, ndim=2] dist_x_Y_categoricals 17 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y_categoricals_sum 18 | cdef numpy.ndarray[numpy.float32_t, ndim=2] dist_x_Y_numericals 19 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y_numericals_sum 20 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y 21 | 22 | dist_x_Y_categoricals = (x_categoricals!=Y_categoricals).astype(numpy.uint8) 23 | dist_x_Y_categoricals_sum = dist_x_Y_categoricals.sum(axis=1).astype(numpy.float32) 24 | 25 | dist_x_Y_numericals = numpy.abs((x_numericals-Y_numericals)/numericals_ranges) 26 | dist_x_Y_numericals = numpy.nan_to_num(dist_x_Y_numericals) 27 | dist_x_Y_numericals_sum = dist_x_Y_numericals.sum(axis=1).astype(numpy.float32) 28 | 29 | dist_x_Y = dist_x_Y_categoricals_sum + dist_x_Y_numericals_sum 30 | dist_x_Y = dist_x_Y / features_weight_sum 31 | 32 | return dist_x_Y 33 | 34 | 35 | @cython.boundscheck(False) # turn off bounds-checking for entire function 36 | @cython.wraparound(False) # turn off negative index wrapping for entire function 37 | def gower_matrix_c( 38 | numpy.ndarray[numpy.uint8_t, ndim=2] X_categorical, 39 | numpy.ndarray[numpy.float32_t, ndim=2] X_numerical, 40 | numpy.ndarray[numpy.uint8_t, ndim=2] Y_categorical, 41 | numpy.ndarray[numpy.float32_t, ndim=2] Y_numerical, 42 | numpy.ndarray[numpy.float32_t, ndim=1] numericals_ranges, 43 | numpy.float32_t features_weight_sum, 44 | bint fill_diagonal, 45 | Py_ssize_t first_index 46 | ): 47 | cdef Py_ssize_t i 48 | cdef numpy.int32_t X_rows 49 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y 50 | cdef numpy.ndarray[numpy.float32_t, ndim=1] distance_matrix 51 | 52 | X_rows = X_categorical.shape[0] 53 | distance_matrix = numpy.zeros(X_rows, dtype=numpy.float32) 54 | 55 | # per ogni istanza della matrice X 56 | for i in range(X_rows): 57 | dist_x_Y = gower_row(X_categorical[i, :], X_numerical[i, :], Y_categorical, Y_numerical, numericals_ranges, features_weight_sum) 58 | 59 | if (fill_diagonal and first_index < 0): 60 | dist_x_Y[i] = 5.0 61 | elif (fill_diagonal): 62 | dist_x_Y[first_index+i] = 5.0 63 | distance_matrix[i] = numpy.amin(dist_x_Y) 64 | 65 | return distance_matrix -------------------------------------------------------------------------------- /sure/privacy/__init__.py: -------------------------------------------------------------------------------- 1 | from ..distance_metrics.distance import (distance_to_closest_record, 2 | dcr_stats, 3 | number_of_dcr_equal_to_zero, 4 | validation_dcr_test) 5 | from .privacy import (adversary_dataset, 6 | membership_inference_test) 7 | 8 | __all__ = [ 9 | "distance_to_closest_record", 10 | "dcr_stats", 11 | "number_of_dcr_equal_to_zero", 12 | "validation_dcr_test", 13 | "adversary_dataset", 14 | "membership_inference_test", 15 | ] -------------------------------------------------------------------------------- /sure/privacy/privacy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import polars as pl 4 | 5 | from typing import Dict, List, Tuple 6 | 7 | from sklearn.metrics import precision_score 8 | 9 | from sure.privacy import distance_to_closest_record 10 | from sure import _save_to_json, _drop_cols 11 | 12 | # ATTACK SANDBOX 13 | def _polars_to_pandas(dataframe: pl.DataFrame | pl.LazyFrame): 14 | """Convert polars DataFrame to pandas DataFrame""" 15 | if isinstance(dataframe, pl.DataFrame): 16 | dataframe = dataframe.to_pandas() 17 | if isinstance(dataframe, pl.LazyFrame): 18 | dataframe = dataframe.collect().to_pandas() 19 | return dataframe 20 | 21 | def _pl_pd_to_numpy(dataframe: pl.DataFrame | pl.LazyFrame | pl.Series | pd.DataFrame): 22 | """Convert polars or pandas DataFrame to numpy array""" 23 | if isinstance(dataframe, (pl.DataFrame, pl.Series, pd.DataFrame)): 24 | dataframe = dataframe.to_numpy() 25 | if isinstance(dataframe, pl.LazyFrame): 26 | dataframe = dataframe.collect().to_numpy() 27 | return dataframe 28 | 29 | def adversary_dataset( 30 | training_set: pd.DataFrame | pl.DataFrame | pl.LazyFrame, 31 | validation_set: pd.DataFrame | pl.DataFrame | pl.LazyFrame, 32 | original_dataset_sample_fraction: float = 0.2, 33 | ) -> pd.DataFrame: 34 | """ 35 | Create an adversary dataset for the Membership Inference Test given a training 36 | and validation set. The validation set must be smaller than the training set. 37 | 38 | The size of the resulting adversary dataset is a fraction of the sum of the training 39 | set size and the validation set size. 40 | 41 | It takes half of the final rows from the training set and the other half from the 42 | validation set. It adds a column to mark which rows was sampled from the training set. 43 | 44 | Parameters 45 | ---------- 46 | training_set : pd.DataFrame 47 | The training set as a pandas DataFrame. 48 | validation_set : pd.DataFrame 49 | The validation set as a pandas DataFrame. 50 | original_dataset_sample_fraction : float, optional 51 | How many rows (a fraction from 0 to 1) to sample from the concatenation of the 52 | training and validation set, by default 0.2 53 | 54 | Returns 55 | ------- 56 | pd.DataFrame 57 | A new pandas DataFrame in which half of the rows come from the training set and 58 | the other half come from the validation set. 59 | """ 60 | training_set = _polars_to_pandas(training_set) 61 | validation_set = _polars_to_pandas(validation_set) 62 | 63 | sample_number_of_rows = ( 64 | training_set.shape[0] + validation_set.shape[0] 65 | ) * original_dataset_sample_fraction 66 | 67 | # if the validation set is very small, we'll set the number of rows to sample equal to 68 | # the number of rows of the validation set, that is every row of the validation set 69 | # is going into the adversary set. 70 | sample_number_of_rows = min(int(sample_number_of_rows / 2), validation_set.shape[0]) 71 | 72 | sampled_from_training = training_set.sample( 73 | sample_number_of_rows, replace=False, random_state=42 74 | ) 75 | sampled_from_training["privacy_test_is_training"] = True 76 | 77 | sampled_from_validation = validation_set.sample( 78 | sample_number_of_rows, replace=False, random_state=42 79 | ) 80 | sampled_from_validation["privacy_test_is_training"] = False 81 | 82 | adversary_dataset = pd.concat( 83 | [sampled_from_training, sampled_from_validation], ignore_index=True 84 | ) 85 | adversary_dataset = adversary_dataset.sample(frac=1).reset_index(drop=True) 86 | return adversary_dataset 87 | 88 | def membership_inference_test( 89 | adversary_dataset: pd.DataFrame | pl.DataFrame | pl.LazyFrame, 90 | synthetic_dataset: pd.DataFrame | pl.DataFrame | pl.LazyFrame, 91 | adversary_guesses_ground_truth: np.ndarray | pd.DataFrame | pl.DataFrame | pl.LazyFrame | pl.Series, 92 | parallel: bool = True, 93 | save_data = True, 94 | path_to_json: str = "" 95 | ): 96 | """ 97 | Simulate a Membership Inference Attack on the provided synthetic dataset using an adversary dataset. 98 | 99 | Parameters 100 | ---------- 101 | adversary_dataset : pd.DataFrame, pl.DataFrame, or pl.LazyFrame 102 | The dataset used by the adversary, containing features for the attack simulation. 103 | synthetic_dataset : pd.DataFrame, pl.DataFrame, or pl.LazyFrame 104 | The synthetic dataset on which the Membership Inference Attack is performed. 105 | adversary_guesses_ground_truth : np.ndarray, pd.DataFrame, pl.DataFrame, pl.LazyFrame, or pl.Series 106 | Ground truth labels indicating whether a sample is from the original training dataset or not. 107 | parallel : bool, optional 108 | Whether to use parallel processing for distance calculations, by default True. 109 | save_data : bool 110 | If True, saves the DCR information into the JSON file used to generate the final report, by default True. 111 | path_to_json : str, optional 112 | Path to save the attack output as a JSON file. If empty, the output is not saved, by default "". 113 | 114 | Returns 115 | ------- 116 | dict 117 | A dictionary containing the attack results, including distance thresholds, precisions, and risk score. 118 | 119 | Notes 120 | ----- 121 | - This function simulates an attack where an adversary attempts to distinguish between real and synthetic samples. 122 | - The attack results are saved to a JSON file if `path_to_json` is provided. 123 | """ 124 | 125 | # Convert datasets 126 | adversary_dataset = _polars_to_pandas(adversary_dataset) 127 | synthetic_dataset = _polars_to_pandas(synthetic_dataset) 128 | adversary_guesses_ground_truth = _pl_pd_to_numpy(adversary_guesses_ground_truth) 129 | 130 | adversary_dataset=adversary_dataset.drop(["privacy_test_is_training"],axis=1) 131 | 132 | # Drop columns that are present in adversary_dataset but are missing in synthetic_dataset and vice versa 133 | synthetic_dataset, adversary_dataset = _drop_cols(synthetic_dataset, adversary_dataset) 134 | 135 | dcr_adversary_synth = distance_to_closest_record("other", 136 | adversary_dataset, 137 | synthetic_dataset, 138 | parallel=parallel, 139 | save_data=False 140 | ) 141 | 142 | adversary_precisions = [] 143 | distance_thresholds = np.quantile( 144 | dcr_adversary_synth, [0.5, 0.25, 0.2, np.min(dcr_adversary_synth) + 0.01] 145 | ) 146 | for distance_threshold in distance_thresholds: 147 | adversary_guesses = dcr_adversary_synth < distance_threshold 148 | adversary_precision = precision_score( 149 | adversary_guesses_ground_truth, adversary_guesses, zero_division=0 150 | ) 151 | adversary_precisions.append(max(adversary_precision, 0.5)) 152 | adversary_precision_mean = np.mean(adversary_precisions).item() 153 | membership_inference_mean_risk_score = max( 154 | (adversary_precision_mean - 0.5) * 2, 0.0 155 | ) 156 | 157 | attack_output = { 158 | "adversary_distance_thresholds": distance_thresholds.tolist(), 159 | "adversary_precisions": adversary_precisions, 160 | "membership_inference_mean_risk_score": membership_inference_mean_risk_score, 161 | } 162 | 163 | if save_data: 164 | _save_to_json("MIA_attack", attack_output, path_to_json) 165 | return attack_output -------------------------------------------------------------------------------- /sure/report_generator/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="light" 3 | primaryColor="#6f329a" 4 | secondaryBackgroundColor="#a58cc9" 5 | textColor="#3d3d3d" 6 | 7 | [browser] 8 | gatherUsageStats = false -------------------------------------------------------------------------------- /sure/report_generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/sure/report_generator/__init__.py -------------------------------------------------------------------------------- /sure/report_generator/pages/privacy.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import seaborn.objects as so 5 | 6 | @st.cache_data 7 | def plot_DCR(train_data, val_data=None): 8 | """ 9 | Plot histograms for Distance to Closest Record (DCR) for synthetic training and validation datasets. 10 | 11 | This function creates histograms to visualize the distribution of DCR values for synthetic training data. 12 | If validation data is provided, it also plots the DCR distribution for the synthetic validation data. 13 | 14 | Parameters 15 | ---------- 16 | train_data : array-like 17 | DCR values for the synthetic training dataset. 18 | val_data : array-like, optional 19 | DCR values for the synthetic validation dataset. If None, only the synthetic training DCR histogram 20 | is plotted. Default is None. 21 | 22 | Returns 23 | ------- 24 | None 25 | The function does not return any value. It plots the histograms using Streamlit. 26 | """ 27 | # Convert data to pandas DataFrame 28 | df = pd.DataFrame({'DCR': train_data, 'Data': 'Synthetic-Trainining DCR'}) 29 | if val_data is not None: 30 | df_val = pd.DataFrame({'DCR': val_data, 'Data': 'Synthetic-Validation DCR'}) 31 | df = pd.concat([df, df_val]) 32 | 33 | # colors = ['#6268ff','#ccccff'] 34 | # chart = alt.Chart(df).mark_bar(opacity=0.6).encode( 35 | # alt.X('DCR:Q', bin=alt.Bin(maxbins=15)), 36 | # alt.Y('count()', stack=None), 37 | # color=alt.Color('Data:N', scale=alt.Scale(range=colors)) 38 | # ).properties( 39 | # title='Histograms of Synthetic Train and Validation Data' if val_data is not None else 'Histograms of Synthetic Train', 40 | # width=600, 41 | # height=400 42 | # ) 43 | # # Display chart in Streamlit 44 | # st.altair_chart(chart) 45 | 46 | f = plt.figure(figsize=(8, 4)) 47 | sf = f.subfigures(1, 1) 48 | ( 49 | so.Plot(df, x="DCR") 50 | .facet("Data") 51 | .add(so.Bars(color="#6268ff"), so.Hist()) 52 | .on(sf) 53 | .plot() 54 | ) 55 | 56 | for ax in sf.axes: 57 | plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontsize=6) 58 | plt.setp(ax.get_yticklabels(), fontsize=8) 59 | ax.set_xlabel(ax.get_xlabel(), fontsize=6) # Set x-axis label font size 60 | ax.set_ylabel(ax.get_ylabel(), fontsize=6) # Set y-axis label font size 61 | ax.set_title(ax.get_title(), fontsize=8) # Set title font size 62 | 63 | # Display the plot in Streamlit 64 | st.pyplot(f) 65 | 66 | @st.cache_data 67 | def dcr_stats_table(train_stats, val_stats=None): 68 | """ 69 | Display a table with overall Distance to Closest Record (DCR) statistics. 70 | 71 | Parameters 72 | ---------- 73 | train_stats : dict 74 | Dictionary containing statistics for the synthetic training dataset. 75 | val_stats : dict, optional 76 | Dictionary containing statistics for the synthetic validation dataset. Default is None. 77 | 78 | Returns 79 | ------- 80 | None 81 | The function outputs a table using Streamlit. 82 | """ 83 | df1 = pd.DataFrame.from_dict(train_stats, orient='index', columns=['Synth-Train']) 84 | if val_stats: 85 | df2 = pd.DataFrame.from_dict(val_stats, orient='index', columns=['Synth-Val']) 86 | # Merge the train and val DataFrames 87 | merged_df = pd.concat([df1, df2], axis=1) 88 | st.table(merged_df) 89 | 90 | @st.cache_data 91 | def dcr_validation(dcr_val, dcr_zero_train=None, dcr_zero_val=None): 92 | """ 93 | Display the DCR share value and additional metrics for clones. 94 | 95 | Parameters 96 | ---------- 97 | dcr_val : dict 98 | Dictionary containing DCR share values and warnings. 99 | dcr_zero_train : int, optional 100 | Number of clones in the synthetic training dataset. Default is None. 101 | dcr_zero_val : int, optional 102 | Number of clones in the synthetic validation dataset. Default is None. 103 | 104 | Returns 105 | ------- 106 | None 107 | The function outputs metrics and information using Streamlit. 108 | """ 109 | perc = dcr_val["percentage"] 110 | # st.write("The share of records of the synthetic dataset that are closer to the training set than to the validation set is: ", round(perc,1),"%") 111 | cols = st.columns(2) 112 | with cols[0]: 113 | st.metric("DCR closer to training", str(round(perc,1))+"%", help="For each synthetic record we computed its DCR with respect to the training set as well as with respect to the validation set. The validation set was not used to train the generative model. If we can find that the DCR values between synthetic and training (histogram on the left) are not systematically smaller than the DCR values between synthetic and validation (histogram on the right), we gain evidence of a high level of privacy.The share of records that are then closer to a training than to a validation record serves us as our proposed privacy risk measure. If that resulting share is then close to (or smaller than) 50%, we gain empirical evidence of the training and validation data being interchangeable with respect to the synthetic data. This in turn allows to make a strong case for plausible deniability for any individual, as the synthetic data records do not allow to conjecture whether an individual was or was not contained in the training dataset.") 114 | with cols[1]: 115 | if dcr_val["warnings"] != '': 116 | st.write(dcr_val["warnings"]) 117 | # st.caption("N.B. the closer this value is to 50%, the less the synthetic dataset is vulnerable to reidentification attacks!") 118 | 119 | cols2 = st.columns(2) 120 | # with cols2[0]: 121 | if dcr_zero_train: 122 | # Table with the number of DCR equal to zero 123 | st.metric("Clones Synth-Train", dcr_zero_train, help="The number of clones shows how many rows of the synthetic dataset have an identical match in the training dataset. A very low value indicates a low risk in terms of privacy. Ideally this value should be close to 0, but some peculiar characteristics of the training dataset (small size or low column cardinality) may lead to a higher value. The duplicatestable shows the number of duplicates (identical rows) in the training dataset and the synthetic dataset: similar percentages mean higher utility.") 124 | # with cols2[1]: 125 | if dcr_zero_val: 126 | st.metric("Clones Synth-Val", dcr_zero_val, help="The number of clones shows how many rows of the synthetic dataset have an identical match in the validation dataset. A very low value indicates a low risk in terms of privacy. Ideally this value should be close to 0, but some peculiar characteristics of the training dataset (small size or low column cardinality) may lead to a higher value.") 127 | 128 | def _MIA(): 129 | cols = st.columns([1,3,2]) 130 | with cols[0]: 131 | st.metric("MI mean risk score", str(round(st.session_state["MIA_attack"]["membership_inference_mean_risk_score"],3)*100)+"%", help="The MI Risk score is computed as (precision - 0.5) * 2.\n MI Risk Score smaller than 0.2 (20%) are considered to be very LOW RISK of disclosure due to membership inference.") 132 | with cols[1]: 133 | df_MIA = pd.DataFrame(st.session_state["MIA_attack"]) 134 | st.dataframe(df_MIA.drop(columns=df_MIA.columns[-1]).iloc[::-1], hide_index=True) 135 | 136 | def main(): 137 | """ 138 | Main function to configure and run the Streamlit application. 139 | 140 | The application provides a report on privacy metrics for a synthetic dataset, 141 | using the SURE library to visualize and interpret the data. 142 | 143 | Returns 144 | ------- 145 | None 146 | The function runs the Streamlit app. 147 | """ 148 | # Set app conifgurations 149 | st.set_page_config(layout="wide", page_title='SURE', page_icon=':large_purple_square:') 150 | 151 | # Header, subheader and description 152 | st.title('SURE') 153 | st.subheader('Synthetic Data: Utility, Regulatory compliance, and Ethical privacy') 154 | st.write( 155 | """This report provides a visual digest of the privacy metrics computed 156 | with the library [SURE](https://github.com/Clearbox-AI/SURE) on the synthetic dataset under test.""") 157 | 158 | ### PRIVACY 159 | st.header("Privacy", divider='violet') 160 | st.sidebar.markdown("# Privacy") 161 | 162 | ## Distance to closest record 163 | st.subheader("Distance to closest record", help="Distances-to-Closest-Record are individual-level distances of synthetic records with respect to their corresponding nearest neighboring records from the training dataset. A DCR of 0 corresponds to an identical match. These histograms are used to assess whether the synthetic data is a simple copy or minor perturbation of the training data, resulting in high risk of disclosure. There is one DCR histogram computed only on the Training Set (histogram on the left) and one computed for the Synthetic Set vs Training Set (histogram on the right). Ideally, the two histograms should have a similar shape and, above all, the histogram on the right should be far enough away from 0.") 164 | # DCR statistics 165 | st.write("DCR statistics") 166 | col1, buff, col2 = st.columns([3,0.5,3]) 167 | with col1: 168 | if "dcr_synth_train_stats" in st.session_state: 169 | dcr_stats_table(st.session_state["dcr_synth_train_stats"], 170 | st.session_state["dcr_synth_val_stats"]) 171 | with col2: 172 | if "dcr_validation" in st.session_state: 173 | dcr_validation(st.session_state["dcr_validation"], 174 | st.session_state["dcr_synth_train_num_of_zeros"], 175 | st.session_state["dcr_synth_val_num_of_zeros"]) 176 | 177 | # Synth-train DCR and synth-validation DCR histograms 178 | if "dcr_synth_train" in st.session_state: 179 | plot_DCR(st.session_state["dcr_synth_train"], 180 | st.session_state["dcr_synth_val"]) 181 | 182 | st.divider() 183 | 184 | ## Membership Inference Attack 185 | st.subheader("Membership Inference Attack", help="Membership inference attacks seek to infer membership of an individual record in the training set from which the synthetic data was generated. We consider a hypothetical adversary which has access to a subset of records K containing half instances from the training set and half instances from the validation set and that attempts a membership inference attack as follows: given an individual record k in K and the synthetic set S, the adversary identifies the closest record s in S; the adversary determines that k is part of the training set if d(k, s) is lower than a certain threshold. We evaluate the success rate of such attack strategy. Precision represents the number of correct decisions the adversary has made. Since 50% of the instances in K come from the training set and 50% come from the validation set, the baseline precision is 0.5, corresponding to a random choice and any value above that reflects an increasing levels of disclosure risk.") 186 | if "MIA_attack" in st.session_state: 187 | _MIA() 188 | 189 | if __name__ == "__main__": 190 | main() 191 | -------------------------------------------------------------------------------- /sure/report_generator/report_app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import numpy as np 4 | import seaborn as sns 5 | import seaborn.objects as so 6 | import matplotlib as mpl 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | import os 10 | 11 | from report_generator import _load_from_json, _convert_to_dataframe 12 | 13 | 14 | def _plot_hist(real_data, synth_data): 15 | """Plot the synth-train DCR and synth-validation DCR histograms""" 16 | # Convert data to pandas DataFrame 17 | real_label = pd.DataFrame({'is_real': ['Real']*len(real_data)}) 18 | df_real = pd.concat([real_data, real_label], axis=1) 19 | synth_label = pd.DataFrame({'is_real': ['Synthetic']*len(synth_data)}) 20 | df_synth = pd.concat([synth_data, synth_label], axis=1) 21 | 22 | df = pd.concat([df_real, df_synth]) 23 | cols = df.columns.to_list() 24 | cols.remove("is_real") 25 | 26 | # Create dropdown menu for feature selection 27 | selected_feature = st.selectbox(label = 'Select a feature from the dataset:', 28 | options = ["Select a feature..."] + cols, 29 | index = None, 30 | placeholder = "Select a feature...", 31 | label_visibility = "collapsed") 32 | 33 | if selected_feature and selected_feature!="Select a feature...": 34 | f = plt.figure(figsize=(8, 4)) 35 | sf = f.subfigures(1, 1) 36 | ( 37 | so.Plot(df, x=selected_feature) 38 | .facet("is_real") 39 | .add(so.Bars(color="#3e42a8"), so.Hist()) 40 | .on(sf) 41 | .plot() 42 | ) 43 | 44 | for ax in sf.axes: 45 | plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontsize=6) 46 | plt.setp(ax.get_yticklabels(), fontsize=8) 47 | ax.set_xlabel(ax.get_xlabel(), fontsize=6) # Set x-axis label font size 48 | ax.set_ylabel(ax.get_ylabel(), fontsize=6) # Set y-axis label font size 49 | ax.set_title(ax.get_title(), fontsize=8) # Set title font size 50 | 51 | # Display the plot in Streamlit 52 | st.pyplot(f) 53 | 54 | 55 | def _plot_heatmap(data, title): 56 | df = pd.DataFrame(data) 57 | # Generate a mask for the upper triangle 58 | mask = np.triu(np.ones_like(df, dtype=bool), 1) 59 | 60 | # Set up the matplotlib figure 61 | f, ax = plt.subplots(figsize=(13, 11)) 62 | 63 | # Generate a custom diverging colormap 64 | cmap = sns.diverging_palette(20, 265, as_cmap=True) 65 | 66 | # Draw the heatmap with the mask, correct aspect ratio, and column names as labels 67 | sns.heatmap(df, cmap=cmap, center=0, 68 | square=True, mask=mask, linewidths=.5, 69 | cbar_kws={"shrink": .75}, 70 | xticklabels=df.columns, # Display column names on x-axis 71 | yticklabels=df.columns # Display column names on y-axis 72 | ) 73 | 74 | # Rotate x-axis labels for better readability and add title 75 | plt.xticks(rotation=45, ha='right') 76 | ax.set_title(title, fontsize=16, pad=20) 77 | 78 | # Display the plot 79 | st.pyplot(f) 80 | 81 | def _display_feature_data(data): 82 | """Display the data for a selected feature""" 83 | # Get list of feature names 84 | feature_names = list(data.keys()) 85 | 86 | # Create dropdown menu for feature selection 87 | selected_feature = st.selectbox(label = 'Select a statistical quantity:', 88 | options = ["Select a statistical quantity..."] + feature_names, 89 | index = None, 90 | placeholder = "Select a statistical quantity...", 91 | label_visibility = "collapsed") 92 | 93 | # If a feature has been selected, display its data and create another dropdown menu 94 | if selected_feature and selected_feature!="Select a statistical quantity...": 95 | # Get data for selected feature 96 | feature_data = data[selected_feature] 97 | 98 | # Convert data to DataFrame 99 | df_real = pd.DataFrame(feature_data['real'], index=["Original"]) 100 | df_synthetic = pd.DataFrame(feature_data['synthetic'], index=["Synthetic"]) 101 | df = pd.concat([df_real, df_synthetic]) 102 | # Add row with the difference between the synthetic dataset and the real one 103 | df.loc['Difference'] = df.iloc[0] - df.iloc[1] 104 | 105 | # Display DataFrame 106 | st.write(df) 107 | 108 | # Remove selected feature from list of feature names 109 | feature_names.remove(selected_feature) 110 | 111 | # If there are still features left, create another dropdown menu 112 | if feature_names: 113 | _display_feature_data({name: data[name] for name in feature_names}) 114 | 115 | def _ml_utility(): 116 | def _select_all(): 117 | st.session_state["selected_models"] = models_df.index.values 118 | def _deselect_all(): 119 | st.session_state["selected_models"] = [] 120 | 121 | if 'selected_models' not in st.session_state: 122 | st.session_state["selected_models"] = [] 123 | 124 | container = st.container() 125 | 126 | cols = st.columns([1.2,1]) 127 | default = ["LinearSVC", "Perceptron", "LogisticRegression", "XGBClassifier", "SVC", "BaggingClassifier", "SGDClassifier", "RandomForestClassifier", "AdaBoostClassifier", "KNeighborsClassifier", "DecisionTreeClassifier", "DummyClassifier"] 128 | 129 | models_real_df = _convert_to_dataframe(st.session_state["models"]).set_index(['Model']) 130 | models_synth_df = _convert_to_dataframe(st.session_state["models_synth"]).set_index(['Model']) 131 | models_df = pd.concat([models_real_df, models_synth_df], axis=1) 132 | interleaved_columns = [col for pair in zip(models_real_df.columns, models_synth_df.columns) for col in pair] 133 | models_df = models_df[interleaved_columns] 134 | models_delta_df = _convert_to_dataframe(st.session_state["models_delta"]).set_index(['Model']) 135 | 136 | st.session_state["selected_models"] = default 137 | 138 | options = st.multiselect(label="Select ML models to show in the table:", 139 | options=models_df.index.values, 140 | default= [x for x in st.session_state["selected_models"] if x in models_df.index.values], 141 | placeholder="Select ML models...", 142 | key="models_multiselect") 143 | subcols = st.columns([1,1,3]) 144 | with subcols[0]: 145 | butt1 = st.button("Select all models") 146 | with subcols[1]: 147 | butt2 = st.button("Deselect all models") 148 | if butt1: 149 | options = models_df.index.values 150 | if butt2: 151 | options = [] 152 | 153 | st.text("") # vertical space 154 | st.text("") # vertical space 155 | st.text("ML Metrics") 156 | st.dataframe(models_df.loc[options].style.highlight_max(axis=0, subset=models_df.columns[:-2], color="#5cbd91")) 157 | st.text("") # vertical space 158 | st.text("") # vertical space 159 | st.text("ML Delta Metrics", help="Difference between the metrics of the original dataset and the synthetic dataset.") 160 | st.dataframe(models_delta_df.abs().loc[options].style.highlight_min(axis=0, subset=models_delta_df.columns[:-1], color="#5cbd91").highlight_max(axis=0, subset=models_delta_df.columns[:-1], color="#c45454")) 161 | 162 | # def _ml_utility(models_df): 163 | # def _select_all(): 164 | # st.session_state['selected_models'] = models_df.index.values 165 | # def _deselect_all(): 166 | # st.session_state['selected_models'] = [] 167 | 168 | # if 'selected_models' not in st.session_state: 169 | # st.session_state['selected_models'] = ["LinearSVC", "Perceptron", "LogisticRegression", "XGBClassifier", "SVC", "BaggingClassifier", "SGDClassifier", "RandomForestClassifier", "AdaBoostClassifier", "KNeighborsClassifier", "DecisionTreeClassifier", "DummyClassifier"] 170 | 171 | # container = st.container() 172 | 173 | # col1, col2, col3 = st.columns([1, 1, 1]) 174 | # with col1: 175 | # st.button("Select all models", on_click=_select_all, key="butt1") 176 | # with col2: 177 | # st.button("Deselect all models", on_click=_deselect_all, key="butt2") 178 | 179 | # selected_models = container.multiselect( 180 | # "Select ML models to show in the table:", 181 | # models_df.index.values, 182 | # default=st.session_state['selected_models'], 183 | # key='selected_models' 184 | # ) 185 | 186 | # with col1: 187 | # st.dataframe(models_df.loc[selected_models].style.highlight_max(axis=0, subset=models_df.columns[:-1], color="#99ffcc")) 188 | # with col2: 189 | # st.session_state['selected_models'] 190 | # with col3: 191 | # selected_models 192 | 193 | # def _ml_utility(): 194 | # def select_all(): 195 | # st.session_state['selected_models'] = ['A', 'B', 'C', 'D'] 196 | # def deselect_all(): 197 | # st.session_state['selected_models'] = [] 198 | 199 | # if 'selected_models' not in st.session_state: 200 | # st.session_state['selected_models'] = [] 201 | 202 | # container = st.container() 203 | 204 | # col1, col2 = st.columns([1, 1]) 205 | # with col1: 206 | # st.button("Select all", on_click=select_all, key="butt1") 207 | # with col2: 208 | # st.button("Deselect all", on_click=deselect_all, key="butt2") 209 | 210 | # selected_models = container.multiselect( 211 | # "Select one or more options:", 212 | # ['A', 'B', 'C', 'D'], 213 | # default=st.session_state['selected_models'], 214 | # key='selected_models' 215 | # ) 216 | 217 | # st.markdown('##') 218 | # cols = st.columns(2) 219 | # with cols[0]: 220 | # "st.session_state['selected_models']:" 221 | # st.session_state['selected_models'] 222 | # with cols[1]: 223 | # "selected_models:" 224 | # selected_models 225 | 226 | def main(real_df, synth_df, path_to_json): 227 | """ 228 | Main function to generate a Streamlit-based report for analyzing and visualizing 229 | utility metrics of a synthetic dataset using the SURE library. 230 | 231 | Parameters 232 | ---------- 233 | real_df : str 234 | Path to the pickle file containing the real dataset. 235 | synth_df : str 236 | Path to the pickle file containing the synthetic dataset. 237 | path_to_json : str 238 | Path to the JSON file for loading session state data. If not provided, the default session state is used. 239 | 240 | Returns 241 | ------- 242 | None 243 | The function initializes and runs a Streamlit app for visualizing and comparing utility metrics. 244 | """ 245 | # Set app conifgurations 246 | st.set_page_config(layout="wide", page_title='SURE', page_icon='https://raw.githubusercontent.com/Clearbox-AI/SURE/main/docs/source/img/favicon.ico') 247 | 248 | # Header and subheader and description 249 | st.title('SURE') 250 | st.subheader('Synthetic Data: Utility, Regulatory compliance, and Ethical privacy') 251 | st.write( 252 | """This report provides a visual digest of the utility metrics computed 253 | with the library [SURE](https://github.com/Clearbox-AI/SURE) on the synthetic dataset under test.""") 254 | 255 | ### UTILITY 256 | st.header('Utility', divider='gray') 257 | st.sidebar.markdown("# Utility") 258 | 259 | # Load real dataset 260 | real_df = pd.read_pickle(real_df) 261 | synth_df = pd.read_pickle(synth_df) 262 | 263 | ## Plot real and synthetic data distributions 264 | st.subheader("Dataset feature distribution", help="Distribution histogram plot of the selected feature for the real and synthetic dataset (before pre-processing).") 265 | _plot_hist(real_df, synth_df) 266 | 267 | # Load data in the session state, so that it is available in all the pages of the app 268 | if path_to_json: 269 | st.session_state = _load_from_json(path_to_json) 270 | else: 271 | st.session_state = _load_from_json("") 272 | 273 | ## Statistical similarity 274 | st.subheader("Statistical similarity", help="Statistical quantities computed for each feature of the real and synthetic dataset (after pre-processing).") 275 | 276 | # Features distribution 277 | # plot_distribution() 278 | 279 | # General statistics 280 | if "num_features_comparison" in st.session_state and st.session_state["num_features_comparison"]: 281 | features_comparison = st.session_state["num_features_comparison"] 282 | if "cat_features_comparison" in st.session_state and st.session_state["cat_features_comparison"]: 283 | if "features_comparison" in locals(): 284 | features_comparison = {**features_comparison, **st.session_state["cat_features_comparison"]} 285 | else: 286 | features_comparison = st.session_state["cat_features_comparison"] 287 | if "time_features_comparison" in st.session_state and st.session_state["time_features_comparison"]: 288 | if "features_comparison" in locals(): 289 | features_comparison = {**features_comparison, **st.session_state["time_features_comparison"]} 290 | else: 291 | features_comparison = st.session_state["time_features_comparison"] 292 | if features_comparison: 293 | _display_feature_data(features_comparison) 294 | 295 | st.markdown('###') 296 | 297 | # Correlation 298 | st.subheader("Feature correlation", help="This matrix shows the correlation between ordinal and categorical features. These correlation coefficients are obtained using the mutual information metric. Mutual information describes relationships in terms of uncertainty.") 299 | if "real_corr" in st.session_state: 300 | cb_rea_corr = st.checkbox("Show original dataset correlation matrix", value=False) 301 | if cb_rea_corr: 302 | _plot_heatmap(st.session_state["real_corr"], 'Original Dataset Correlation Matrix Heatmap') 303 | if "synth_corr" in st.session_state: 304 | cb_synth_corr = st.checkbox("Show synthetic dataset correlation matrix", value=False) 305 | if cb_synth_corr: 306 | _plot_heatmap(st.session_state["synth_corr"], 'Synthetic Dataset Correlation Matrix Heatmap') 307 | if "diff_corr" in st.session_state: 308 | cb_diff_corr = st.checkbox("Show difference between the correlation matrix of the original dataset and the one of the synthetic dataset", value=False) 309 | if cb_diff_corr: 310 | _plot_heatmap(st.session_state["diff_corr"], 'Original-Synthetic Dataset Correlation Matrix Difference Heatmap') 311 | 312 | st.divider() 313 | 314 | ## ML Utility 315 | st.subheader("ML utility", help="The datasets were evaluated using various machine learning models. The metrics presented below reflect the performance of each model tested, applied to both the real and synthetic datasets.") 316 | if "models" in st.session_state: 317 | _ml_utility() 318 | 319 | if __name__ == "__main__": 320 | # Create an ArgumentParser object 321 | parser = argparse.ArgumentParser(description="This script runs the utility and privacy report app of the SURE library.") 322 | parser.add_argument('real_data', type=str, default="", help='real dataframe') 323 | parser.add_argument('synth_data', type=str, default="", help='synthetic dataframe') 324 | parser.add_argument('path', type=str, default="", help='path where the json file with the results is saved') 325 | args = parser.parse_args() 326 | 327 | main(args.real_data, args.synth_data, args.path) -------------------------------------------------------------------------------- /sure/report_generator/report_generator.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import pkg_resources 3 | import sys 4 | 5 | import json 6 | import os 7 | import tempfile 8 | 9 | import pandas as pd 10 | import polars as pl 11 | import numpy as np 12 | 13 | # Function to run the streamlit app 14 | def report(df_real: pd.DataFrame | pl.DataFrame | pl.LazyFrame, 15 | df_synth: pd.DataFrame | pl.DataFrame | pl.LazyFrame, 16 | path_to_json:str = ""): 17 | """Generate the report app""" 18 | # Check dataframe type 19 | if isinstance(df_real, pd.DataFrame) and isinstance(df_synth, pd.DataFrame): 20 | pass 21 | elif isinstance(df_real, pl.DataFrame) and isinstance(df_synth, pl.DataFrame): 22 | df_real = df_real.to_pandas() 23 | df_synth = df_synth.to_pandas() 24 | elif isinstance(df_real, pl.LazyFrame) and isinstance(df_synth, pl.LazyFrame): 25 | df_real = df_real.collect().to_pandas() 26 | df_synth = df_synth.collect().to_pandas() 27 | else: 28 | sys.exit('ErrorType\nThe datatype provided is not supported or the two datasets have different types.') 29 | 30 | # Save the DataFrame to a temporary file (pickle format) 31 | # if df_real: 32 | with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmpfile: 33 | df_path_real = tmpfile.name 34 | df_real.to_pickle(df_path_real) # Save the DataFrame as a pickle file 35 | 36 | with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmpfile: 37 | df_path_synth = tmpfile.name 38 | df_synth.to_pickle(df_path_synth) # Save the DataFrame as a pickle file 39 | 40 | report_path = pkg_resources.resource_filename('sure.report_generator', 'report_app.py') 41 | process = subprocess.run(['streamlit', 'run', report_path, df_path_real, df_path_synth, path_to_json]) 42 | return process 43 | 44 | def _convert_to_serializable(obj: object): 45 | """Recursively convert DataFrames and other non-serializable objects in a nested dictionary to serializable formats""" 46 | if isinstance(obj, (pd.DataFrame, pl.DataFrame, pl.LazyFrame)): 47 | if isinstance(obj, pl.DataFrame): 48 | obj = obj.to_pandas() 49 | if isinstance(obj, pl.LazyFrame): 50 | obj = obj.collect().to_pandas() 51 | 52 | # Convert index to column only if index is non-numerical 53 | if obj.index.dtype == 'object' or pd.api.types.is_string_dtype(obj.index): 54 | obj = obj.reset_index() 55 | 56 | # Convert datetime columns to string 57 | for col in obj.columns: 58 | if pd.api.types.is_datetime64_any_dtype(obj[col]): 59 | obj[col] = obj[col].astype(str) 60 | 61 | return obj.to_dict(orient='records') 62 | elif isinstance(obj, dict): 63 | return {k: _convert_to_serializable(v) for k, v in obj.items()} 64 | elif isinstance(obj, list): 65 | return [_convert_to_serializable(item) for item in obj] 66 | elif isinstance(obj, np.integer): 67 | return int(obj) 68 | elif isinstance(obj, np.floating): 69 | return float(obj) 70 | elif isinstance(obj, np.ndarray): 71 | return obj.tolist() 72 | else: 73 | return obj 74 | 75 | def _save_to_json(data_name: str, 76 | new_data: object, 77 | path_to_json: str): 78 | """Save data into a JSON file in the folder where the user is working""" 79 | # Check if the file exists 80 | path = os.path.join(path_to_json,"data.json") 81 | if os.path.exists(path): 82 | # Read the existing data from the file 83 | with open(path, 'r') as file: 84 | try: 85 | data = json.load(file) 86 | except json.JSONDecodeError: 87 | data = {} # Initialize to an empty dictionary if the file is empty or invalid 88 | else: 89 | data = {} 90 | 91 | # Convert new_data to a serializable format if it is a DataFrame or contains DataFrames 92 | serializable_data = _convert_to_serializable(new_data) 93 | 94 | # Update the data dictionary with the new_data 95 | data[data_name] = serializable_data 96 | 97 | # Write the updated data back to the file 98 | with open(path, 'w') as file: 99 | json.dump(data, file, indent=4) 100 | 101 | def _load_from_json(path_to_json: str, 102 | data_name: str = None 103 | ): 104 | """Load data from a JSON file "data.json" in the folder where the user is working""" 105 | # Check if the file exists 106 | path = os.path.join(path_to_json,"data.json") 107 | if not os.path.exists(path): 108 | raise FileNotFoundError("The data.json file does not exist.") 109 | 110 | # Read the data from the file 111 | with open(path, 'r') as file: 112 | try: 113 | data = json.load(file) 114 | except json.JSONDecodeError: 115 | raise ValueError("The data.json file is empty or invalid.") 116 | 117 | if data_name: 118 | # Extract the relevant data 119 | if data_name not in data: 120 | raise KeyError(f"Key '{data_name}' not found in dictionary.") 121 | data = data.get(data_name, None) 122 | 123 | return data 124 | 125 | def _convert_to_dataframe(obj): 126 | """Convert nested dictionaries back to DataFrames""" 127 | if isinstance(obj, list) and all(isinstance(item, dict) for item in obj): 128 | return pd.DataFrame(obj) 129 | elif isinstance(obj, dict): 130 | return {k: _convert_to_dataframe(v) for k, v in obj.items()} 131 | else: 132 | return obj -------------------------------------------------------------------------------- /sure/utility.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Callable, List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import polars as pl 6 | import polars.selectors as cs 7 | import random 8 | from functools import reduce 9 | from sklearn.metrics import accuracy_score, roc_auc_score 10 | from sklearn.preprocessing import LabelEncoder 11 | from clearbox_preprocessor import Preprocessor 12 | 13 | from sure import _save_to_json 14 | from sure._lazypredict import LazyClassifier, LazyRegressor 15 | 16 | def _to_polars_df(df): 17 | """Converting Real and Synthetic Dataset into pl.DataFrame""" 18 | if isinstance(df, pd.DataFrame): 19 | df = pl.from_pandas(df) 20 | elif isinstance(df, np.ndarray): 21 | df = pl.from_numpy(df) 22 | elif isinstance(df, pl.LazyFrame): 23 | df = df.collect() 24 | elif isinstance(df, pl.DataFrame): 25 | pass 26 | else: 27 | raise TypeError("Invalid type for dataframe") 28 | return df 29 | 30 | def _to_numpy(data): 31 | ''' This functions transforms polars or pandas DataFrames or LazyFrames into numpy arrays''' 32 | if isinstance(data, pl.LazyFrame): 33 | return data.collect().to_numpy() 34 | elif isinstance(data, pl.DataFrame | pd.DataFrame| pl.Series | pd.Series): 35 | return data.to_numpy() 36 | elif isinstance(data, np.ndarray): 37 | return data 38 | else: 39 | print("The dataframe must be a Polars LazyFrame or a Pandas DataFrame") 40 | return 41 | 42 | def _drop_cols(synth, real): 43 | ''' This function returns the real dataset without the columns that are not present in the synthetic one 44 | ''' 45 | if not isinstance(synth, np.ndarray) and not isinstance(real, np.ndarray): 46 | col_synth = set(synth.columns) 47 | col_real = set(real.columns) 48 | not_in_real = col_synth-col_real 49 | not_in_synth = col_real-col_synth 50 | 51 | # Drop columns that are present in the real dataset but are missing in the synthetic one 52 | if len(not_in_real)>0: 53 | print(f"""Warning: The following columns of the synthetic dataset are not present in the real dataset and were dropped to carry on with the computation:\n{not_in_real}\nIf you used only a subset of the dataset for computation, consider increasing the number of rows to ensure that all categorical values are adequately represented after one-hot-encoding.""") 54 | if isinstance(synth, pd.DataFrame): 55 | synth = synth.drop(columns=list(not_in_real)) 56 | if isinstance(synth, pl.DataFrame): 57 | synth = synth.drop(list(not_in_real)) 58 | if isinstance(synth, pl.LazyFrame): 59 | synth = synth.collect.drop(list(not_in_real)) 60 | 61 | # Drop columns that are present in the real dataset but are missing in the synthetic one 62 | if isinstance(real, pd.DataFrame): 63 | real = real.drop(columns=list(not_in_synth)) 64 | if isinstance(real, pl.DataFrame): 65 | real = real.drop(list(not_in_synth)) 66 | if isinstance(real, pl.LazyFrame): 67 | real = real.collect.drop(list(not_in_synth)) 68 | return synth, real 69 | 70 | # MODEL GARDEN MODULE 71 | class ClassificationGarden: 72 | """ 73 | A class to facilitate the training and evaluation of multiple classification models 74 | using LazyClassifier. 75 | 76 | Parameters 77 | ---------- 78 | verbose : int, optional 79 | Verbosity level, by default 0. 80 | ignore_warnings : bool, optional 81 | Whether to ignore warnings, by default True. 82 | custom_metric : callable, optional 83 | Custom metric function for evaluation, by default None. 84 | predictions : bool, optional 85 | Whether to return predictions along with model performance, by default False. 86 | classifiers : str or list, optional 87 | List of classifiers to use, or "all" for using all available classifiers, by default "all". 88 | 89 | :meta private: 90 | """ 91 | 92 | def __init__( 93 | self, 94 | verbose = 0, 95 | ignore_warnings = True, 96 | custom_metric = None, 97 | predictions = False, 98 | classifiers = "all", 99 | ): 100 | self.verbose = verbose 101 | self.ignore_warnings = ignore_warnings 102 | self.custom_metric = custom_metric 103 | self.predictions = predictions 104 | self.classifiers = classifiers 105 | 106 | self.clf = LazyClassifier(verbose = verbose, 107 | ignore_warnings = ignore_warnings, 108 | custom_metric = custom_metric, 109 | predictions = predictions, 110 | classifiers = classifiers) 111 | 112 | def fit( 113 | self, 114 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 115 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 116 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 117 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray 118 | ) -> pd.DataFrame | np.ndarray: 119 | """ 120 | Fit multiple classification models on the provided training data and evaluate on the test data. 121 | 122 | Parameters 123 | ---------- 124 | X_train : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray 125 | Training features. 126 | X_test : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray 127 | Test features. 128 | y_train : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray 129 | Training labels. 130 | y_test : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray 131 | Test labels. 132 | 133 | Returns 134 | ------- 135 | pd.DataFrame or np.ndarray 136 | Model performance metrics and predictions if specified. 137 | 138 | :meta private: 139 | """ 140 | data = [X_train, X_test, y_train, y_test] 141 | 142 | for count, el in enumerate(data): 143 | data[count] = _to_numpy(el) 144 | 145 | models, predictions = self.clf.fit(data[0], data[1], data[2], data[3]) 146 | return models, predictions 147 | 148 | class RegressionGarden(): 149 | """ 150 | A class to facilitate the training and evaluation of multiple regression models 151 | using LazyRegressor. 152 | 153 | Parameters 154 | ---------- 155 | verbose : int, optional 156 | Verbosity level, by default 0. 157 | ignore_warnings : bool, optional 158 | Whether to ignore warnings, by default True. 159 | custom_metric : callable, optional 160 | Custom metric function for evaluation, by default None. 161 | predictions : bool, optional 162 | Whether to return predictions along with model performance, by default False. 163 | regressors : str or list, optional 164 | List of regressors to use, or "all" for using all available regressors, by default "all". 165 | 166 | :meta private: 167 | """ 168 | def __init__( 169 | self, 170 | verbose = 0, 171 | ignore_warnings = True, 172 | custom_metric = None, 173 | predictions = False, 174 | regressors = "all" 175 | ): 176 | self.verbose = verbose 177 | self.ignore_warnings =ignore_warnings 178 | self.custom_metric = custom_metric 179 | self.predictions = predictions 180 | self.regressors = regressors 181 | 182 | self.reg = LazyRegressor(verbose = verbose, 183 | ignore_warnings = ignore_warnings, 184 | custom_metric = custom_metric, 185 | predictions = predictions, 186 | regressors = regressors) 187 | 188 | def fit( 189 | self, 190 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 191 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 192 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 193 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray 194 | ) -> pd.DataFrame | np.ndarray: 195 | """ 196 | Fit multiple regression models on the provided training data and evaluate on the test data. 197 | 198 | Parameters 199 | ---------- 200 | X_train : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray 201 | Training features. 202 | X_test : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray 203 | Test features. 204 | y_train : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray 205 | Training labels. 206 | y_test : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray 207 | Test labels. 208 | 209 | Returns 210 | ------- 211 | pd.DataFrame or np.ndarray 212 | Model performance metrics and predictions if specified. 213 | 214 | :meta private: 215 | """ 216 | data = [X_train, X_test, y_train, y_test] 217 | 218 | for count, el in enumerate(data): 219 | data[count] = _to_numpy(el) 220 | 221 | models, predictions = self.reg.fit(data[0], data[1], data[2], data[3]) 222 | return models, predictions 223 | 224 | # ML UTILITY METRICS MODULE 225 | def compute_utility_metrics_class( 226 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 227 | X_synth: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 228 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 229 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 230 | y_synth: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 231 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 232 | custom_metric: Callable = None, 233 | classifiers: List[Callable] = "all", 234 | predictions: bool = False, 235 | save_data = True, 236 | path_to_json: str = "" 237 | ): 238 | """ 239 | Train and evaluate classification models on both real and synthetic datasets. 240 | 241 | Parameters 242 | ---------- 243 | X_train : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 244 | Training features for real data. 245 | X_synth : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 246 | Training features for synthetic data. 247 | X_test : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 248 | Test features for evaluation. 249 | y_train : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray] 250 | Training labels for real data. 251 | y_synth : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray] 252 | Training labels for synthetic data. 253 | y_test : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray] 254 | Test labels for evaluation. 255 | custom_metric : Callable, optional 256 | Custom metric for model evaluation, by default None. 257 | classifiers : List[Callable], optional 258 | List of classifiers to use, or "all" for all available classifiers, by default "all". 259 | predictions : bool, optional 260 | If True, returns predictions along with model performance, by default False. 261 | save_data : bool 262 | If True, saves the DCR information into the JSON file used to generate the final report, by default True. 263 | path_to_json : str, optional 264 | Path to save the output JSON files, by default "". 265 | 266 | Returns 267 | ------- 268 | Union[pd.DataFrame, pl.DataFrame, np.ndarray] 269 | Model performance metrics for both real and synthetic datasets, and optionally predictions. 270 | """ 271 | # Store DataFrame type information for returning the same type 272 | was_pd = True 273 | was_pl = False 274 | was_np = False 275 | 276 | if isinstance(X_train, pl.DataFrame) or isinstance(X_train, pl.LazyFrame): 277 | was_pl = True 278 | elif isinstance(X_train, np.ndarray): 279 | was_np = True 280 | 281 | # Initialise ClassificationGarden class and start training 282 | classifier = ClassificationGarden(predictions=predictions, classifiers=classifiers, custom_metric=custom_metric) 283 | print('Fitting original models:') 284 | models_train, pred_train = classifier.fit(X_train[[w for w in X_train.columns if w in X_test.columns]], 285 | X_test[[w for w in X_train.columns if w in X_test.columns]], 286 | y_train, 287 | y_test) 288 | 289 | print('Fitting synthetic models:') 290 | classifier_synth = ClassificationGarden(predictions=predictions, classifiers=classifiers, custom_metric=custom_metric) 291 | models_synth, pred_synth = classifier_synth.fit(X_synth[[w for w in X_synth.columns if w in X_test.columns]], 292 | X_test[[w for w in X_synth.columns if w in X_test.columns]], 293 | y_synth, 294 | y_test) 295 | delta = models_train-models_synth 296 | col_names_delta = [s + " Delta" for s in list(delta.columns)] 297 | delta.columns = col_names_delta 298 | col_names_train= [s + " Real" for s in list(models_train.columns)] 299 | models_train.columns = col_names_train 300 | col_names_synth = [s + " Synth" for s in list(models_synth.columns)] 301 | models_synth.columns = col_names_synth 302 | 303 | if save_data: 304 | _save_to_json("models", models_train, path_to_json) 305 | _save_to_json("models_synth", models_synth, path_to_json) 306 | _save_to_json("models_delta", delta, path_to_json) 307 | 308 | # Transform the output DataFrames into the type used for the input DataFrames 309 | if was_pl: 310 | models_train = pl.from_pandas(models_train) 311 | models_synth = pl.from_pandas(models_synth) 312 | delta = pl.from_pandas(delta) 313 | elif was_np: 314 | models_train = pl.from_numpy(models_train) 315 | models_synth = pl.from_numpy(models_synth) 316 | delta = pl.from_numpy(delta) 317 | 318 | if predictions: 319 | if isinstance(y_train, (pl.DataFrame, pl.LazyFrame, pl.Series)): 320 | y_train = y_train.to_pandas() 321 | y_synth = y_synth.to_pandas() 322 | elif isinstance(y_train, np.ndarray): 323 | y_train = pd.DataFrame(y_train) 324 | y_synth = pd.DataFrame(y_synth) 325 | 326 | pred_train = pd.concat([y_train,pred_train], axis=1) 327 | pred_train.columns.values[0] = 'Ground truth' 328 | pred_synth = pd.concat([y_synth,pred_synth], axis=1) 329 | pred_synth.columns.values[0] = 'Ground truth' 330 | 331 | if was_pl: 332 | pred_train = pl.from_pandas(pred_train) 333 | pred_synth = pl.from_pandas(pred_synth) 334 | elif was_np: 335 | pred_train = pred_train.to_numpy() 336 | pred_synth = pred_synth.to_numpy() 337 | return models_train, models_synth, delta, pred_train, pred_synth 338 | else: 339 | return models_train, models_synth, delta 340 | 341 | def compute_utility_metrics_regr( 342 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 343 | X_synth: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 344 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 345 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 346 | y_synth: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 347 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray, 348 | custom_metric: Callable = None, 349 | regressors: List[Callable] = "all", 350 | predictions: bool = False, 351 | save_data = True, 352 | path_to_json: str = "" 353 | ): 354 | """ 355 | Train and evaluate regression models on both real and synthetic datasets. 356 | 357 | Parameters 358 | ---------- 359 | X_train : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 360 | Training features for real data. 361 | X_synth : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 362 | Training features for synthetic data. 363 | X_test : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 364 | Test features for evaluation. 365 | y_train : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray] 366 | Training labels for real data. 367 | y_synth : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray] 368 | Training labels for synthetic data. 369 | y_test : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray] 370 | Test labels for evaluation. 371 | custom_metric : Callable, optional 372 | Custom metric for model evaluation, by default None. 373 | regressors : List[Callable], optional 374 | List of regressors to use, or "all" for all available regressors, by default "all". 375 | predictions : bool, optional 376 | If True, returns predictions along with model performance, by default False. 377 | save_data : bool 378 | If True, saves the DCR information into the JSON file used to generate the final report, by default True. 379 | path_to_json : str, optional 380 | Path to save the output JSON files, by default "". 381 | 382 | Returns 383 | ------- 384 | Union[pd.DataFrame, pl.DataFrame, np.ndarray] 385 | Model performance metrics for both real and synthetic datasets, and optionally predictions. 386 | """ 387 | # Store DataFrame type information for returning the same type 388 | was_pd = True 389 | was_pl = False 390 | was_np = False 391 | 392 | if isinstance(X_train, pl.DataFrame) or isinstance(X_train, pl.LazyFrame): 393 | was_pl = True 394 | elif isinstance(X_train, np.ndarray): 395 | was_np = True 396 | 397 | # Initialise RegressionGarden class and start training 398 | regressor = RegressionGarden(predictions=predictions, regressors=regressors, custom_metric=custom_metric) 399 | print('Fitting original models:') 400 | models_train, pred_train = regressor.fit(X_train[[w for w in X_train.columns if w in X_test.columns]], 401 | X_test[[w for w in X_train.columns if w in X_test.columns]], 402 | y_train, 403 | y_test) 404 | print('Fitting synthetic models:') 405 | 406 | regressor_synth = RegressionGarden(predictions=predictions, regressors=regressors, custom_metric=custom_metric) 407 | models_synth, pred_synth = regressor_synth.fit(X_synth[[w for w in X_synth.columns if w in X_test.columns]], 408 | X_test[[w for w in X_synth.columns if w in X_test.columns]], 409 | y_synth, 410 | y_test) 411 | delta = models_train-models_synth 412 | col_names_delta = [s + " Delta" for s in list(delta.columns)] 413 | delta.columns = col_names_delta 414 | col_names_train= [s + " Real" for s in list(models_train.columns)] 415 | models_train.columns = col_names_train 416 | col_names_synth = [s + " Synth" for s in list(models_synth.columns)] 417 | models_synth.columns = col_names_synth 418 | 419 | if save_data: 420 | _save_to_json("models", models_train, path_to_json) 421 | _save_to_json("models_synth", models_synth, path_to_json) 422 | _save_to_json("models_delta", delta, path_to_json) 423 | 424 | pred_train = pd.concat([y_train.to_pandas(),pred_train], axis=1) 425 | pred_train.columns.values[0] = 'Ground truth' 426 | pred_synth = pd.concat([y_synth.to_pandas(),pred_synth], axis=1) 427 | pred_synth.columns.values[0] = 'Ground truth' 428 | 429 | # Transform the output DataFrames into the type used for the input DataFrames 430 | if isinstance(y_train, (pl.DataFrame, pl.LazyFrame)): 431 | y_train = y_train.to_pandas() 432 | y_synth = y_synth.to_pandas() 433 | elif isinstance(y_train, np.ndarray): 434 | y_train = pd.DataFrame(y_train) 435 | y_synth = pd.DataFrame(y_synth) 436 | 437 | pred_train = pd.concat([y_train,pred_train], axis=1) 438 | pred_train.columns.values[0] = 'Groud truth' 439 | pred_synth = pd.concat([y_synth,pred_synth], axis=1) 440 | pred_synth.columns.values[0] = 'Groud truth' 441 | 442 | if was_pl: 443 | models_train = pl.from_pandas(models_train) 444 | models_synth = pl.from_pandas(models_synth) 445 | delta = pl.from_pandas(delta) 446 | elif was_np: 447 | models_train = pl.from_numpy(models_train) 448 | models_synth = pl.from_numpy(models_synth) 449 | delta = pl.from_numpy(delta) 450 | 451 | if predictions: 452 | if was_pl: 453 | pred_train = pl.from_pandas(pred_train) 454 | pred_synth = pl.from_pandas(pred_synth) 455 | elif was_np: 456 | pred_train = pl.from_numpy(pred_train) 457 | pred_synth = pl.from_numpy(pred_synth) 458 | return models_train, models_synth, delta, pred_train, pred_synth 459 | else: 460 | return models_train, models_synth, delta 461 | 462 | # STATISTICAL SIMILARITIES METRICS MODULE 463 | def _value_count(data: pl.DataFrame, 464 | features: List 465 | ) -> Dict: 466 | ''' This function returns the unique values count and frequency for each feature in a Polars DataFrame 467 | ''' 468 | values = dict() 469 | for feature in data.select(pl.col(features)).columns: 470 | # Get the value counts for the specified feature 471 | value_counts_df = data.group_by(feature).len(name="count") 472 | 473 | # Calculate the frequency 474 | total_counts = value_counts_df['count'].sum() 475 | value_counts_df = value_counts_df.with_columns( 476 | (pl.col('count') / total_counts * 100).round(2).alias('freq_%') 477 | ).sort("freq_%",descending=True) 478 | 479 | values[feature] = value_counts_df 480 | 481 | return values 482 | 483 | def _most_frequent_values(data: pl.DataFrame, 484 | features: List 485 | ) -> Dict: 486 | ''' This function returns the most frequent value for each feature of the Polars DataFrame 487 | ''' 488 | most_frequent_dict = dict() 489 | for feature in features: 490 | # Calculate the most frequent value (mode) for original dataset 491 | most_frequent = data[feature].mode().cast(pl.String).to_list() 492 | 493 | # Create the dictionary entry for the current feature 494 | if len(most_frequent)>5: 495 | most_frequent_dict[feature] = None 496 | else: 497 | most_frequent_dict[feature] = most_frequent 498 | 499 | return most_frequent_dict 500 | 501 | def compute_statistical_metrics( 502 | real_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 503 | synth_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 504 | save_data = True, 505 | path_to_json: str = "" 506 | ) -> Tuple[Dict, Dict, Dict]: 507 | """ 508 | Compute statistical metrics for numerical, categorical, and temporal features 509 | in both real and synthetic datasets. 510 | 511 | Parameters 512 | ---------- 513 | real_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 514 | The real dataset containing numerical, categorical, and/or temporal features. 515 | synth_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 516 | The synthetic dataset containing numerical, categorical, and/or temporal features. 517 | save_data : bool 518 | If True, saves the DCR information into the JSON file used to generate the final report, by default True. 519 | path_to_json : str, optional 520 | The file path to save the comparison metrics in JSON format, by default "". 521 | 522 | Returns 523 | ------- 524 | Tuple[Dict, Dict, Dict] 525 | A tuple containing three dictionaries with statistical comparisons for 526 | numerical, categorical, and temporal features, respectively. 527 | """ 528 | num_features_comparison = None 529 | cat_features_comparison = None 530 | time_features_comparison = None 531 | 532 | # Converting Real and Synthetic Dataset into pl.DataFrame 533 | real_data = _to_polars_df(real_data) 534 | synth_data = _to_polars_df(synth_data) 535 | 536 | # Drop columns that are present in the real dataset but not in the synthetic dataset and vice versa 537 | synth_data, real_data = _drop_cols(synth_data, real_data) 538 | 539 | # Check that the real features and the synthetic ones match 540 | if not real_data.columns==synth_data.columns: 541 | raise ValueError("The features from the real dataset and the synthetic one do not match.") 542 | 543 | # Transform boolean into int 544 | real_data = real_data.with_columns(pl.col(pl.Boolean).cast(pl.UInt8)) 545 | synth_data = synth_data.with_columns(pl.col(pl.Boolean).cast(pl.UInt8)) 546 | 547 | cat_features = cs.expand_selector(real_data, cs.string()) 548 | num_features = cs.expand_selector(real_data, cs.numeric()) 549 | time_features = cs.expand_selector(real_data, cs.temporal()) 550 | 551 | # Numerical features 552 | if len(num_features) != 0: 553 | num_features_comparison = dict() 554 | num_features_comparison["null_count"] = { "real" : real_data.select(pl.col(num_features)).null_count(), 555 | "synthetic" : synth_data.select(pl.col(num_features)).null_count()} 556 | num_features_comparison["unique_val_number"]= { "real" : real_data.select(pl.n_unique(num_features)), 557 | "synthetic" : synth_data.select(pl.n_unique(num_features))} 558 | num_features_comparison["mean"] = { "real" : real_data.select(pl.mean(num_features)), 559 | "synthetic" : synth_data.select(pl.mean(num_features))} 560 | num_features_comparison["std"] = { "real" : real_data.select(num_features).std(), 561 | "synthetic" : synth_data.select(num_features).std()} 562 | num_features_comparison["min"] = { "real" : real_data.select(pl.min(num_features)), 563 | "synthetic" : synth_data.select(pl.min(num_features))} 564 | num_features_comparison["first_quartile"] = { "real" : real_data.select(num_features).quantile(0.25,"nearest"), 565 | "synthetic" : synth_data.select(num_features).quantile(0.25,"nearest")} 566 | num_features_comparison["second_quartile"] = { "real" : real_data.select(num_features).quantile(0.5,"nearest"), 567 | "synthetic" : synth_data.select(num_features).quantile(0.5,"nearest")} 568 | num_features_comparison["third_quartile"] = { "real" : real_data.select(num_features).quantile(0.75,"nearest"), 569 | "synthetic" : synth_data.select(num_features).quantile(0.75,"nearest")} 570 | num_features_comparison["max"] = { "real" : real_data.select(pl.max(num_features)), 571 | "synthetic" : synth_data.select(pl.max(num_features))} 572 | num_features_comparison["skewness"] = { "real" : real_data.select(pl.col(num_features).skew()), 573 | "synthetic" : synth_data.select(pl.col(num_features).skew())} 574 | num_features_comparison["kurtosis"] = { "real" : real_data.select(pl.col(num_features).kurtosis()), 575 | "synthetic" : synth_data.select(pl.col(num_features).kurtosis())} 576 | 577 | # Categorical features 578 | if len(cat_features) != 0: 579 | cat_features_comparison = dict() 580 | cat_features_comparison["null_count"] = { "real" : real_data.select(pl.col(cat_features)).null_count(), 581 | "synthetic" : synth_data.select(pl.col(cat_features)).null_count()} 582 | cat_features_comparison["unique_val_number"]= { "real" : real_data.select(pl.n_unique(cat_features)), 583 | "synthetic" : synth_data.select(pl.n_unique(cat_features))} 584 | cat_features_comparison["unique_val_stats"] = { "real" : _value_count(real_data, cat_features), 585 | "synthetic" : _value_count(synth_data, cat_features)} 586 | 587 | # Temporal features 588 | if len(time_features) != 0: 589 | time_features_comparison = dict() 590 | time_features_comparison["null_count"] = { "real" : real_data.select(pl.col(time_features)).null_count(), 591 | "synthetic" : synth_data.select(pl.col(time_features)).null_count()} 592 | time_features_comparison["unique_val_number"] = { "real" : real_data.select(pl.n_unique(time_features)), 593 | "synthetic" : synth_data.select(pl.n_unique(time_features))} 594 | time_features_comparison["min"] = { "real" : real_data.select(pl.min(time_features)), 595 | "synthetic" : synth_data.select(pl.min(time_features))} 596 | time_features_comparison["max"] = { "real" : real_data.select(pl.max(time_features)), 597 | "synthetic" : synth_data.select(pl.max(time_features))} 598 | time_features_comparison["most_frequent"] = { "real" : _most_frequent_values(real_data, time_features), 599 | "synthetic" : _most_frequent_values(synth_data, time_features)} 600 | 601 | if save_data: 602 | _save_to_json("num_features_comparison", num_features_comparison, path_to_json) 603 | _save_to_json("cat_features_comparison", cat_features_comparison, path_to_json) 604 | _save_to_json("time_features_comparison", time_features_comparison, path_to_json) 605 | 606 | return num_features_comparison, cat_features_comparison, time_features_comparison 607 | 608 | def compute_mutual_info( 609 | real_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 610 | synth_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 611 | exclude_columns: List = [], 612 | save_data = True, 613 | path_to_json: str = "" 614 | ) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: 615 | """ 616 | Compute the correlation matrices for both real and synthetic datasets, and 617 | calculate the difference between these matrices. 618 | 619 | Parameters 620 | ---------- 621 | real_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 622 | The real dataset, which can be in the form of a Polars DataFrame, LazyFrame, 623 | pandas DataFrame, or numpy ndarray. 624 | synth_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 625 | The synthetic dataset, provided in the same format as `real_data`. 626 | exclude_columns: List, option 627 | A list of columns to exclude from the computaion of mutual information, 628 | by default []. 629 | save_data : bool 630 | If True, saves the DCR information into the JSON file used to generate the final report, by default True. 631 | path_to_json : str, optional 632 | File path to save the correlation matrices and their differences in JSON format, 633 | by default "". 634 | 635 | Returns 636 | ------- 637 | Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame] 638 | A tuple containing: 639 | - real_corr: Correlation matrix of the real dataset with column names included. 640 | - synth_corr: Correlation matrix of the synthetic dataset with column names included. 641 | - diff_corr: Difference between the correlation matrices of the real and synthetic 642 | datasets, with values smaller than 1e-5 substituted with 0. 643 | 644 | Raises 645 | ------ 646 | ValueError 647 | If the features in the real and synthetic datasets do not match or if non-numerical 648 | features are present. 649 | """ 650 | # Converting Real and Synthetic Dataset into pl.DataFrame 651 | real_data = _to_polars_df(real_data) 652 | synth_data = _to_polars_df(synth_data) 653 | 654 | for col in exclude_columns: 655 | if col not in real_data.columns: 656 | raise KeyError(f"Column {col} not found in DataFrame.") 657 | 658 | real_data = real_data.drop(exclude_columns) 659 | synth_data = synth_data.drop(exclude_columns) 660 | 661 | # Drop columns that are present in the real dataset but not in the synthetic dataset and vice versa 662 | synth_data, real_data = _drop_cols(synth_data, real_data) 663 | 664 | # Check that the real features and the synthetic ones match 665 | if not real_data.columns==synth_data.columns: 666 | raise ValueError("The features from the real dataset and the synthetic one do not match.") 667 | 668 | # Convert Boolean and Temporal types to numerical 669 | real_data = real_data.with_columns(cs.boolean().cast(pl.UInt8)) 670 | synth_data = synth_data.with_columns(cs.boolean().cast(pl.UInt8)) 671 | real_data = real_data.with_columns(cs.temporal().as_expr().dt.timestamp('ms')) 672 | synth_data = synth_data.with_columns(cs.temporal().as_expr().dt.timestamp('ms')) 673 | 674 | # Label Encoding of categorical features to compute mutual information 675 | encoder = LabelEncoder() 676 | for col in real_data.select(cs.string()).columns: 677 | real_data = real_data.with_columns([ 678 | pl.Series(col, encoder.fit_transform(real_data[col].to_list())), 679 | ]) 680 | synth_data = synth_data.with_columns([ 681 | pl.Series(col, encoder.fit_transform(synth_data[col].to_list())), 682 | ]) 683 | 684 | # Real and Synthetic dataset correlation matrix 685 | real_corr = real_data.corr() 686 | synth_corr = synth_data.corr() 687 | 688 | # Difference between the correlation matrix of the real dataset and the correlation matrix of the synthetic dataset 689 | diff_corr = real_corr-synth_corr 690 | 691 | # Substitute elements with abs value lower than 1e-5 with 0 692 | diff_corr = diff_corr.with_columns([pl.when(abs(pl.col(col)) < 1e-5).then(0).otherwise(pl.col(col)).alias(col) for col in diff_corr.columns]) 693 | 694 | if save_data: 695 | _save_to_json("real_corr", real_corr, path_to_json) 696 | _save_to_json("synth_corr", synth_corr, path_to_json) 697 | _save_to_json("diff_corr", diff_corr, path_to_json) 698 | 699 | return real_corr, synth_corr, diff_corr 700 | 701 | def detection( 702 | df_original: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 703 | df_synthetic: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 704 | preprocessor: Preprocessor = None, 705 | features_to_hide: List = [], 706 | save_data = True, 707 | path_to_json: str = "" 708 | ) -> Dict: 709 | """ 710 | Computes the detection score by training an XGBoost model to differentiate between 711 | original and synthetic data. The lower the model's accuracy, the higher the quality 712 | of the synthetic data. 713 | 714 | Parameters 715 | ---------- 716 | df_original : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 717 | The original dataset containing real data. 718 | df_synthetic : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray] 719 | The synthetic dataset to be evaluated. 720 | preprocessor : Preprocessor, optional 721 | A preprocessor object for transforming the datasets. If None, a new Preprocessor 722 | instance will be created. Defaults to None. 723 | features_to_hide : list, optional 724 | List of features to exclude from importance analysis. Defaults to an empty list. 725 | save_data : bool 726 | If True, saves the DCR information into the JSON file used to generate the final report, by default True. 727 | path_to_json : str, optional 728 | Path to save the output JSON files, by default "". 729 | 730 | Returns 731 | ------- 732 | dict 733 | A dictionary containing accuracy, ROC AUC score, detection score, and feature importances. 734 | 735 | Notes 736 | ----- 737 | The method operates through the following steps: 738 | 739 | 1. Prepares the datasets: 740 | 741 | - Samples the original dataset to match the size of the synthetic dataset. 742 | - Preprocesses both datasets to ensure consistent feature representation. 743 | - Labels original data as `0` and synthetic data as `1`. 744 | 745 | 2. Builds a classification model: 746 | 747 | - Uses XGBoost to train a model that classifies data points as either real or synthetic. 748 | - Splits the data into training and test sets. 749 | - Trains the model using a 33% test split. 750 | 751 | 3. Computes Evaluation Metrics: 752 | 753 | - Accuracy: Measures classification correctness. 754 | - ROC-AUC Score: Measures the model’s discriminatory power. 755 | - Detection Score 756 | 757 | 4. Extracts Feature Importances: 758 | 759 | - Identifies which features contribute most to distinguishing real vs. synthetic data. 760 | - Helps detect which synthetic features deviate from real-world patterns. 761 | 762 | 763 | The detection score is calculated as: 764 | 765 | .. code-block:: python 766 | 767 | detection_score["score"] = (1 - detection_score["ROC_AUC"]) * 2 768 | 769 | - If ``ROC_AUC <= 0.5``, the synthetic data is considered indistinguishable from the real dataset (``score = 1``). 770 | - A lower score means better synthetic data quality. 771 | - Feature importance analysis helps detect which synthetic features deviate most from real data. 772 | 773 | Examples 774 | -------- 775 | Example of dictionary returned: 776 | 777 | .. code-block:: python 778 | 779 | >>> detection_results = detection_score(original_df, synthetic_df) 780 | >>> print(detection_results) 781 | { 782 | "accuracy": 0.85, # How often the model classifies correctly 783 | "ROC_AUC": 0.90, # The ability to distinguish real vs synthetic 784 | "score": 0.2, # The final detection score (lower = better) 785 | "feature_importances": {"feature_1": 0.34, "feature_2": 0.21, ...} 786 | } 787 | 788 | """ 789 | import xgboost as xgb 790 | from sklearn.model_selection import train_test_split 791 | 792 | if preprocessor is None: 793 | preprocessor = Preprocessor(df_original) 794 | 795 | df_original = _to_polars_df(df_original) 796 | df_synthetic = _to_polars_df(df_synthetic) 797 | 798 | # Sample from the original dataset to match the size of the synthetic dataset 799 | df_original = df_original.head(len(df_synthetic)) 800 | 801 | # Replace minority labels in the original data 802 | for i in preprocessor.discarded[1].keys(): 803 | list_minority_labels = preprocessor.discarded[1][i] 804 | for j in list_minority_labels: 805 | df_original = df_original.with_columns( 806 | pl.when(pl.col(i) == j) 807 | .then(pl.lit("other")) 808 | .otherwise(pl.col(i)) 809 | .alias(i) 810 | ) 811 | 812 | # Preprocess and label the original data 813 | preprocessed_df_original = preprocessor.transform(df_original) 814 | df_original = preprocessor.inverse_transform(preprocessed_df_original) 815 | df_original = df_original.with_columns([ 816 | pl.Series("label", np.zeros(len(df_original)).astype(int)), 817 | ]) 818 | 819 | # Preprocess and label the synthetic data 820 | preprocessed_df_synthetic = preprocessor.transform(df_synthetic) 821 | df_synthetic = preprocessor.inverse_transform(preprocessed_df_synthetic) 822 | df_synthetic = df_synthetic.with_columns([ 823 | pl.Series("label", np.ones(len(df_synthetic)).astype(int)), 824 | ]) 825 | 826 | df = pl.concat([df_original, df_synthetic]) 827 | preprocessor_ = Preprocessor(df, target_column = "label") 828 | df_preprocessed = preprocessor_.transform(df) 829 | 830 | X_train, X_test, y_train, y_test = train_test_split( 831 | df_preprocessed.select(pl.exclude("label")), 832 | df_preprocessed.select(pl.col("label")), 833 | test_size=0.33, 834 | random_state=42 835 | ) 836 | 837 | model = xgb.XGBClassifier(max_depth=3, n_estimators=50, use_label_encoder=False, eval_metric="logloss") 838 | model.fit(X_train, y_train) 839 | 840 | # Make predictions and compute metrics 841 | y_pred = model.predict(X_test) 842 | detection_score = {} 843 | detection_score["accuracy"] = round( 844 | accuracy_score(y_true=y_test, y_pred=y_pred), 4 845 | ) 846 | detection_score["ROC_AUC"] = round( 847 | roc_auc_score( 848 | y_true=y_test, 849 | y_score=model.predict_proba(X_test)[:, 1], 850 | average=None, 851 | ), 852 | 4, 853 | ) 854 | detection_score["score"] = ( 855 | 1 856 | if detection_score["ROC_AUC"] <= 0.5 857 | else (1 - detection_score["ROC_AUC"]) * 2 858 | ) 859 | 860 | detection_score["feature_importances"] = {} 861 | 862 | # Determine feature importances 863 | numerical_features_sizes, categorical_features_sizes = preprocessor.get_features_sizes() 864 | 865 | numerical_features = preprocessor.numerical_features 866 | categorical_features = preprocessor.categorical_features 867 | datetime_features = preprocessor.datetime_features 868 | 869 | index = 0 870 | for feature, importance in zip(numerical_features, model.feature_importances_): 871 | if feature not in features_to_hide: 872 | detection_score["feature_importances"][feature] = round(float(importance), 4) 873 | index += 1 874 | 875 | if len(datetime_features)>0: 876 | for feature, importance in zip(datetime_features, model.feature_importances_[index:]): 877 | if feature not in features_to_hide: 878 | detection_score["feature_importances"][feature] = round(float(importance), 4) 879 | index += 1 880 | 881 | for feature, feature_size in zip(categorical_features, categorical_features_sizes): 882 | importance = np.sum(model.feature_importances_[index : index + feature_size]) 883 | index += feature_size 884 | if feature not in features_to_hide: 885 | detection_score["feature_importances"][feature] = round(float(importance), 4) 886 | 887 | if save_data: 888 | _save_to_json("detection_score", detection_score, path_to_json) 889 | 890 | return detection_score 891 | 892 | def query_power( 893 | df_original: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 894 | df_synthetic: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray, 895 | preprocessor: Preprocessor = None, 896 | save_data = True, 897 | path_to_json: str = "" 898 | ) -> dict: 899 | """ 900 | Generates and runs queries to compare the original and synthetic datasets. 901 | 902 | This method creates random queries that filter data from both datasets. 903 | The similarity between the sizes of the filtered results is used to score 904 | the quality of the synthetic data. 905 | 906 | Parameters: 907 | ----------- 908 | df_original (Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]): 909 | The original dataset containing real data. 910 | df_synthetic (Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]): 911 | The synthetic dataset to be evaluated. 912 | preprocessor (Preprocessor, optional): 913 | A preprocessor object for transforming the datasets. If None, a new Preprocessor 914 | instance will be created. Defaults to None. 915 | save_data (bool): 916 | If True, saves the DCR information into the JSON file used to generate the final report. 917 | path_to_json (str, optional): 918 | Path to save the output JSON files. 919 | Returns: 920 | -------- 921 | dict: A dictionary containing query texts, the number of matches for each 922 | query in both datasets, and an overall score indicating the quality 923 | of the synthetic data. 924 | """ 925 | def polars_query(feat_type, feature, op, value): 926 | if feat_type == 'num': 927 | if op == "<=": 928 | query = pl.col(feature) <= value 929 | elif op == ">=": 930 | query = pl.col(feature) >= value 931 | elif feat_type == 'cat': 932 | if op == "==": 933 | query = pl.col(feature) == value 934 | elif op == "!=": 935 | query = pl.col(feature) != value 936 | return query 937 | 938 | query_power = {"queries": []} 939 | 940 | if preprocessor is None: 941 | preprocessor = Preprocessor(df_original) 942 | 943 | df_original = _to_polars_df(df_original) 944 | df_synthetic = _to_polars_df(df_synthetic) 945 | 946 | df_original = df_original.sample(len(df_synthetic)).clone() 947 | df_original_preprocessed = preprocessor.transform(df_original) 948 | df_original = preprocessor.inverse_transform(df_original_preprocessed) 949 | 950 | df_synthetic_preprocessed = preprocessor.transform(df_synthetic) 951 | df_synthetic = preprocessor.inverse_transform(df_synthetic_preprocessed) 952 | 953 | # Extract feature types 954 | numerical_features = preprocessor.numerical_features 955 | categorical_features = preprocessor.categorical_features 956 | datetime_features = preprocessor.datetime_features 957 | boolean_features = preprocessor.boolean_features 958 | 959 | # Prepare the feature list, excluding datetime features 960 | features = list(set(df_original.columns) - set(datetime_features)) 961 | 962 | # Define query parameters 963 | quantiles = [0.25, 0.5, 0.75] 964 | numerical_ops = ["<=", ">="] 965 | categorical_ops = ["==", "!="] 966 | logical_ops = ["and"] 967 | 968 | queries_score = [] 969 | 970 | # Generate and run up to 5 queries 971 | while len(features) >= 2 and len(query_power["queries"]) < 5: 972 | # Randomly select two features for the query 973 | feats = [random.choice(features)] 974 | features.remove(feats[0]) 975 | feats.append(random.choice(features)) 976 | features.remove(feats[1]) 977 | 978 | queries = [] 979 | queries_text = [] 980 | # Construct query conditions for each selected feature 981 | for feature in feats: 982 | if feature in numerical_features: 983 | feat_type = 'num' 984 | op = random.choice(numerical_ops) 985 | value = df_original.select( 986 | pl.col(feature).quantile( 987 | random.choice(quantiles), interpolation="nearest" 988 | )).item() 989 | elif feature in categorical_features or feature in boolean_features: 990 | feat_type = 'cat' 991 | op = random.choice(categorical_ops) 992 | value = random.choice(df_original[feature].unique()) 993 | else: 994 | continue 995 | 996 | queries_text.append(f"`{feature}` {op} `{value}`") 997 | queries.append(polars_query(feat_type, feature, op, value)) 998 | 999 | # Combine query conditions with a logical operator 1000 | text = f" {random.choice(logical_ops)} ".join(queries_text) 1001 | combined_query = reduce(lambda a, b: a & b, queries) 1002 | 1003 | try: 1004 | query = { 1005 | "text": text, 1006 | "original_df": len(df_original.filter(combined_query)), 1007 | "synthetic_df": len(df_synthetic.filter(combined_query)), 1008 | } 1009 | except Exception: 1010 | query = {"text": "Invalid query", "original_df": 0, "synthetic_df": 0} 1011 | 1012 | # Append the query and calculate the score 1013 | query_power["queries"].append(query) 1014 | queries_score.append( 1015 | 1 - abs(query["original_df"] - query["synthetic_df"]) / len(df_original) 1016 | ) 1017 | 1018 | # Calculate the overall query power score 1019 | query_power["score"] = round(float(sum(queries_score) / len(queries_score)), 4) 1020 | 1021 | if save_data: 1022 | _save_to_json("query_power", query_power, path_to_json) 1023 | 1024 | return query_power -------------------------------------------------------------------------------- /tests/resources/synthetic_dataset.csv: -------------------------------------------------------------------------------- 1 | age,work_class,education,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income 2 | 34,Private,Some-college,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,48,United-States,>50K 3 | 49,Private,Some-college,Separated,Exec-managerial,Unmarried,Black,Female,0,0,60,United-States,<=50K 4 | 44,Federal-gov,Bachelors,Married-civ-spouse,Adm-clerical,Husband,Asian-Pac-Islander,Male,0,0,40,*,>50K 5 | 35,Private,HS-grad,Divorced,Craft-repair,Not-in-family,White,Male,0,0,40,United-States,<=50K 6 | 44,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,*,Male,0,0,40,United-States,<=50K 7 | 28,Private,Assoc-voc,Never-married,Sales,Not-in-family,White,Male,0,0,50,United-States,<=50K 8 | 44,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,40,United-States,<=50K 9 | 54,Local-gov,10th,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,45,United-States,<=50K 10 | 33,Private,Bachelors,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,50,United-States,>50K 11 | 30,Private,10th,Divorced,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K 12 | 50,Private,Bachelors,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,1902,40,United-States,>50K 13 | 30,State-gov,Bachelors,Never-married,Prof-specialty,Other-relative,White,Male,0,0,35,United-States,<=50K 14 | 44,Private,Masters,Married-civ-spouse,Adm-clerical,Husband,White,Male,3998,0,50,United-States,<=50K 15 | 20,nan,Some-college,Never-married,nan,Own-child,White,Female,0,0,40,United-States,<=50K 16 | 36,Private,Bachelors,Divorced,Adm-clerical,Not-in-family,White,Male,0,0,54,United-States,>50K 17 | 54,Local-gov,Some-college,Married-civ-spouse,Transport-moving,Husband,Black,Male,0,0,30,United-States,<=50K 18 | 23,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,50,United-States,<=50K 19 | 26,Private,11th,Separated,Craft-repair,Other-relative,White,Male,3121,0,50,United-States,<=50K 20 | 20,nan,Some-college,Never-married,nan,Own-child,White,Female,0,0,40,United-States,<=50K 21 | 57,Private,Some-college,Married-civ-spouse,Craft-repair,Husband,White,Male,0,1709,45,United-States,<=50K 22 | 27,Local-gov,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Female,0,0,40,United-States,<=50K 23 | 19,Private,Some-college,Never-married,Other-service,Own-child,White,Female,0,0,30,United-States,<=50K 24 | 41,Private,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Female,0,0,45,United-States,<=50K 25 | 33,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,35,United-States,<=50K 26 | 33,Private,*,Never-married,Other-service,Unmarried,Black,Female,0,0,14,*,<=50K 27 | 61,Private,11th,Divorced,*,Not-in-family,White,Female,0,0,12,nan,<=50K 28 | 40,Private,Bachelors,*,Transport-moving,Own-child,*,Male,0,0,45,nan,<=50K 29 | 29,Private,Bachelors,Never-married,Exec-managerial,Not-in-family,White,Female,0,0,45,United-States,>50K 30 | 18,Private,HS-grad,Never-married,Handlers-cleaners,Own-child,White,Male,0,0,40,United-States,<=50K 31 | 43,Self-emp-inc,Bachelors,Married-civ-spouse,Sales,Husband,White,Male,0,0,45,United-States,>50K 32 | 42,Self-emp-not-inc,Assoc-voc,Married-civ-spouse,Transport-moving,Husband,White,Male,4797,0,80,United-States,>50K 33 | 37,Private,Some-college,Married-civ-spouse,Sales,Husband,*,Male,0,0,45,United-States,>50K 34 | 29,Private,Assoc-voc,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,44,United-States,<=50K 35 | 45,Private,HS-grad,Married-civ-spouse,Transport-moving,Husband,*,Male,4132,0,40,United-States,<=50K 36 | 36,Self-emp-not-inc,*,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,35,United-States,<=50K 37 | 27,Private,HS-grad,Never-married,Machine-op-inspct,Not-in-family,White,Female,0,0,40,United-States,<=50K 38 | 18,Private,HS-grad,Never-married,Adm-clerical,Own-child,White,Male,0,0,20,United-States,<=50K 39 | 46,Private,HS-grad,*,Other-service,Not-in-family,White,Male,0,0,35,Mexico,<=50K 40 | 30,Self-emp-not-inc,Some-college,Never-married,Sales,Own-child,White,Male,0,0,25,United-States,<=50K 41 | 42,Private,Assoc-voc,Divorced,Adm-clerical,Unmarried,White,Female,0,0,40,United-States,<=50K 42 | 33,Private,*,*,Transport-moving,Unmarried,*,Male,0,0,20,*,<=50K 43 | 45,Private,HS-grad,Never-married,Exec-managerial,Not-in-family,Black,Female,0,0,40,United-States,<=50K 44 | 54,Private,Some-college,Married-civ-spouse,Craft-repair,Wife,White,Female,0,0,40,United-States,<=50K 45 | 37,Self-emp-inc,HS-grad,Never-married,Transport-moving,Not-in-family,White,Male,0,0,50,United-States,<=50K 46 | 37,Self-emp-not-inc,HS-grad,Married-civ-spouse,Transport-moving,Husband,White,Male,0,0,30,United-States,<=50K 47 | 23,Private,Assoc-acdm,Never-married,Sales,Own-child,Black,Female,0,0,36,United-States,<=50K 48 | 26,Private,*,Never-married,Handlers-cleaners,Unmarried,White,Male,0,0,40,United-States,<=50K 49 | 25,Private,*,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,34,Mexico,<=50K 50 | 31,Private,Bachelors,Married-civ-spouse,Sales,Husband,White,Male,0,0,40,United-States,>50K 51 | 26,Local-gov,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Female,0,0,40,United-States,<=50K 52 | 74,nan,Bachelors,Widowed,nan,Not-in-family,*,Female,0,0,35,United-States,<=50K 53 | 25,Private,HS-grad,Never-married,Craft-repair,Not-in-family,White,Male,0,0,35,United-States,<=50K 54 | 18,Private,HS-grad,Never-married,Sales,Own-child,White,Female,0,0,27,United-States,<=50K 55 | 40,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,40,United-States,<=50K 56 | 32,Private,Assoc-voc,Divorced,Exec-managerial,Not-in-family,White,Female,0,0,50,United-States,<=50K 57 | 54,Private,HS-grad,Married-civ-spouse,Tech-support,Husband,White,Male,7687,0,25,United-States,>50K 58 | 23,Private,Some-college,Never-married,Handlers-cleaners,Own-child,White,Male,2961,0,50,United-States,<=50K 59 | 56,Self-emp-not-inc,*,Divorced,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K 60 | 33,Local-gov,HS-grad,Widowed,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K 61 | 21,State-gov,Some-college,Never-married,Adm-clerical,Own-child,White,Female,0,0,30,United-States,<=50K 62 | 21,Private,HS-grad,Never-married,Sales,Not-in-family,White,Male,0,0,54,United-States,<=50K 63 | 27,Private,HS-grad,Widowed,Craft-repair,Unmarried,White,Female,0,1610,26,United-States,<=50K 64 | 47,Self-emp-not-inc,HS-grad,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,2415,50,United-States,>50K 65 | 35,Private,Masters,Never-married,Exec-managerial,Not-in-family,White,Female,0,0,54,United-States,<=50K 66 | 39,Self-emp-not-inc,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,70,nan,<=50K 67 | 19,Private,11th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,60,United-States,<=50K 68 | 52,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,United-States,>50K 69 | 50,Private,HS-grad,Married-civ-spouse,Other-service,Husband,White,Male,0,0,40,United-States,>50K 70 | 39,Private,Some-college,Never-married,Other-service,Unmarried,Black,Female,0,0,40,United-States,<=50K 71 | 22,nan,Some-college,Never-married,nan,Not-in-family,Asian-Pac-Islander,Female,0,0,12,*,<=50K 72 | 31,Private,HS-grad,Never-married,Other-service,Not-in-family,White,Male,0,0,50,United-States,<=50K 73 | 26,Self-emp-not-inc,Some-college,Never-married,Other-service,Not-in-family,White,Female,0,0,30,United-States,<=50K 74 | 26,Private,HS-grad,Married-civ-spouse,Exec-managerial,Wife,White,Female,5013,0,40,United-States,<=50K 75 | 30,Private,10th,Never-married,Craft-repair,Not-in-family,White,Male,0,0,50,United-States,<=50K 76 | 22,Private,Some-college,Married-civ-spouse,Tech-support,Husband,White,Male,0,0,40,United-States,<=50K 77 | 49,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,40,United-States,>50K 78 | 35,Private,Some-college,Divorced,Exec-managerial,Not-in-family,White,Female,0,0,40,United-States,>50K 79 | 25,Private,HS-grad,Married-civ-spouse,Transport-moving,Husband,White,Male,0,0,40,United-States,<=50K 80 | 42,Private,Bachelors,Married-civ-spouse,Craft-repair,Wife,Black,Female,0,0,35,United-States,>50K 81 | 28,Local-gov,Assoc-voc,Married-civ-spouse,Prof-specialty,Wife,White,Female,7687,0,37,United-States,>50K 82 | 20,Private,Some-college,Never-married,Other-service,Not-in-family,White,Male,0,0,40,United-States,<=50K 83 | 30,Self-emp-not-inc,HS-grad,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,1887,50,United-States,>50K 84 | 35,Private,HS-grad,Married-civ-spouse,Handlers-cleaners,Husband,White,Male,0,0,40,Mexico,<=50K 85 | 56,Private,HS-grad,Widowed,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K 86 | 25,Private,Bachelors,Never-married,Exec-managerial,Own-child,White,Male,0,261,40,United-States,<=50K 87 | 39,Private,Masters,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,40,United-States,<=50K 88 | 60,Self-emp-not-inc,*,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,United-States,<=50K 89 | 51,Local-gov,*,Divorced,Exec-managerial,Not-in-family,White,Female,0,0,45,United-States,>50K 90 | 53,Private,Masters,Never-married,Prof-specialty,Not-in-family,White,Female,0,1591,54,United-States,>50K 91 | 24,Private,HS-grad,Never-married,Transport-moving,Unmarried,White,Female,0,0,40,Mexico,<=50K 92 | 43,Private,Some-college,Divorced,Transport-moving,Not-in-family,White,Male,0,0,45,United-States,<=50K 93 | 51,Self-emp-not-inc,Masters,Divorced,Exec-managerial,Unmarried,White,Male,26552,0,50,United-States,>50K 94 | 31,Private,Assoc-acdm,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,20,United-States,<=50K 95 | 60,Private,HS-grad,Married-civ-spouse,Handlers-cleaners,Husband,White,Male,0,0,40,United-States,<=50K 96 | 24,Private,Some-college,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,40,United-States,<=50K 97 | 45,State-gov,Masters,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,*,<=50K 98 | 23,Private,Some-college,Never-married,Sales,Not-in-family,White,Female,0,0,30,United-States,<=50K 99 | 34,Self-emp-not-inc,Assoc-acdm,Married-civ-spouse,Prof-specialty,Wife,White,Female,0,0,25,United-States,>50K 100 | 18,Private,Some-college,Never-married,Adm-clerical,Own-child,White,Female,0,0,38,United-States,<=50K 101 | -------------------------------------------------------------------------------- /tests/resources/validation_dataset.csv: -------------------------------------------------------------------------------- 1 | age,work_class,education,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income 2 | 25,Private,11th,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K 3 | 38,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K 4 | 28,Local-gov,Assoc-acdm,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K 5 | 44,Private,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K 6 | 18,,Some-college,Never-married,,Own-child,White,Female,0,0,30,United-States,<=50K 7 | 34,Private,10th,Never-married,Other-service,Not-in-family,White,Male,0,0,30,United-States,<=50K 8 | 29,,HS-grad,Never-married,,Unmarried,Black,Male,0,0,40,United-States,<=50K 9 | 63,Self-emp-not-inc,Prof-school,Married-civ-spouse,Prof-specialty,Husband,White,Male,3103,0,32,United-States,>50K 10 | 24,Private,Some-college,Never-married,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K 11 | 55,Private,7th-8th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,10,United-States,<=50K 12 | 65,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,6418,0,40,United-States,>50K 13 | 36,Federal-gov,Bachelors,Married-civ-spouse,Adm-clerical,Husband,White,Male,0,0,40,United-States,<=50K 14 | 26,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,39,United-States,<=50K 15 | 58,,HS-grad,Married-civ-spouse,,Husband,White,Male,0,0,35,United-States,<=50K 16 | 48,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,3103,0,48,United-States,>50K 17 | 43,Private,Masters,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,50,United-States,>50K 18 | 20,State-gov,Some-college,Never-married,Other-service,Own-child,White,Male,0,0,25,United-States,<=50K 19 | 43,Private,HS-grad,Married-civ-spouse,Adm-clerical,Wife,White,Female,0,0,30,United-States,<=50K 20 | 37,Private,HS-grad,Widowed,Machine-op-inspct,Unmarried,White,Female,0,0,20,United-States,<=50K 21 | 40,Private,Doctorate,Married-civ-spouse,Prof-specialty,Husband,Asian-Pac-Islander,Male,0,0,45,,>50K 22 | 34,Private,Bachelors,Married-civ-spouse,Tech-support,Husband,White,Male,0,0,47,United-States,>50K 23 | 34,Private,Some-college,Never-married,Other-service,Own-child,Black,Female,0,0,35,United-States,<=50K 24 | 72,,7th-8th,Divorced,,Not-in-family,White,Female,0,0,6,United-States,<=50K 25 | 25,Private,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,43,Peru,<=50K 26 | 25,Private,Bachelors,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,United-States,<=50K 27 | 45,Self-emp-not-inc,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,7298,0,90,United-States,>50K 28 | 22,Private,HS-grad,Never-married,Adm-clerical,Own-child,White,Male,0,0,20,United-States,<=50K 29 | 23,Private,HS-grad,Separated,Machine-op-inspct,Unmarried,Black,Male,0,0,54,United-States,<=50K 30 | 54,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,35,United-States,<=50K 31 | 32,Self-emp-not-inc,Some-college,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,60,United-States,<=50K 32 | 46,State-gov,Some-college,Married-civ-spouse,Exec-managerial,Husband,Black,Male,7688,0,38,United-States,>50K 33 | 56,Self-emp-not-inc,11th,Widowed,Other-service,Unmarried,White,Female,0,0,50,United-States,<=50K 34 | 24,Self-emp-not-inc,Bachelors,Never-married,Sales,Not-in-family,White,Male,0,0,50,United-States,<=50K 35 | 23,Local-gov,Some-college,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,<=50K 36 | 26,Private,HS-grad,Divorced,Exec-managerial,Unmarried,White,Female,0,0,40,United-States,<=50K 37 | 65,,HS-grad,Married-civ-spouse,,Husband,White,Male,0,0,40,United-States,<=50K 38 | 36,Local-gov,Bachelors,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,United-States,>50K 39 | 22,Private,5th-6th,Never-married,Priv-house-serv,Not-in-family,White,Male,0,0,50,Guatemala,<=50K 40 | 17,Private,10th,Never-married,Machine-op-inspct,Not-in-family,White,Male,0,0,40,United-States,<=50K 41 | 20,Private,HS-grad,Never-married,Craft-repair,Own-child,White,Male,0,0,40,United-States,<=50K 42 | 65,Private,Masters,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,50,United-States,>50K 43 | 44,Self-emp-inc,Assoc-voc,Married-civ-spouse,Sales,Husband,White,Male,0,0,45,United-States,>50K 44 | 36,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,40,United-States,<=50K 45 | 29,Private,11th,Married-civ-spouse,Other-service,Husband,White,Male,0,0,40,United-States,<=50K 46 | 20,State-gov,Some-college,Never-married,Farming-fishing,Own-child,White,Male,0,0,32,United-States,<=50K 47 | 28,Private,Assoc-voc,Married-civ-spouse,Prof-specialty,Wife,White,Female,0,0,36,United-States,>50K 48 | 39,Private,7th-8th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,Mexico,<=50K 49 | 54,Private,Some-college,Married-civ-spouse,Transport-moving,Husband,White,Male,3908,0,50,United-States,<=50K 50 | 52,Private,11th,Separated,Priv-house-serv,Not-in-family,Black,Female,0,0,18,United-States,<=50K 51 | -------------------------------------------------------------------------------- /tests/test_sure.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import polars as pl 4 | import pandas as pd 5 | from pathlib import Path 6 | from sure import Preprocessor 7 | from sure.utility import (compute_statistical_metrics, compute_mutual_info, 8 | compute_utility_metrics_class) 9 | from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test, 10 | adversary_dataset, membership_inference_test) 11 | @pytest.fixture 12 | def real_data(): 13 | test_dir = Path(__file__).parent 14 | df = pl.read_csv(test_dir / "resources" / "dataset.csv") 15 | return df 16 | 17 | @pytest.fixture 18 | def synthetic_data(): 19 | test_dir = Path(__file__).parent 20 | df = pl.read_csv(test_dir / "resources" / "synthetic_dataset.csv") 21 | return df 22 | 23 | @pytest.fixture 24 | def validation_data(): 25 | test_dir = Path(__file__).parent 26 | df = pl.read_csv(test_dir / "resources" / "validation_dataset.csv") 27 | return df 28 | 29 | def test_statistical_metrics(real_data, synthetic_data): 30 | """Test computation of statistical metrics between real and synthetic data""" 31 | num_stats, cat_stats, _ = compute_statistical_metrics(real_data, synthetic_data) 32 | 33 | assert isinstance(num_stats, dict) 34 | 35 | def test_mutual_info(real_data, synthetic_data): 36 | """Test computation of mutual information metrics""" 37 | preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize') 38 | real_preprocessed = preprocessor.transform(real_data) 39 | synth_preprocessed = preprocessor.transform(synthetic_data) 40 | 41 | corr_real, corr_synth, corr_diff = compute_mutual_info(real_preprocessed, synth_preprocessed) 42 | 43 | assert corr_real.shape == corr_synth.shape 44 | 45 | def test_privacy_metrics(real_data, synthetic_data, validation_data): 46 | """Test computation of privacy metrics""" 47 | # Test DCR computations 48 | dcr_synth_train = distance_to_closest_record("synth_train", synthetic_data, real_data) 49 | dcr_synth_valid = distance_to_closest_record("synth_val", synthetic_data, validation_data) 50 | 51 | # Test DCR stats 52 | stats_train = dcr_stats("synth_train", dcr_synth_train) 53 | stats_valid = dcr_stats("synth_val", dcr_synth_valid) 54 | 55 | # Test validation 56 | share = validation_dcr_test(dcr_synth_train, dcr_synth_valid) 57 | assert isinstance(share['percentage'], float) 58 | assert 0 <= share['percentage'] <= 100 59 | 60 | def test_membership_inference(real_data, synthetic_data, validation_data): 61 | """Test membership inference attack""" 62 | adversary_df = adversary_dataset(real_data, validation_data) 63 | ground_truth = adversary_df["privacy_test_is_training"] 64 | 65 | mia_results = membership_inference_test(adversary_df, synthetic_data, ground_truth) 66 | 67 | assert isinstance(mia_results, dict) 68 | -------------------------------------------------------------------------------- /tutorials/sure_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SURE library use case notebook\n", 8 | "Useful links:\n", 9 | "- [Github repo](https://github.com/Clearbox-AI/SURE)\n", 10 | "- [Documentation](https://dario-brunelli-clearbox-ai.notion.site/SURE-Documentation-2c17db370641488a8db5bce406032c1f)\n", 11 | "\n", 12 | "Download the datasets and try out the library with this guided use case.\n", 13 | "\n", 14 | "We would greatly appreciate your feedback to help us improve the library! \\\n", 15 | "If you encounter any issues, please open an issue on our [GitHub repository](https://github.com/Clearbox-AI/SURE).\n", 16 | "\n", 17 | "### Datasets description\n", 18 | "\n", 19 | "The three datasets provided are the following:\n", 20 | "\n", 21 | "- *census_dataset_training.csv* \\\n", 22 | " The original real dataset used to train the generative model from which *census_dataset_synthetic* was produced.\n", 23 | " \n", 24 | "- *census_dataset_validation.csv* \\\n", 25 | " This dataset was also part of the original real dataset, but it was NOT used to train the generative model that produced *census_dataset_synthetic*.\n", 26 | " \n", 27 | "- *census_dataset_synthetic.csv* \\\n", 28 | " The synthetic dataset produced with the generative model trained on *census_dataset_training.*\n", 29 | " \n", 30 | "\n", 31 | "The three census datasets include various demographic, social, economic, and housing characteristics of individuals. Every row of the datasets corresponds to an individual.\n", 32 | "\n", 33 | "The machine learning task related to these datasets is a classification task, where, based on all the features, a ML classifier model must decide whether the individual earns more than 50k dollars per year (label=1) or less (label=0).\\\n", 34 | "The column \"label\" in each dataset is the ground truth for this classification task." 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## 0. Installing the library and importing dependencies " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# install the SURE library \n", 51 | "%pip install clearbox-sure" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# importing dependencies\n", 61 | "import polars as pl # you can use polars or pandas for importing the datasets\n", 62 | "import pandas as pd\n", 63 | "import os\n", 64 | "\n", 65 | "from sure import Preprocessor, report\n", 66 | "from sure.utility import (compute_statistical_metrics, compute_mutual_info,\n", 67 | "\t\t\t \t\t\t compute_utility_metrics_class,\n", 68 | "\t\t\t\t\t\t detection,\n", 69 | "\t\t\t\t\t\t query_power)\n", 70 | "from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test, \n", 71 | "\t\t\t adversary_dataset, membership_inference_test)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## 1. Dataset import and preparation" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "#### 1.1 Import the datasets" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "file_path = \"https://raw.githubusercontent.com/Clearbox-AI/SURE/main/examples/data/census_dataset\"\n", 95 | "\n", 96 | "real_data = pl.from_pandas(pd.read_csv(os.path.join(file_path,\"census_dataset_training.csv\"))).lazy()\n", 97 | "valid_data = pl.from_pandas(pd.read_csv(os.path.join(file_path,\"census_dataset_validation.csv\"))).lazy()\n", 98 | "synth_data = pl.from_pandas(pd.read_csv(os.path.join(file_path,\"census_dataset_synthetic.csv\"))).lazy()" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "#### 1.2 Datasets preparation\n", 106 | "Apply a series of transformations to the raw dataset to prepare it for the subsequent steps." 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "# Preprocessor initialization and query execution on the real, synthetic and validation datasets\n", 116 | "preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize')\n", 117 | "\n", 118 | "real_data_preprocessed = preprocessor.transform(real_data)\n", 119 | "valid_data_preprocessed = preprocessor.transform(valid_data)\n", 120 | "synth_data_preprocessed = preprocessor.transform(synth_data)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## 2. Utility assessment" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "#### 2.1 Statistical properties and mutual information\n", 135 | "These functions compute general statistical features, the correlation matrices and the difference between the correlation matrix of the real and synthetic dataset." 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 6, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# Compute statistical properties and features mutual information\n", 145 | "num_features_stats, cat_features_stats, temporal_feat_stats = compute_statistical_metrics(real_data, synth_data)\n", 146 | "corr_real, corr_synth, corr_difference = compute_mutual_info(real_data_preprocessed, synth_data_preprocessed)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "#### 2.2 ML utility - Train on Synthetic Test on Real\n", 154 | "The `compute_utility_metrics_class` trains multiple machine learning classification models on the synthetic dataset and evaluates their performance on the validation set.\n", 155 | "\n", 156 | "For comparison, it also trains the same models on the original training set and evaluates them on the same validation set. This allows a direct comparison between models trained on synthetic data and those trained on real data." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# Assessing the machine learning utility of the synthetic dataset on the classification task\n", 166 | "\n", 167 | "# ML utility: TSTR - Train on Synthetic, Test on Real\n", 168 | "X_train = real_data_preprocessed.drop(\"label\") # Assuming the datasets have a “label” column for the machine learning task they are intended for\n", 169 | "y_train = real_data_preprocessed[\"label\"]\n", 170 | "X_synth = synth_data_preprocessed.drop(\"label\")\n", 171 | "y_synth = synth_data_preprocessed[\"label\"]\n", 172 | "X_test = valid_data_preprocessed.drop(\"label\").limit(10000) # Test the trained models on a portion of the original real dataset (first 10k rows)\n", 173 | "y_test = valid_data_preprocessed[\"label\"].limit(10000)\n", 174 | "TSTR_metrics = compute_utility_metrics_class(X_train, X_synth, X_test, y_train, y_synth, y_test)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "#### 2.3 Detection Score\n", 182 | "Computes the detection score by training an XGBoost model to differentiate between original and synthetic data. \n", 183 | "\n", 184 | "The lower the model's accuracy, the higher the quality of the synthetic data.\n", 185 | "\n", 186 | "\n", 187 | "The detection score is computed as\n", 188 | "\n", 189 | "detection_score = 2*(1 - ROC_AUC)\n", 190 | "\n", 191 | "So if ROC_AUC<=0.5 the synthetic dataset is considered indistinguishable from the real dataset (detection score =1)\n" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "detection_score = detection(real_data, synth_data, preprocessor)\n", 201 | "print(\"Detection accuracy: \", detection_score[\"accuracy\"])\n", 202 | "print(\"Detection ROC_AUC: \", detection_score[\"ROC_AUC\"])\n", 203 | "print(\"Detection score: \", detection_score[\"score\"])\n", 204 | "print(\"Detection feature importances: \", detection_score[\"feature_importances\"])" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "#### Query Power\n", 212 | "Generates and runs queries to compare the original and synthetic datasets.\n", 213 | "\n", 214 | "This method creates random queries that filter data from both datasets.\n", 215 | "\n", 216 | "The similarity between the sizes of the filtered results is used to score the quality of the synthetic data." 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "query_power_score = query_power(real_data, synth_data, preprocessor)\n", 226 | "\n", 227 | "print(\"Query Power score: \", query_power_score[\"score\"])\n", 228 | "for query in query_power_score[\"queries\"]:\n", 229 | " print(\"\\n\", query[\"text\"])\n", 230 | " print(\"Query result on real: \", query[\"original_df\"])\n", 231 | " print(\"Query result on synthetic: \", query[\"synthetic_df\"])" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "## 3. Privacy assessment" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "#### 3.1 Distance to closest record (DCR)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 10, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "# Compute the distances to closest record between the synthetic dataset and the real dataset\n", 255 | "# and the distances to closest record between the synthetic dataset and the validation dataset\n", 256 | "\n", 257 | "dcr_synth_train = distance_to_closest_record(\"synth_train\", synth_data, real_data)\n", 258 | "dcr_synth_valid = distance_to_closest_record(\"synth_val\", synth_data, valid_data)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "# Check for any clones shared between the synthetic and real datasets (DCR=0).\n", 268 | "\n", 269 | "dcr_zero_synth_train = number_of_dcr_equal_to_zero(\"synth_train\", dcr_synth_train)\n", 270 | "dcr_zero_synth_valid = number_of_dcr_equal_to_zero(\"synth_val\", dcr_synth_valid)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "# Compute some general statistcs for the DCR arrays computed above\n", 280 | "\n", 281 | "dcr_stats_synth_train = dcr_stats(\"synth_train\", dcr_synth_train)\n", 282 | "dcr_stats_synth_valid = dcr_stats(\"synth_val\", dcr_synth_valid)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "# Compute the share of records that are closer to the training set than to the validation set\n", 292 | "\n", 293 | "share = validation_dcr_test(dcr_synth_train, dcr_synth_valid)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "#### 3.2 Membership Inference Attack test" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "# Simulate a Membership inference Attack on your syntehtic dataset\n", 310 | "# To do so, you'll need to produce an adversary dataset and some labels as adversary guesses groundtruth\n", 311 | "\n", 312 | "# The label is automatically produced by the function adversary_dataset and is added as a column named \n", 313 | "# \"privacy_test_is_training\" in the adversary dataset returned\n", 314 | "\n", 315 | "# ML privacy attack sandbox initialization and simulation\n", 316 | "adversary_df = adversary_dataset(real_data, valid_data)\n", 317 | "\n", 318 | "# The function adversary_dataset adds a column \"privacy_test_is_training\" to the adversary dataset, indicating whether the record was part of the training set or not\n", 319 | "adversary_guesses_ground_truth = adversary_df[\"privacy_test_is_training\"] \n", 320 | "MIA = membership_inference_test(adversary_df, synth_data, adversary_guesses_ground_truth)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "## 4. Utility-Privacy report" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "# Produce the utility privacy report with the information computed above\n", 337 | "\n", 338 | "report(real_data, synth_data)" 339 | ] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "Python (test_sure)", 345 | "language": "python", 346 | "name": "test_sure" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.10.12" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 2 363 | } 364 | --------------------------------------------------------------------------------