├── .github
└── workflows
│ ├── publish.yaml
│ └── publish_macos.yaml
├── .readthedocs.yaml
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── build_wheel.sh
├── docs
└── source
│ ├── _static
│ └── style.css
│ ├── api
│ ├── modules.rst
│ ├── sure.distance_metrics.rst
│ ├── sure.privacy.rst
│ ├── sure.report_generator.rst
│ └── sure.rst
│ ├── conf.py
│ ├── doc_1.md
│ ├── doc_2.md
│ ├── img
│ ├── SURE_workflow_.png
│ ├── cb_white_logo_compact.png
│ ├── favicon.ico
│ └── sure_logo_nobg.png
│ └── index.rst
├── make.bat
├── project.toml
├── requirements.txt
├── setup.py
├── sure
├── __init__.py
├── _lazypredict.py
├── distance_metrics
│ ├── __init__.py
│ ├── distance.py
│ └── gower_matrix_c.pyx
├── privacy
│ ├── __init__.py
│ └── privacy.py
├── report_generator
│ ├── .streamlit
│ │ └── config.toml
│ ├── __init__.py
│ ├── pages
│ │ └── privacy.py
│ ├── report_app.py
│ └── report_generator.py
└── utility.py
├── tests
├── resources
│ ├── dataset.csv
│ ├── synthetic_dataset.csv
│ └── validation_dataset.csv
└── test_sure.py
└── tutorials
├── data
└── census_dataset
│ ├── census_dataset_synthetic.csv
│ ├── census_dataset_training.csv
│ └── census_dataset_validation.csv
└── sure_tutorial.ipynb
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | name: Publish Libraty to PyPI
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 | runs-on: ubuntu-latest
11 |
12 | strategy:
13 | matrix:
14 | python-version: ["3.10", "3.11", "3.12"]
15 | include:
16 | - python-version: "3.10"
17 | python-path: "/opt/python/cp310-cp310/bin"
18 | - python-version: "3.11"
19 | python-path: "/opt/python/cp311-cp311/bin"
20 | - python-version: "3.12"
21 | python-path: "/opt/python/cp312-cp312/bin"
22 |
23 | steps:
24 | - name: Checkout repository
25 | uses: actions/checkout@v2
26 |
27 | - name: Set up Python for tests
28 | uses: actions/setup-python@v2
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 |
32 | - name: Install test dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install -r requirements.txt
36 | pip install pytest
37 |
38 | - name: Build Python wheels
39 | run: |
40 | python setup.py bdist_wheel
41 |
42 | - name: Install wheel
43 | run: |
44 | pip install dist/*.whl
45 | - name: Run tests
46 | run: |
47 | pytest tests/
48 |
49 | - name: Set up Docker Buildx
50 | uses: docker/setup-buildx-action@v1
51 |
52 | - name: Set up QEMU
53 | uses: docker/setup-qemu-action@v1
54 |
55 | - name: Build Docker image
56 | run: |
57 | docker build -t my-python-wheels .
58 |
59 | - name: Run build script in Docker container
60 | run: |
61 | docker run --rm \
62 | -e PYTHON_VERSION=${{ matrix.python-version }} \
63 | -e PYTHON_PATH=${{ matrix.python-path }} \
64 | -v ${{ github.workspace }}:/io \
65 | my-python-wheels /build_wheel.sh
66 |
67 | # - name: Check if version exists on PyPI
68 | # id: check-version
69 | # run: |
70 | # PACKAGE_NAME=$(python setup.py --name)
71 | # PACKAGE_VERSION=$(python setup.py --version)
72 | # if curl --silent -f https://pypi.org/project/${PACKAGE_NAME}/${PACKAGE_VERSION}/; then
73 | # echo "Version ${PACKAGE_VERSION} already exists on PyPI."
74 | # echo "already_published=true" >> $GITHUB_ENV
75 | # else
76 | # echo "Version ${PACKAGE_VERSION} does not exist on PyPI."
77 | # echo "already_published=false" >> $GITHUB_ENV
78 | # fi
79 |
80 | # - name: Upload Python wheels as artifacts
81 | # uses: actions/upload-artifact@v4
82 | # with:
83 | # name: python-wheels
84 | # path: dist/*.whl
85 |
86 | - name: Publish wheels to PyPI
87 | # if: env.already_published == 'false'
88 | uses: pypa/gh-action-pypi-publish@v1.4.2
89 | with:
90 | user: __token__
91 | password: ${{ secrets.PYPI_TOKEN }}
92 |
93 |
94 |
--------------------------------------------------------------------------------
/.github/workflows/publish_macos.yaml:
--------------------------------------------------------------------------------
1 | name: Publish Library to PyPI (MacOS)
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 | runs-on: ${{ matrix.os }}
11 |
12 | strategy:
13 | matrix:
14 | os: [macos-latest]
15 | python-version: ["3.10", "3.11", "3.12"]
16 |
17 | steps:
18 | - name: Checkout repository
19 | uses: actions/checkout@v2
20 |
21 | - name: Set up Python
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 |
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install wheel setuptools
30 | pip install -r requirements.txt
31 |
32 | - name: Build Python wheels
33 | run: |
34 | python setup.py bdist_wheel
35 |
36 |
37 | # - name: Upload Python wheels as artifacts
38 | # uses: actions/upload-artifact@v4
39 | # with:
40 | # name: python-wheels
41 | # path: dist/*.whl
42 |
43 | - name: Publish on PyPI
44 | run: |
45 | pip install twine
46 | export TWINE_USERNAME=__token__
47 | export TWINE_PASSWORD=${{ secrets.PYPI_TOKEN }}
48 | twine upload dist/*.whl
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the OS, Python version, and other tools you might need
9 | build:
10 | os: ubuntu-22.04
11 | tools:
12 | python: "3.10"
13 | # You can also specify other tool versions:
14 | # nodejs: "19"
15 | # rust: "1.64"
16 | # golang: "1.19"
17 |
18 | # Build documentation in the "docs/" directory with Sphinx
19 | sphinx:
20 | configuration: docs/source/conf.py
21 |
22 | # Optionally build your docs in additional formats such as PDF and ePub
23 | formats: all
24 |
25 | # Optional but recommended, declare the Python requirements required
26 | # to build your documentation
27 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
28 | python:
29 | install:
30 | - requirements: requirements.txt
31 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM quay.io/pypa/manylinux2014_x86_64
2 |
3 | ARG PYTHON_VERSION
4 |
5 | RUN /opt/python/cp310-cp310/bin/pip install -U pip setuptools wheel Cython
6 | RUN /opt/python/cp311-cp311/bin/pip install -U pip setuptools wheel Cython
7 | RUN /opt/python/cp312-cp312/bin/pip install -U pip setuptools wheel Cython
8 |
9 | COPY build_wheel.sh /build_wheel.sh
10 | RUN chmod +x /build_wheel.sh
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = docs/source
9 | BUILDDIR = docs/build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://clearbox-sure.readthedocs.io/en/latest/?badge=latest)
3 | [](https://badge.fury.io/py/clearbox-sure)
4 | [](https://pepy.tech/project/clearbox-sure)
5 | [](https://github.com/Clearbox-AI/SURE)
6 |
7 |
8 |
9 | ### Synthetic Data: Utility, Regulatory compliance, and Ethical privacy
10 |
11 | The SURE package is an open-source Python library intended to be used for the assessment of the utility and privacy performance of any tabular synthetic dataset.
12 |
13 | The SURE library features multiple Python modules that can be easily imported and seamlessly integrated into any Python script after installing the library.
14 |
15 | > [!WARNING]
16 | > This is a beta version of the library and only runs on Linux and MacOS for the moment.
17 |
18 | > [!IMPORTANT]
19 | > Requires Python >= 3.10
20 |
21 | # Installation
22 |
23 | To install the library run the following command in your terminal:
24 |
25 | ```shell
26 | $ pip install clearbox-sure
27 | ```
28 |
29 | # Modules overview
30 |
31 | The SURE library features the following modules:
32 |
33 | 1. Preprocessor
34 | 2. Statistical similarity metrics
35 | 3. Model garden
36 | 4. ML utility metrics
37 | 5. Distance metrics
38 | 6. Privacy attack sandbox
39 | 7. Report generator
40 |
41 | **Preprocessor**
42 |
43 | The input datasets undergo manipulation by the preprocessor module, tailored to conform to the standard structure utilized across the subsequent processes. The Polars library used in the preprocessor makes this operation significantly faster compared to the use of other data processing libraries.
44 |
45 | **Utility**
46 |
47 | The statistical similarity metrics, the ML utility metrics and the model garden modules constitute the data **utility evaluation** part.
48 |
49 | The statistical similarity module and the distance metrics module take as input the pre-processed datasets and carry out the operation to assess the statistical similarity between the datasets and how different the content of the synthetic dataset is from the one of the original dataset. In particular, The real and synthetic input datasets are used in the statistical similarity metrics module to assess how close the two datasets are in terms of statistical properties, such as mean, correlation, distribution.
50 |
51 | The model garden executes a classification or regression task on the given dataset with multiple machine learning models, returning the performance metrics of each of the models tested on the given task and dataset.
52 |
53 | The model garden module’s best performing models are employed in the machine learning utility metrics module to compute the usefulness of the synthetic data on a given ML task (classification or regression).
54 |
55 | **Privacy**
56 |
57 | The distance metrics and the privacy attack sandbox make up the synthetic data **privacy assessment** modules.
58 |
59 | The distance metrics module computes the Gower distance between the two input datasets and the distance to the closest record for each line of the first dataset.
60 |
61 | The ML privacy attack sandbox allows to simulate a Membership Inference Attack for re-identification of vulnerable records identified with the distance metrics module and evaluate how exposed the synthetic dataset is to this kind of assault.
62 |
63 | **Report**
64 |
65 | Eventually, the report generator provides a summary of the utility and privacy metrics computed in the previous modules, providing a visual digest with charts and tables of the results.
66 |
67 | This following diagram serves as a visual representation of how each module contributes to the utility-privacy assessment process and highlights the seamless interconnection and synergy between individual blocks.
68 |
69 |
70 |
71 | # Usage
72 |
73 | The library leverages Polars, which ensures faster computations compared to other data manipulation libraries. It supports both Polars and Pandas dataframes.
74 |
75 | The user must provide both the original real training dataset (which was used to train the generative model that produced the synthetic dataset), the real holdout dataset (which was NOT used to train the generative model that produced the synthetic dataset) and the corresponding synthetic dataset to enable the library's modules to perform the necessary computations for evaluation.
76 |
77 | Below is a code snippet example for the usage of the library:
78 |
79 | ```python
80 | # Import the necessary modules from the SURE library
81 | from sure import Preprocessor, report
82 | from sure.utility import (compute_statistical_metrics, compute_mutual_info,
83 | compute_utility_metrics_class,
84 | detection,
85 | query_power)
86 | from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test,
87 | adversary_dataset, membership_inference_test)
88 |
89 | # Assuming real_data, valid_data and synth_data are three pandas DataFrames
90 |
91 | # Preprocessor initialization and query execution on the real, synthetic and validation datasets
92 | preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize')
93 |
94 | real_data_preprocessed = preprocessor.transform(real_data)
95 | valid_data_preprocessed = preprocessor.transform(valid_data)
96 | synth_data_preprocessed = preprocessor.transform(synth_data)
97 |
98 | # Statistical properties and mutual information
99 | num_features_stats, cat_features_stats, temporal_feat_stats = compute_statistical_metrics(real_data, synth_data)
100 | corr_real, corr_synth, corr_difference = compute_mutual_info(real_data_preprocessed, synth_data_preprocessed)
101 |
102 | # ML utility: TSTR - Train on Synthetic, Test on Real
103 | X_train = real_data_preprocessed.drop("label", axis=1) # Assuming the datasets have a “label” column for the machine learning task they are intended for
104 | y_train = real_data_preprocessed["label"]
105 | X_synth = synth_data_preprocessed.drop("label", axis=1)
106 | y_synth = synth_data_preprocessed["label"]
107 | X_test = valid_data_preprocessed.drop("label", axis=1).limit(10000) # Test the trained models on a portion of the original real dataset (first 10k rows)
108 | y_test = valid_data_preprocessed["label"].limit(10000)
109 | TSTR_metrics = compute_utility_metrics_class(X_train, X_synth, X_test, y_train, y_synth, y_test)
110 |
111 | # Distance to closest record
112 | dcr_synth_train = distance_to_closest_record("synth_train", synth_data, real_data)
113 | dcr_synth_valid = distance_to_closest_record("synth_val", synth_data, valid_data)
114 | dcr_stats_synth_train = dcr_stats("synth_train", dcr_synth_train)
115 | dcr_stats_synth_valid = dcr_stats("synth_val", dcr_synth_valid)
116 | dcr_zero_synth_train = number_of_dcr_equal_to_zero("synth_train", dcr_synth_train)
117 | dcr_zero_synth_valid = number_of_dcr_equal_to_zero("synth_val", dcr_synth_valid)
118 | share = validation_dcr_test(dcr_synth_train, dcr_synth_valid)
119 |
120 | # Detection Score
121 | detection_score = detection(real_data, synth_data, preprocessor)
122 |
123 | # Query Power
124 | query_power_score = query_power(real_data, synth_data, preprocessor)
125 |
126 | # ML privacy attack sandbox initialization and simulation
127 | adversary_df = adversary_dataset(real_data, valid_data)
128 | # The function adversary_dataset adds a column "privacy_test_is_training" to the adversary dataset, indicating whether the record was part of the training set or not
129 | adversary_guesses_ground_truth = adversary_df["privacy_test_is_training"]
130 | MIA = membership_inference_test(adversary_dfv, synth_data, adversary_guesses_ground_truth)
131 |
132 | # Report generation as HTML page
133 | report(real_data, synth_data)
134 | ```
135 |
136 | Follow the step-by-step [guide](https://github.com/Clearbox-AI/SURE/tree/main/examples) to test the library.
137 |
138 |
139 |
--------------------------------------------------------------------------------
/build_wheel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 | cd /io
5 |
6 | ${PYTHON_PATH}/pip install -U pip setuptools wheel Cython
7 | ${PYTHON_PATH}/pip install -r requirements.txt
8 | ${PYTHON_PATH}/python setup.py bdist_wheel
9 |
10 | # Prepare output directory
11 | mkdir -p /io/wheelhouse
12 |
13 | # Repair the wheels with auditwheel
14 | for whl in dist/*.whl; do
15 | auditwheel repair "$whl" --plat manylinux2014_x86_64 -w /io/wheelhouse/
16 | done
17 |
18 | # Replace dist with repaired wheels
19 | rm -rf dist
20 | mv wheelhouse dist
21 |
--------------------------------------------------------------------------------
/docs/source/_static/style.css:
--------------------------------------------------------------------------------
1 | @import url("https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap");
2 | @import url('https://fonts.googleapis.com/css2?family=Be+Vietnam+Pro:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap');
3 |
4 | .be-vietnam-pro-regular {
5 | font-family: "Be Vietnam Pro", serif;
6 | font-weight: 400;
7 | font-style: normal;
8 | }
9 |
10 | body {
11 | font-family: "Be Vietnam Pro", serif;
12 | color: #212121;
13 | }
14 |
15 | footer {
16 | color: #646369;
17 | }
18 |
19 | h1, h2, h3, h4, h5, h6 {
20 | font-family: "Be Vietnam Pro", serif;
21 | font-weight: 600;
22 | }
23 |
24 | .wy-nav-content-wrap {
25 | background-color: #fafafa;
26 | }
27 | .wy-nav-content {
28 | background-color: #ffffff;
29 | max-width: 60rem;
30 | min-height: 100vh;
31 | }
32 | .wy-nav-side,
33 | .wy-nav-top {
34 | background-color: #483a8f;
35 | }
36 |
37 | .wy-form input[type="text"] {
38 | border-radius: 0;
39 | border: none;
40 | }
41 |
42 | /** This fixes the horizontal scroll on mobile */
43 | @media only screen and (max-width: 770px) {
44 | .py.class dt,
45 | .py.function dt {
46 | display: block !important;
47 | overflow-x: auto;
48 | }
49 | }
--------------------------------------------------------------------------------
/docs/source/api/modules.rst:
--------------------------------------------------------------------------------
1 | sure
2 | ====
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | sure
8 |
--------------------------------------------------------------------------------
/docs/source/api/sure.distance_metrics.rst:
--------------------------------------------------------------------------------
1 | sure.distance\_metrics package
2 | ==============================
3 |
4 | Submodules
5 | ----------
6 |
7 | sure.distance\_metrics.distance module
8 | --------------------------------------
9 |
10 | .. automodule:: sure.distance_metrics.distance
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | sure.distance\_metrics.gower\_matrix\_c module
16 | ----------------------------------------------
17 |
18 | .. automodule:: sure.distance_metrics.gower_matrix_c
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | Module contents
24 | ---------------
25 |
26 | .. automodule:: sure.distance_metrics
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/docs/source/api/sure.privacy.rst:
--------------------------------------------------------------------------------
1 | sure.privacy package
2 | ====================
3 |
4 | Submodules
5 | ----------
6 |
7 | sure.privacy.privacy module
8 | ---------------------------
9 |
10 | .. automodule:: sure.privacy.privacy
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | Module contents
16 | ---------------
17 |
18 | .. automodule:: sure.privacy
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
--------------------------------------------------------------------------------
/docs/source/api/sure.report_generator.rst:
--------------------------------------------------------------------------------
1 | sure.report\_generator package
2 | ==============================
3 |
4 | Submodules
5 | ----------
6 |
7 | sure.report\_generator.report\_app module
8 | -----------------------------------------
9 |
10 | .. automodule:: sure.report_generator.report_app
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | sure.report\_generator.report\_generator module
16 | -----------------------------------------------
17 |
18 | .. automodule:: sure.report_generator.report_generator
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | Module contents
24 | ---------------
25 |
26 | .. automodule:: sure.report_generator
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/docs/source/api/sure.rst:
--------------------------------------------------------------------------------
1 | Utils
2 | =====
3 | .. toctree::
4 | :maxdepth: 4
5 |
6 | .. automodule:: sure
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | Utility
12 | =======
13 |
14 | .. automodule:: sure.utility
15 | :members:
16 |
17 |
18 | Privacy
19 | =======
20 | .. toctree::
21 | :maxdepth: 2
22 |
23 | .. automodule:: sure.privacy
24 | :members:
25 |
26 | .. .. automodule:: sure.distance_metrics.distance
27 | :members:
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.path.abspath('../..')) # Adjust the path as needed
4 |
5 | project = 'SURE'
6 | author = 'Dario Brunelli'
7 | copyright = "2024, Clearbox AI"
8 | release = '0.1.9.9'
9 |
10 | extensions = [
11 | 'sphinx.ext.autodoc',
12 | "sphinx.ext.coverage",
13 | 'sphinx.ext.napoleon',
14 | "myst_parser",
15 | 'sphinx_rtd_theme'
16 | ]
17 |
18 | myst_enable_extensions = [
19 | "html_image", # Allows html images format conversion
20 | ]
21 |
22 | source_suffix = {
23 | '.rst': 'restructuredtext',
24 | '.txt': 'markdown',
25 | '.md': 'markdown',
26 | }
27 | templates_path = ['_templates']
28 | exclude_patterns = []
29 |
30 | html_theme = "sphinx_rtd_theme"
31 | html_theme_options = {
32 | "logo_only": True,
33 | "style_nav_header_background": "#483a8f",
34 | }
35 | html_css_files = [
36 | 'style.css',
37 | ]
38 | html_static_path = ['_static', 'img']
39 | html_logo = "img/cb_white_logo_compact.png"
40 | html_favicon = "img/favicon.ico"
41 |
42 | master_doc = 'index' # Ensure this points to your main document
43 |
44 |
45 | # Napoleon settings
46 | napoleon_google_docstring = True
47 | napoleon_numpy_docstring = False
--------------------------------------------------------------------------------
/docs/source/doc_1.md:
--------------------------------------------------------------------------------
1 | [](https://clearbox-sure.readthedocs.io/en/latest/?badge=latest)
2 | [](https://badge.fury.io/py/clearbox-sure)
3 | [](https://pepy.tech/project/clearbox-sure)
4 | [](https://github.com/Clearbox-AI/SURE)
--------------------------------------------------------------------------------
/docs/source/doc_2.md:
--------------------------------------------------------------------------------
1 | ## SURE
2 | Synthetic Data: Utility, Regulatory compliance, and Ethical privacy
3 |
4 | The SURE package is an open-source Python library for the assessment of the utility and privacy performance of any tabular **synthetic dataset**.
5 |
6 | The SURE library works both with [pandas](https://pandas.pydata.org/) and [polars](https://pola.rs/) DataFrames.
7 |
8 | ## Installation
9 |
10 | It is highly recommended to install the library in a virtual environment or container.
11 |
12 | ```shell
13 | $ pip install clearbox-sure
14 | ```
15 |
16 | ## Usage
17 |
18 | The user must provide both the original real training dataset (which was used to train the generative model that produced the synthetic dataset), the real holdout dataset (which was NOT used to train the generative model that produced the synthetic dataset) and the corresponding synthetic dataset to enable the library's modules to perform the necessary computations for evaluation.
19 |
20 | Follow the step-by-step guide to test the library using the provided [instructions](https://github.com/Clearbox-AI/SURE/blob/main/testing/sure_test.ipynb).
21 |
22 | ```python
23 | # Import the necessary modules from the SURE library
24 | from sure import Preprocessor, report
25 | from sure.utility import (compute_statistical_metrics, compute_mutual_info,
26 | compute_utility_metrics_class,
27 | detection,
28 | query_power)
29 | from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test,
30 | adversary_dataset, membership_inference_test)
31 |
32 | # Assuming real_data, valid_data and synth_data are three pandas DataFrames
33 |
34 | # Preprocessor initialization and query execution on the real, synthetic and validation datasets
35 | preprocessor = Preprocessor(real_data)
36 | real_data_preprocessed = preprocessor.transform(real_data)
37 | valid_data_preprocessed = preprocessor.transform(valid_data)
38 | synth_data_preprocessed = preprocessor.transform(synth_data)
39 |
40 | # Statistical properties and mutual information
41 | num_features_stats, cat_features_stats, temporal_feat_stats = compute_statistical_metrics(real_data, synth_data)
42 | corr_real, corr_synth, corr_difference = compute_mutual_info(real_data_preprocessed, synth_data_preprocessed)
43 |
44 | # ML utility: TSTR - Train on Synthetic, Test on Real
45 | X_train = real_data_preprocessed.drop("label", axis=1) # Assuming the datasets have a “label” column for the machine learning task they are intended for
46 | y_train = real_data_preprocessed["label"]
47 | X_synth = synth_data_preprocessed.drop("label", axis=1)
48 | y_synth = synth_data_preprocessed["label"]
49 | X_test = valid_data_preprocessed.drop("label", axis=1).limit(10000) # Test the trained models on a portion of the original real dataset (first 10k rows)
50 | y_test = valid_data_preprocessed["label"].limit(10000)
51 | TSTR_metrics = compute_utility_metrics_class(X_train, X_synth, X_test, y_train, y_synth, y_test)
52 |
53 | # Distance to closest record
54 | dcr_synth_train = distance_to_closest_record("synth_train", synth_data, real_data)
55 | dcr_synth_valid = distance_to_closest_record("synth_val", synth_data, valid_data)
56 | dcr_stats_synth_train = dcr_stats("synth_train", dcr_synth_train)
57 | dcr_stats_synth_valid = dcr_stats("synth_val", dcr_synth_valid)
58 | dcr_zero_synth_train = number_of_dcr_equal_to_zero("synth_train", dcr_synth_train)
59 | dcr_zero_synth_valid = number_of_dcr_equal_to_zero("synth_val", dcr_synth_valid)
60 | share = validation_dcr_test(dcr_synth_train, dcr_synth_valid)
61 |
62 | # Detection Score
63 | detection_score = detection(real_data, synth_data, preprocessor)
64 |
65 | # Query Power
66 | query_power_score = query_power(real_data, synth_data, preprocessor)
67 |
68 | # ML privacy attack sandbox initialization and simulation
69 | adversary_df = adversary_dataset(real_data, valid_data)
70 | # The function adversary_dataset adds a column "privacy_test_is_training" to the adversary dataset, indicating whether the record was part of the training set or not
71 | adversary_guesses_ground_truth = adversary_df["privacy_test_is_training"]
72 | MIA = membership_inference_test(adversary_df, synth, adversary_guesses_ground_truth)
73 |
74 | # Report generation as HTML page
75 | report()
76 | ```
77 |
--------------------------------------------------------------------------------
/docs/source/img/SURE_workflow_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/SURE_workflow_.png
--------------------------------------------------------------------------------
/docs/source/img/cb_white_logo_compact.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/cb_white_logo_compact.png
--------------------------------------------------------------------------------
/docs/source/img/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/favicon.ico
--------------------------------------------------------------------------------
/docs/source/img/sure_logo_nobg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/docs/source/img/sure_logo_nobg.png
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Clearbox SURE documentation master file, created by
2 | sphinx-quickstart on Mon Nov 18 15:46:14 2024.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | .. include:: doc_1.md
7 | :parser: markdown
8 |
9 | .. image:: img/sure_logo_nobg.png
10 | :alt: SURE Logo
11 | :width: 450px
12 |
13 | .. include:: doc_2.md
14 | :parser: markdown
15 |
16 |
17 | .. toctree::
18 | :maxdepth: 2
19 | :caption: Contents:
20 |
21 |
22 | Modules
23 | -------
24 |
25 | .. toctree::
26 | :maxdepth: 2
27 |
28 | api/sure.rst
29 |
30 | Indices and tables
31 | ------------------
32 |
33 | * :ref:`genindex`
34 | * :ref:`modindex`
35 | * :ref:`search`
--------------------------------------------------------------------------------
/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/project.toml:
--------------------------------------------------------------------------------
1 | # pyproject.toml (build-time deps only)
2 | [build-system]
3 | requires = [
4 | "setuptools>=61",
5 | "wheel",
6 | "cython",
7 | "numpy"
8 | ]
9 | build-backend = "setuptools.build_meta"
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | clearbox-preprocessor
2 | setuptools
3 | pandas<=2.2.3
4 | polars<=1.17.1
5 | numpy<2.0.0
6 | cython
7 | wheel
8 | streamlit
9 | matplotlib
10 | matplotlib-inline
11 | seaborn
12 | click
13 | scikit-learn
14 | tqdm
15 | joblib
16 | lightgbm
17 | xgboost
18 | sphinx_rtd_theme
19 | myst_parser
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | import numpy as np
4 | from pathlib import Path
5 | from setuptools import setup, Extension, find_packages
6 | from Cython.Build import cythonize
7 | from Cython.Distutils import build_ext
8 |
9 | # Read the content of the README.md for the long_description metadata
10 | with open("README.md", "r") as readme:
11 | long_description = readme.read()
12 |
13 | # Parse the requirements.txt file to get a list of dependencies
14 | with open("requirements.txt") as f:
15 | requirements = f.read().splitlines()
16 |
17 | # List of files to exclude from Cythonization
18 | EXCLUDE_FILES = [
19 | "sure/utility.py",
20 | "sure/__init__.py",
21 | "sure/privacy/__init__.py",
22 | "sure/privacy/privacy.py"
23 | "sure/report_generator/report_generator.py",
24 | "sure/report_generator/report_app.py",
25 | "sure/report_generator/__init__.py",
26 | "sure/report_generator/pages/privacy.py",
27 | "sure/report_generator/.streamlit/config.toml",
28 | "sure/_lazypredict.py"
29 | ]
30 |
31 | def get_extensions_paths(root_dir, exclude_files):
32 | """
33 | Retrieve file paths for compilation.
34 |
35 | Parameters
36 | ----------
37 | root_dir : str
38 | Root directory to start searching for files.
39 | exclude_files : list of str
40 | List of file paths to exclude from the result.
41 |
42 | Returns
43 | -------
44 | list of str or Extension
45 | A list containing file paths and/or Extension objects.
46 |
47 | """
48 | paths = []
49 |
50 | # Walk the directory to find .py and .pyx files
51 | for root, _, files in os.walk(root_dir):
52 | for filename in files:
53 | if (
54 | os.path.splitext(filename)[1] != ".py"
55 | and os.path.splitext(filename)[1] != ".pyx"
56 | ):
57 | continue
58 |
59 | file_path = os.path.join(root, filename)
60 |
61 | if file_path in exclude_files:
62 | continue
63 |
64 | if os.path.splitext(filename)[1] == ".pyx":
65 | file_path = Extension(
66 | root.replace("/", ".").replace("\\", "."),
67 | [file_path],
68 | include_dirs=[np.get_include()],
69 | )
70 |
71 | paths.append(file_path)
72 |
73 | return paths
74 |
75 | class CustomBuild(build_ext):
76 | """
77 | Custom build class that inherits from Cython's build_ext.
78 |
79 | This class is created to override the default build behavior.
80 | Specifically, it ensures certain non-Cython files are copied
81 | over to the build output directory after the Cythonization process.
82 | """
83 |
84 | def run(self):
85 | """Override the run method to copy specific files after build."""
86 | # Run the original run method
87 | build_ext.run(self)
88 |
89 | build_dir = Path(self.build_lib)
90 | root_dir = Path(__file__).parent
91 | target_dir = build_dir if not self.inplace else root_dir
92 |
93 | # List of files to copy after the build process
94 | files_to_copy = [
95 | "sure/distance_metrics/__init__.py",
96 | "sure/distance_metrics/gower_matrix_c.pyx",
97 | "sure/utility.py",
98 | "sure/__init__.py",
99 | "sure/privacy/__init__.py",
100 | "sure/privacy/privacy.py",
101 | "sure/report_generator/report_generator.py",
102 | "sure/report_generator/report_app.py",
103 | "sure/report_generator/__init__.py",
104 | "sure/report_generator/pages/privacy.py",
105 | "sure/report_generator/.streamlit/config.toml",
106 | ]
107 |
108 | for file in files_to_copy:
109 | self.copy_file(Path(file), root_dir, target_dir)
110 |
111 | def copy_file(self, path, source_dir, destination_dir):
112 | """
113 | Utility method to copy files from source to destination.
114 |
115 | Parameters
116 | ----------
117 | path : Path
118 | Path of the file to be copied.
119 | source_dir : Path
120 | Directory where the source file resides.
121 | destination_dir : Path
122 | Directory where the file should be copied to.
123 |
124 | """
125 | src_file = source_dir / path
126 | dest_file = destination_dir / path
127 | dest_file.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
128 | shutil.copyfile(str(src_file), str(dest_file))
129 |
130 | # Main setup configuration
131 | setup(
132 | # Metadata about the package
133 | name="clearbox-sure",
134 | version="0.2.15",
135 | author="Clearbox AI",
136 | author_email="info@clearbox.ai",
137 | description="A utility and privacy evaluation library for synthetic data",
138 | long_description=long_description,
139 | long_description_content_type="text/markdown",
140 | url="https://github.com/Clearbox-AI/SURE",
141 | install_requires=requirements,
142 | python_requires=">=3.7.0",
143 |
144 | # Cython modules compilation
145 | ext_modules=cythonize(
146 | get_extensions_paths("sure", EXCLUDE_FILES),
147 | build_dir="build",
148 | compiler_directives=dict(language_level=3, always_allow_keywords=True),
149 | ),
150 |
151 | # Override the build command with our custom class
152 | cmdclass=dict(build_ext=CustomBuild),
153 |
154 | # List of packages included in the distribution
155 | packages=find_packages(), # Include all packages in the distribution
156 | include_package_data=True,
157 | )
158 |
--------------------------------------------------------------------------------
/sure/__init__.py:
--------------------------------------------------------------------------------
1 | from .report_generator.report_generator import report, _save_to_json
2 | from .utility import _drop_cols
3 | from clearbox_preprocessor import Preprocessor
4 |
5 | __all__ = [
6 | "report",
7 | "_save_to_json",
8 | "_drop_cols",
9 | "Preprocessor"
10 | ]
--------------------------------------------------------------------------------
/sure/_lazypredict.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is taken from the library lazypredict-nightly
3 | """
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from tqdm import tqdm
8 | import time
9 | from sklearn.pipeline import Pipeline
10 | from sklearn.impute import SimpleImputer
11 | from sklearn.preprocessing import StandardScaler, OneHotEncoder, OrdinalEncoder
12 | from sklearn.compose import ColumnTransformer
13 | from sklearn.utils import all_estimators
14 | from sklearn.base import RegressorMixin
15 | from sklearn.base import ClassifierMixin
16 | from sklearn.metrics import (
17 | accuracy_score,
18 | balanced_accuracy_score,
19 | roc_auc_score,
20 | f1_score,
21 | r2_score,
22 | mean_squared_error,
23 | )
24 | import warnings
25 | import xgboost
26 | import lightgbm
27 | import inspect
28 |
29 | warnings.filterwarnings("ignore")
30 | pd.set_option("display.precision", 2)
31 | pd.set_option("display.float_format", lambda x: "%.2f" % x)
32 |
33 | removed_classifiers = [
34 | "ClassifierChain",
35 | "ComplementNB",
36 | "GradientBoostingClassifier",
37 | "GaussianProcessClassifier",
38 | "HistGradientBoostingClassifier",
39 | "MLPClassifier",
40 | "LogisticRegressionCV",
41 | "MultiOutputClassifier",
42 | "MultinomialNB",
43 | "OneVsOneClassifier",
44 | "OneVsRestClassifier",
45 | "OutputCodeClassifier",
46 | "RadiusNeighborsClassifier",
47 | "VotingClassifier",
48 | ]
49 |
50 | removed_regressors = [
51 | "TheilSenRegressor",
52 | "ARDRegression",
53 | "CCA",
54 | "IsotonicRegression",
55 | "StackingRegressor",
56 | "MultiOutputRegressor",
57 | "MultiTaskElasticNet",
58 | "MultiTaskElasticNetCV",
59 | "MultiTaskLasso",
60 | "MultiTaskLassoCV",
61 | "PLSCanonical",
62 | "PLSRegression",
63 | "RadiusNeighborsRegressor",
64 | "RegressorChain",
65 | "VotingRegressor",
66 | ]
67 |
68 | CLASSIFIERS = [
69 | est
70 | for est in all_estimators()
71 | if (issubclass(est[1], ClassifierMixin) and (est[0] not in removed_classifiers))
72 | ]
73 |
74 | REGRESSORS = [
75 | est
76 | for est in all_estimators()
77 | if (issubclass(est[1], RegressorMixin) and (est[0] not in removed_regressors))
78 | ]
79 |
80 | REGRESSORS.append(("XGBRegressor", xgboost.XGBRegressor))
81 | REGRESSORS.append(("LGBMRegressor", lightgbm.LGBMRegressor))
82 | # REGRESSORS.append(('CatBoostRegressor',catboost.CatBoostRegressor))
83 |
84 | CLASSIFIERS.append(("XGBClassifier", xgboost.XGBClassifier))
85 | CLASSIFIERS.append(("LGBMClassifier", lightgbm.LGBMClassifier))
86 | # CLASSIFIERS.append(('CatBoostClassifier',catboost.CatBoostClassifier))
87 |
88 | numeric_transformer = Pipeline(
89 | steps=[("imputer", SimpleImputer(strategy="mean")), ("scaler", StandardScaler())]
90 | )
91 |
92 | categorical_transformer_low = Pipeline(
93 | steps=[
94 | ("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
95 | ("encoding", OneHotEncoder(handle_unknown="ignore", sparse_output=False)),
96 | ]
97 | )
98 |
99 | categorical_transformer_high = Pipeline(
100 | steps=[
101 | ("imputer", SimpleImputer(strategy="constant", fill_value="missing")),
102 | # 'OrdianlEncoder' Raise a ValueError when encounters an unknown value. Check https://github.com/scikit-learn/scikit-learn/pull/13423
103 | ("encoding", OrdinalEncoder()),
104 | ]
105 | )
106 |
107 |
108 | # Helper function
109 | def get_card_split(df, cols, n=11):
110 | """
111 | Splits categorical columns into 2 lists based on cardinality (i.e # of unique values)
112 | Parameters
113 | ----------
114 | df : Pandas DataFrame
115 | DataFrame from which the cardinality of the columns is calculated.
116 | cols : list-like
117 | Categorical columns to list
118 | n : int, optional (default=11)
119 | The value of 'n' will be used to split columns.
120 | Returns
121 | -------
122 | card_low : list-like
123 | Columns with cardinality < n
124 | card_high : list-like
125 | Columns with cardinality >= n
126 | """
127 | cond = df[cols].nunique() > n
128 | card_high = cols[cond]
129 | card_low = cols[~cond]
130 | return card_low, card_high
131 |
132 |
133 | # Helper class for performing classification
134 |
135 |
136 | class LazyClassifier:
137 | """
138 | This module helps in fitting to all the classification algorithms that are available in Scikit-learn
139 | Parameters
140 | ----------
141 | verbose : int, optional (default=0)
142 | For the liblinear and lbfgs solvers set verbose to any positive
143 | number for verbosity.
144 | ignore_warnings : bool, optional (default=True)
145 | When set to True, the warning related to algorigms that are not able to run are ignored.
146 | custom_metric : function, optional (default=None)
147 | When function is provided, models are evaluated based on the custom evaluation metric provided.
148 | prediction : bool, optional (default=False)
149 | When set to True, the predictions of all the models models are returned as dataframe.
150 | classifiers : list, optional (default="all")
151 | When function is provided, trains the chosen classifier(s).
152 |
153 | Examples
154 | --------
155 | >>> from lazypredict.Supervised import LazyClassifier
156 | >>> from sklearn.datasets import load_breast_cancer
157 | >>> from sklearn.model_selection import train_test_split
158 | >>> data = load_breast_cancer()
159 | >>> X = data.data
160 | >>> y= data.target
161 | >>> X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=.5,random_state =123)
162 | >>> clf = LazyClassifier(verbose=0,ignore_warnings=True, custom_metric=None)
163 | >>> models,predictions = clf.fit(X_train, X_test, y_train, y_test)
164 | >>> model_dictionary = clf.provide_models(X_train,X_test,y_train,y_test)
165 | >>> models
166 | | Model | Accuracy | Balanced Accuracy | ROC AUC | F1 Score | Time Taken |
167 | |:-------------------------------|-----------:|--------------------:|----------:|-----------:|-------------:|
168 | | LinearSVC | 0.989474 | 0.987544 | 0.987544 | 0.989462 | 0.0150008 |
169 | | SGDClassifier | 0.989474 | 0.987544 | 0.987544 | 0.989462 | 0.0109992 |
170 | | MLPClassifier | 0.985965 | 0.986904 | 0.986904 | 0.985994 | 0.426 |
171 | | Perceptron | 0.985965 | 0.984797 | 0.984797 | 0.985965 | 0.0120046 |
172 | | LogisticRegression | 0.985965 | 0.98269 | 0.98269 | 0.985934 | 0.0200036 |
173 | | LogisticRegressionCV | 0.985965 | 0.98269 | 0.98269 | 0.985934 | 0.262997 |
174 | | SVC | 0.982456 | 0.979942 | 0.979942 | 0.982437 | 0.0140011 |
175 | | CalibratedClassifierCV | 0.982456 | 0.975728 | 0.975728 | 0.982357 | 0.0350015 |
176 | | PassiveAggressiveClassifier | 0.975439 | 0.974448 | 0.974448 | 0.975464 | 0.0130005 |
177 | | LabelPropagation | 0.975439 | 0.974448 | 0.974448 | 0.975464 | 0.0429988 |
178 | | LabelSpreading | 0.975439 | 0.974448 | 0.974448 | 0.975464 | 0.0310006 |
179 | | RandomForestClassifier | 0.97193 | 0.969594 | 0.969594 | 0.97193 | 0.033 |
180 | | GradientBoostingClassifier | 0.97193 | 0.967486 | 0.967486 | 0.971869 | 0.166998 |
181 | | QuadraticDiscriminantAnalysis | 0.964912 | 0.966206 | 0.966206 | 0.965052 | 0.0119994 |
182 | | HistGradientBoostingClassifier | 0.968421 | 0.964739 | 0.964739 | 0.968387 | 0.682003 |
183 | | RidgeClassifierCV | 0.97193 | 0.963272 | 0.963272 | 0.971736 | 0.0130029 |
184 | | RidgeClassifier | 0.968421 | 0.960525 | 0.960525 | 0.968242 | 0.0119977 |
185 | | AdaBoostClassifier | 0.961404 | 0.959245 | 0.959245 | 0.961444 | 0.204998 |
186 | | ExtraTreesClassifier | 0.961404 | 0.957138 | 0.957138 | 0.961362 | 0.0270066 |
187 | | KNeighborsClassifier | 0.961404 | 0.95503 | 0.95503 | 0.961276 | 0.0560005 |
188 | | BaggingClassifier | 0.947368 | 0.954577 | 0.954577 | 0.947882 | 0.0559971 |
189 | | BernoulliNB | 0.950877 | 0.951003 | 0.951003 | 0.951072 | 0.0169988 |
190 | | LinearDiscriminantAnalysis | 0.961404 | 0.950816 | 0.950816 | 0.961089 | 0.0199995 |
191 | | GaussianNB | 0.954386 | 0.949536 | 0.949536 | 0.954337 | 0.0139935 |
192 | | NuSVC | 0.954386 | 0.943215 | 0.943215 | 0.954014 | 0.019989 |
193 | | DecisionTreeClassifier | 0.936842 | 0.933693 | 0.933693 | 0.936971 | 0.0170023 |
194 | | NearestCentroid | 0.947368 | 0.933506 | 0.933506 | 0.946801 | 0.0160074 |
195 | | ExtraTreeClassifier | 0.922807 | 0.912168 | 0.912168 | 0.922462 | 0.0109999 |
196 | | CheckingClassifier | 0.361404 | 0.5 | 0.5 | 0.191879 | 0.0170043 |
197 | | DummyClassifier | 0.512281 | 0.489598 | 0.489598 | 0.518924 | 0.0119965 |
198 | """
199 |
200 | def __init__(
201 | self,
202 | verbose=0,
203 | ignore_warnings=True,
204 | custom_metric=None,
205 | predictions=False,
206 | random_state=42,
207 | classifiers="all",
208 | ):
209 | self.verbose = verbose
210 | self.ignore_warnings = ignore_warnings
211 | self.custom_metric = custom_metric
212 | self.predictions = predictions
213 | self.models = {}
214 | self.random_state = random_state
215 | self.classifiers = classifiers
216 |
217 | def fit(self, X_train, X_test, y_train, y_test):
218 | """Fit Classification algorithms to X_train and y_train, predict and score on X_test, y_test.
219 | Parameters
220 | ----------
221 | X_train : array-like,
222 | Training vectors, where rows is the number of samples
223 | and columns is the number of features.
224 | X_test : array-like,
225 | Testing vectors, where rows is the number of samples
226 | and columns is the number of features.
227 | y_train : array-like,
228 | Training vectors, where rows is the number of samples
229 | and columns is the number of features.
230 | y_test : array-like,
231 | Testing vectors, where rows is the number of samples
232 | and columns is the number of features.
233 | Returns
234 | -------
235 | scores : Pandas DataFrame
236 | Returns metrics of all the models in a Pandas DataFrame.
237 | predictions : Pandas DataFrame
238 | Returns predictions of all the models in a Pandas DataFrame.
239 | """
240 | Accuracy = []
241 | B_Accuracy = []
242 | ROC_AUC = []
243 | F1 = []
244 | names = []
245 | TIME = []
246 | predictions = {}
247 |
248 | if self.custom_metric is not None:
249 | CUSTOM_METRIC = []
250 |
251 | if isinstance(X_train, np.ndarray):
252 | X_train = pd.DataFrame(X_train)
253 | X_test = pd.DataFrame(X_test)
254 |
255 | numeric_features = X_train.select_dtypes(include=[np.number]).columns
256 | categorical_features = X_train.select_dtypes(include=["object"]).columns
257 |
258 | categorical_low, categorical_high = get_card_split(
259 | X_train, categorical_features
260 | )
261 |
262 | preprocessor = ColumnTransformer(
263 | transformers=[
264 | ("numeric", numeric_transformer, numeric_features),
265 | ("categorical_low", categorical_transformer_low, categorical_low),
266 | ("categorical_high", categorical_transformer_high, categorical_high),
267 | ]
268 | )
269 |
270 | if self.classifiers == "all":
271 | self.classifiers = CLASSIFIERS
272 | else:
273 | try:
274 | temp_list = []
275 | for classifier in self.classifiers:
276 | full_name = (classifier.__name__, classifier)
277 | temp_list.append(full_name)
278 | self.classifiers = temp_list
279 | except Exception as exception:
280 | print(exception)
281 | print("Invalid Classifier(s)")
282 |
283 | for name, model in tqdm(self.classifiers):
284 | start = time.time()
285 | try:
286 | if "random_state" in model().get_params().keys():
287 | pipe = Pipeline(
288 | steps=[
289 | ("preprocessor", preprocessor),
290 | ("classifier", model(random_state=self.random_state)),
291 | ],
292 | verbose=False
293 | )
294 | else:
295 | pipe = Pipeline(
296 | steps=[("preprocessor", preprocessor), ("classifier", model())],
297 | verbose=False
298 | )
299 |
300 | if model.__name__=="LGBMClassifier":
301 | callbacks = [lightgbm.early_stopping(10, verbose=0), lightgbm.log_evaluation(period=0)]
302 | pipe.fit(X_train, y_train,callbacks=callbacks)
303 | else:
304 | pipe.fit(X_train, y_train)
305 |
306 | self.models[name] = pipe
307 | y_pred = pipe.predict(X_test)
308 | accuracy = accuracy_score(y_test, y_pred, normalize=True)
309 | b_accuracy = balanced_accuracy_score(y_test, y_pred)
310 | f1 = f1_score(y_test, y_pred, average="weighted")
311 | try:
312 | roc_auc = roc_auc_score(y_test, y_pred)
313 | except Exception as exception:
314 | roc_auc = None
315 | if self.ignore_warnings is False:
316 | print("ROC AUC couldn't be calculated for " + name)
317 | print(exception)
318 | names.append(name)
319 | Accuracy.append(accuracy)
320 | B_Accuracy.append(b_accuracy)
321 | ROC_AUC.append(roc_auc)
322 | F1.append(f1)
323 | TIME.append(time.time() - start)
324 | if self.custom_metric is not None:
325 | custom_metric = self.custom_metric(y_test, y_pred)
326 | CUSTOM_METRIC.append(custom_metric)
327 | if self.verbose > 0:
328 | if self.custom_metric is not None:
329 | print(
330 | {
331 | "Model": name,
332 | "Accuracy": accuracy,
333 | "Balanced Accuracy": b_accuracy,
334 | "ROC AUC": roc_auc,
335 | "F1 Score": f1,
336 | self.custom_metric.__name__: custom_metric,
337 | "Time taken": time.time() - start,
338 | }
339 | )
340 | else:
341 | print(
342 | {
343 | "Model": name,
344 | "Accuracy": accuracy,
345 | "Balanced Accuracy": b_accuracy,
346 | "ROC AUC": roc_auc,
347 | "F1 Score": f1,
348 | "Time taken": time.time() - start,
349 | }
350 | )
351 | if self.predictions:
352 | predictions[name] = y_pred
353 | except Exception as exception:
354 | if self.ignore_warnings is False:
355 | print(name + " model failed to execute")
356 | print(exception)
357 | if self.custom_metric is None:
358 | scores = pd.DataFrame(
359 | {
360 | "Model": names,
361 | "Accuracy": Accuracy,
362 | "Balanced Accuracy": B_Accuracy,
363 | "ROC AUC": ROC_AUC,
364 | "F1 Score": F1,
365 | "Time Taken": TIME,
366 | }
367 | )
368 | else:
369 | scores = pd.DataFrame(
370 | {
371 | "Model": names,
372 | "Accuracy": Accuracy,
373 | "Balanced Accuracy": B_Accuracy,
374 | "ROC AUC": ROC_AUC,
375 | "F1 Score": F1,
376 | self.custom_metric.__name__: CUSTOM_METRIC,
377 | "Time Taken": TIME,
378 | }
379 | )
380 | scores = scores.sort_values(by="Balanced Accuracy", ascending=False).set_index(
381 | "Model"
382 | )
383 |
384 | if self.predictions:
385 | predictions_df = pd.DataFrame.from_dict(predictions)
386 | return scores, predictions_df if self.predictions is True else scores
387 |
388 | def provide_models(self, X_train, X_test, y_train, y_test):
389 | """
390 | This function returns all the model objects trained in fit function.
391 | If fit is not called already, then we call fit and then return the models.
392 | Parameters
393 | ----------
394 | X_train : array-like,
395 | Training vectors, where rows is the number of samples
396 | and columns is the number of features.
397 | X_test : array-like,
398 | Testing vectors, where rows is the number of samples
399 | and columns is the number of features.
400 | y_train : array-like,
401 | Training vectors, where rows is the number of samples
402 | and columns is the number of features.
403 | y_test : array-like,
404 | Testing vectors, where rows is the number of samples
405 | and columns is the number of features.
406 | Returns
407 | -------
408 | models: dict-object,
409 | Returns a dictionary with each model pipeline as value
410 | with key as name of models.
411 | """
412 | if len(self.models.keys()) == 0:
413 | self.fit(X_train, X_test, y_train, y_test)
414 |
415 | return self.models
416 |
417 |
418 | def adjusted_rsquared(r2, n, p):
419 | return 1 - (1 - r2) * ((n - 1) / (n - p - 1))
420 |
421 |
422 | # Helper class for performing classification
423 |
424 |
425 | class LazyRegressor:
426 | """
427 | This module helps in fitting regression models that are available in Scikit-learn
428 | Parameters
429 | ----------
430 | verbose : int, optional (default=0)
431 | For the liblinear and lbfgs solvers set verbose to any positive
432 | number for verbosity.
433 | ignore_warnings : bool, optional (default=True)
434 | When set to True, the warning related to algorigms that are not able to run are ignored.
435 | custom_metric : function, optional (default=None)
436 | When function is provided, models are evaluated based on the custom evaluation metric provided.
437 | prediction : bool, optional (default=False)
438 | When set to True, the predictions of all the models models are returned as dataframe.
439 | regressors : list, optional (default="all")
440 | When function is provided, trains the chosen regressor(s).
441 |
442 | Examples
443 | --------
444 | >>> from lazypredict.Supervised import LazyRegressor
445 | >>> from sklearn import datasets
446 | >>> from sklearn.utils import shuffle
447 | >>> import numpy as np
448 |
449 | >>> boston = datasets.load_boston()
450 | >>> X, y = shuffle(boston.data, boston.target, random_state=13)
451 | >>> X = X.astype(np.float32)
452 |
453 | >>> offset = int(X.shape[0] * 0.9)
454 | >>> X_train, y_train = X[:offset], y[:offset]
455 | >>> X_test, y_test = X[offset:], y[offset:]
456 |
457 | >>> reg = LazyRegressor(verbose=0, ignore_warnings=False, custom_metric=None)
458 | >>> models, predictions = reg.fit(X_train, X_test, y_train, y_test)
459 | >>> model_dictionary = reg.provide_models(X_train, X_test, y_train, y_test)
460 | >>> models
461 | | Model | Adjusted R-Squared | R-Squared | RMSE | Time Taken |
462 | |:------------------------------|-------------------:|----------:|------:|-----------:|
463 | | SVR | 0.83 | 0.88 | 2.62 | 0.01 |
464 | | BaggingRegressor | 0.83 | 0.88 | 2.63 | 0.03 |
465 | | NuSVR | 0.82 | 0.86 | 2.76 | 0.03 |
466 | | RandomForestRegressor | 0.81 | 0.86 | 2.78 | 0.21 |
467 | | XGBRegressor | 0.81 | 0.86 | 2.79 | 0.06 |
468 | | GradientBoostingRegressor | 0.81 | 0.86 | 2.84 | 0.11 |
469 | | ExtraTreesRegressor | 0.79 | 0.84 | 2.98 | 0.12 |
470 | | AdaBoostRegressor | 0.78 | 0.83 | 3.04 | 0.07 |
471 | | HistGradientBoostingRegressor | 0.77 | 0.83 | 3.06 | 0.17 |
472 | | PoissonRegressor | 0.77 | 0.83 | 3.11 | 0.01 |
473 | | LGBMRegressor | 0.77 | 0.83 | 3.11 | 0.07 |
474 | | KNeighborsRegressor | 0.77 | 0.83 | 3.12 | 0.01 |
475 | | DecisionTreeRegressor | 0.65 | 0.74 | 3.79 | 0.01 |
476 | | MLPRegressor | 0.65 | 0.74 | 3.80 | 1.63 |
477 | | HuberRegressor | 0.64 | 0.74 | 3.84 | 0.01 |
478 | | GammaRegressor | 0.64 | 0.73 | 3.88 | 0.01 |
479 | | LinearSVR | 0.62 | 0.72 | 3.96 | 0.01 |
480 | | RidgeCV | 0.62 | 0.72 | 3.97 | 0.01 |
481 | | BayesianRidge | 0.62 | 0.72 | 3.97 | 0.01 |
482 | | Ridge | 0.62 | 0.72 | 3.97 | 0.01 |
483 | | TransformedTargetRegressor | 0.62 | 0.72 | 3.97 | 0.01 |
484 | | LinearRegression | 0.62 | 0.72 | 3.97 | 0.01 |
485 | | ElasticNetCV | 0.62 | 0.72 | 3.98 | 0.04 |
486 | | LassoCV | 0.62 | 0.72 | 3.98 | 0.06 |
487 | | LassoLarsIC | 0.62 | 0.72 | 3.98 | 0.01 |
488 | | LassoLarsCV | 0.62 | 0.72 | 3.98 | 0.02 |
489 | | Lars | 0.61 | 0.72 | 3.99 | 0.01 |
490 | | LarsCV | 0.61 | 0.71 | 4.02 | 0.04 |
491 | | SGDRegressor | 0.60 | 0.70 | 4.07 | 0.01 |
492 | | TweedieRegressor | 0.59 | 0.70 | 4.12 | 0.01 |
493 | | GeneralizedLinearRegressor | 0.59 | 0.70 | 4.12 | 0.01 |
494 | | ElasticNet | 0.58 | 0.69 | 4.16 | 0.01 |
495 | | Lasso | 0.54 | 0.66 | 4.35 | 0.02 |
496 | | RANSACRegressor | 0.53 | 0.65 | 4.41 | 0.04 |
497 | | OrthogonalMatchingPursuitCV | 0.45 | 0.59 | 4.78 | 0.02 |
498 | | PassiveAggressiveRegressor | 0.37 | 0.54 | 5.09 | 0.01 |
499 | | GaussianProcessRegressor | 0.23 | 0.43 | 5.65 | 0.03 |
500 | | OrthogonalMatchingPursuit | 0.16 | 0.38 | 5.89 | 0.01 |
501 | | ExtraTreeRegressor | 0.08 | 0.32 | 6.17 | 0.01 |
502 | | DummyRegressor | -0.38 | -0.02 | 7.56 | 0.01 |
503 | | LassoLars | -0.38 | -0.02 | 7.56 | 0.01 |
504 | | KernelRidge | -11.50 | -8.25 | 22.74 | 0.01 |
505 | """
506 |
507 | def __init__(
508 | self,
509 | verbose=0,
510 | ignore_warnings=True,
511 | custom_metric=None,
512 | predictions=False,
513 | random_state=42,
514 | regressors="all",
515 | ):
516 | self.verbose = verbose
517 | self.ignore_warnings = ignore_warnings
518 | self.custom_metric = custom_metric
519 | self.predictions = predictions
520 | self.models = {}
521 | self.random_state = random_state
522 | self.regressors = regressors
523 |
524 | def fit(self, X_train, X_test, y_train, y_test):
525 | """Fit Regression algorithms to X_train and y_train, predict and score on X_test, y_test.
526 | Parameters
527 | ----------
528 | X_train : array-like,
529 | Training vectors, where rows is the number of samples
530 | and columns is the number of features.
531 | X_test : array-like,
532 | Testing vectors, where rows is the number of samples
533 | and columns is the number of features.
534 | y_train : array-like,
535 | Training vectors, where rows is the number of samples
536 | and columns is the number of features.
537 | y_test : array-like,
538 | Testing vectors, where rows is the number of samples
539 | and columns is the number of features.
540 | Returns
541 | -------
542 | scores : Pandas DataFrame
543 | Returns metrics of all the models in a Pandas DataFrame.
544 | predictions : Pandas DataFrame
545 | Returns predictions of all the models in a Pandas DataFrame.
546 | """
547 | R2 = []
548 | ADJR2 = []
549 | RMSE = []
550 | # WIN = []
551 | names = []
552 | TIME = []
553 | predictions = {}
554 |
555 | if self.custom_metric:
556 | CUSTOM_METRIC = []
557 |
558 | if isinstance(X_train, np.ndarray):
559 | X_train = pd.DataFrame(X_train)
560 | X_test = pd.DataFrame(X_test)
561 |
562 | numeric_features = X_train.select_dtypes(include=[np.number]).columns
563 | categorical_features = X_train.select_dtypes(include=["object"]).columns
564 |
565 | categorical_low, categorical_high = get_card_split(
566 | X_train, categorical_features
567 | )
568 |
569 | preprocessor = ColumnTransformer(
570 | transformers=[
571 | ("numeric", numeric_transformer, numeric_features),
572 | ("categorical_low", categorical_transformer_low, categorical_low),
573 | ("categorical_high", categorical_transformer_high, categorical_high),
574 | ]
575 | )
576 |
577 | if self.regressors == "all":
578 | self.regressors = REGRESSORS
579 | else:
580 | try:
581 | temp_list = []
582 | for regressor in self.regressors:
583 | full_name = (regressor.__name__, regressor)
584 | temp_list.append(full_name)
585 | self.regressors = temp_list
586 | except Exception as exception:
587 | print(exception)
588 | print("Invalid Regressor(s)")
589 |
590 | for name, model in tqdm(self.regressors):
591 | start = time.time()
592 | try:
593 | if "random_state" in model().get_params().keys():
594 | pipe = Pipeline(
595 | steps=[
596 | ("preprocessor", preprocessor),
597 | ("regressor", model(random_state=self.random_state)),
598 | ]
599 | )
600 | else:
601 | pipe = Pipeline(
602 | steps=[("preprocessor", preprocessor), ("regressor", model())]
603 | )
604 |
605 | if model.__name__=="LGBMRegressor":
606 | callbacks = [lightgbm.early_stopping(10, verbose=0), lightgbm.log_evaluation(period=0)]
607 | pipe.fit(X_train, y_train,callbacks=callbacks)
608 | else:
609 | pipe.fit(X_train, y_train)
610 |
611 | self.models[name] = pipe
612 | y_pred = pipe.predict(X_test)
613 |
614 | r_squared = r2_score(y_test, y_pred)
615 | adj_rsquared = adjusted_rsquared(
616 | r_squared, X_test.shape[0], X_test.shape[1]
617 | )
618 | rmse = np.sqrt(mean_squared_error(y_test, y_pred))
619 |
620 | names.append(name)
621 | R2.append(r_squared)
622 | ADJR2.append(adj_rsquared)
623 | RMSE.append(rmse)
624 | TIME.append(time.time() - start)
625 |
626 | if self.custom_metric:
627 | custom_metric = self.custom_metric(y_test, y_pred)
628 | CUSTOM_METRIC.append(custom_metric)
629 |
630 | if self.verbose > 0:
631 | scores_verbose = {
632 | "Model": name,
633 | "R-Squared": r_squared,
634 | "Adjusted R-Squared": adj_rsquared,
635 | "RMSE": rmse,
636 | "Time taken": time.time() - start,
637 | }
638 |
639 | if self.custom_metric:
640 | scores_verbose[self.custom_metric.__name__] = custom_metric
641 |
642 | print(scores_verbose)
643 | if self.predictions:
644 | predictions[name] = y_pred
645 | except Exception as exception:
646 | if self.ignore_warnings is False:
647 | print(name + " model failed to execute")
648 | print(exception)
649 |
650 | scores = {
651 | "Model": names,
652 | "Adjusted R-Squared": ADJR2,
653 | "R-Squared": R2,
654 | "RMSE": RMSE,
655 | "Time Taken": TIME,
656 | }
657 |
658 | if self.custom_metric:
659 | scores[self.custom_metric.__name__] = CUSTOM_METRIC
660 |
661 | scores = pd.DataFrame(scores)
662 | scores = scores.sort_values(by="Adjusted R-Squared", ascending=False).set_index(
663 | "Model"
664 | )
665 |
666 | if self.predictions:
667 | predictions_df = pd.DataFrame.from_dict(predictions)
668 | return scores, predictions_df if self.predictions is True else scores
669 |
670 | def provide_models(self, X_train, X_test, y_train, y_test):
671 | """
672 | This function returns all the model objects trained in fit function.
673 | If fit is not called already, then we call fit and then return the models.
674 | Parameters
675 | ----------
676 | X_train : array-like,
677 | Training vectors, where rows is the number of samples
678 | and columns is the number of features.
679 | X_test : array-like,
680 | Testing vectors, where rows is the number of samples
681 | and columns is the number of features.
682 | y_train : array-like,
683 | Training vectors, where rows is the number of samples
684 | and columns is the number of features.
685 | y_test : array-like,
686 | Testing vectors, where rows is the number of samples
687 | and columns is the number of features.
688 | Returns
689 | -------
690 | models: dict-object,
691 | Returns a dictionary with each model pipeline as value
692 | with key as name of models.
693 | """
694 | if len(self.models.keys()) == 0:
695 | self.fit(X_train, X_test, y_train, y_test)
696 |
697 | return self.models
698 |
--------------------------------------------------------------------------------
/sure/distance_metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/sure/distance_metrics/__init__.py
--------------------------------------------------------------------------------
/sure/distance_metrics/distance.py:
--------------------------------------------------------------------------------
1 | import pyximport
2 | import os
3 | from multiprocessing import Pool
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import polars as pl
8 | from sklearn.preprocessing import OrdinalEncoder
9 |
10 | from typing import Dict, List, Tuple, Union
11 | int_type = Union[int, np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64]
12 | float_type = Union[float, np.float16, np.float32, np.float64]
13 |
14 | from ..report_generator.report_generator import _save_to_json
15 |
16 | from sure import _drop_cols
17 |
18 | pyximport.install(setup_args={"include_dirs": np.get_include()})
19 | from sure.distance_metrics.gower_matrix_c import gower_matrix_c
20 |
21 |
22 | def _polars_to_pandas(dataframe: pl.DataFrame | pl.LazyFrame):
23 | if isinstance(dataframe, pl.DataFrame):
24 | dataframe = dataframe.to_pandas()
25 | if isinstance(dataframe, pl.LazyFrame):
26 | dataframe = dataframe.collect().to_pandas()
27 | return dataframe
28 |
29 | def _gower_matrix(
30 | X_categorical: np.ndarray,
31 | X_numerical: np.ndarray,
32 | Y_categorical: np.ndarray,
33 | Y_numerical: np.ndarray,
34 | numericals_ranges: np.ndarray,
35 | features_weight_sum: float,
36 | fill_diagonal: bool,
37 | first_index: int = -1,
38 | ) -> np.ndarray:
39 | """
40 | _summary_
41 |
42 | Parameters
43 | ----------
44 | X_categorical : np.ndarray
45 | 2D array containing only the categorical features of the X dataframe as uint8 values, shape (x_rows, cat_features).
46 | X_numerical : np.ndarray
47 | 2D array containing only the numerical features of the X dataframe as float32 values, shape (x_rows, num_features).
48 | Y_categorical : np.ndarray
49 | 2D array containing only the categorical features of the Y dataframe as uint8 values, shape (y_rows, cat_features).
50 | Y_numerical : np.ndarray
51 | 2D array containing only the numerical features of the Y dataframe as float32 values, shape (y_rows, num_features).
52 | numericals_ranges : np.ndarray
53 | 1D array containing the range (max-min) of each numerical feature as float32 values, shap (num_features,).
54 | features_weight_sum : float
55 | Sum of the feature weights used for the final average computation (usually it's just the number of features, each
56 | feature has a weigth of 1).
57 | fill_diagonal : bool
58 | Whether to fill the matrix diagonal values with a value larger than 1 (5.0). It must be True to get correct values
59 | if your computing the matrix just for one dataset (comparing a dataset with itself), otherwise you will get DCR==0
60 | for each row because on the diagonal you will compare a pair of identical instances.
61 | first_index : int, optioanl
62 | This is required only in case of parallel computation: the computation will occur batch by batch so ther original
63 | diagonal values will no longer be on the diagonal on each batch. We use this index to fill correctly the diagonal
64 | values. If -1 it's assumed there's no parallel computation, by default -1
65 |
66 | Returns
67 | -------
68 | np.ndarray
69 | 1D array containing the Distance to the Closest Record for each row of x_dataframe shape (x_dataframe rows, )
70 | """
71 | return gower_matrix_c(
72 | X_categorical,
73 | X_numerical,
74 | Y_categorical,
75 | Y_numerical,
76 | numericals_ranges,
77 | features_weight_sum,
78 | fill_diagonal,
79 | first_index,
80 | )
81 |
82 | def distance_to_closest_record(
83 | dcr_name: str,
84 | x_dataframe: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
85 | y_dataframe: pd.DataFrame | pl.DataFrame | pl.LazyFrame = None,
86 | feature_weights: np.ndarray | List = None,
87 | parallel: bool = True,
88 | save_data: bool = True,
89 | path_to_json: str = ""
90 | ) -> np.ndarray:
91 | """
92 | Compute the distancees to closest record of dataframe X from dataframe Y using
93 | a modified version of the Gower's distance.
94 | The two dataframes may contain mixed datatypes (numerical and categorical).
95 |
96 | Paper references:
97 | * A General Coefficient of Similarity and Some of Its Properties, J. C. Gower
98 | * Dimensionality Invariant Similarity Measure, Ahmad Basheer Hassanat
99 |
100 | Parameters
101 | ----------
102 | dcr_name : str
103 | Name with which the DCR will be saved with in the JSON file used to generate the final report.
104 | Can be one of the following:
105 | - synth_train
106 | - synth_val
107 | - other
108 | x_dataframe : pd.DataFrame
109 | A dataset containing numerical and categorical data.
110 | categorical_features : List
111 | List of booleans that indicates which features are categorical.
112 | If categoricals_features[i] is True, feature i is categorical.
113 | Must have same length of x_dataframe.columns.
114 | y_dataframe : pd.DataFrame, optional
115 | Another dataset containing numerical and categorical data, by default None.
116 | It must contains the same columns of x_dataframe.
117 | If None, the distance matrix is computed between x_dataframe and x_dataframe
118 | feature_weights : List, optional
119 | List of features weights to use computing distances, by default None.
120 | If None, each feature weight is 1.0
121 | parallel : Boolean, optional
122 | Whether to enable the parallelization to compute Gower matrix, by default True
123 | save_data : bool
124 | If True, saves the DCR information into the JSON file used to generate the final report.
125 | path_to_json : str
126 | Path to the JSON file used to generate the final report.
127 |
128 | Returns
129 | -------
130 | np.ndarray
131 | 1D array containing the Distance to the Closest Record for each row of x_dataframe
132 | shape (x_dataframe rows, )
133 |
134 | Raises
135 | ------
136 | TypeError
137 | If dc_name is not one of the names listed above.
138 | TypeError
139 | If X and Y don't have the same (number of) columns.
140 | """
141 | if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other":
142 | raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"")
143 |
144 | # Converting X Dataset to pd.DataFrame
145 | X = _polars_to_pandas(x_dataframe)
146 |
147 | # Convert any temporal features to int
148 | temporal_columns = X.select_dtypes(include=['datetime']).columns
149 | X[temporal_columns] = X[temporal_columns].astype('int64')
150 |
151 | # If a second dataframe is provided, the distances are calculated using it; otherwise, they are calculated within X itself.
152 | if y_dataframe is None:
153 | Y = X
154 | fill_diagonal = True
155 | else:
156 | # Converting X Dataset into pd.DataFrame
157 | Y = _polars_to_pandas(y_dataframe)
158 | fill_diagonal = False
159 | Y[temporal_columns] = Y[temporal_columns].astype('int64')
160 |
161 | # Drop columns that are present in Y but are missing in X and vice versa
162 | X, Y = _drop_cols(X, Y)
163 |
164 | if not isinstance(X, np.ndarray):
165 | if not np.array_equal(X.columns, Y.columns):
166 | raise TypeError("X and Y dataframes have different columns.")
167 | else:
168 | if not X.shape[1] == Y.shape[1]:
169 | raise TypeError("X and Y arrays have different number of columns.")
170 |
171 | # Get categorical features
172 | categorical_features = np.array(X.dtypes)==pl.Utf8
173 |
174 | # Both datafrmaes are turned into numpy arrays
175 | if not isinstance(X, np.ndarray):
176 | X = np.asarray(X)
177 | if not isinstance(Y, np.ndarray):
178 | Y = np.asarray(Y)
179 |
180 | X_categorical = X[:, categorical_features]
181 | Y_categorical = Y[:, categorical_features]
182 |
183 | if feature_weights is None:
184 | # If no weights are specified, all weights are 1
185 | feature_weights = np.ones(X.shape[1])
186 | else:
187 | feature_weights = np.array(feature_weights)
188 |
189 | # The sum of the weights is necessary to compute the mean value in the end (division)
190 | weight_sum = feature_weights.sum().astype("float32")
191 |
192 | # Perform label encoding on categorical features of X and Y
193 | if categorical_features.any():
194 | encoder = OrdinalEncoder(
195 | handle_unknown='use_encoded_value',
196 | dtype="uint8",
197 | unknown_value=-1,
198 | encoded_missing_value=-1
199 | )
200 |
201 | # Apply the encoder on all categorical columns at once
202 | # Categorical feature matrix of X (num_rows_X x num_cat_feat)
203 | X_categorical = encoder.fit_transform(X_categorical)
204 | # Categorical feature matrix of Y (num_rows_Y x num_cat_feat)
205 | Y_categorical = encoder.transform(Y_categorical)
206 | else:
207 | X_categorical = X_categorical.astype("uint8")
208 | Y_categorical = Y_categorical.astype("uint8")
209 |
210 | # Numerical feature matrix of X (num_rows_X x num_num_feat)
211 | X_numerical = X[:, np.logical_not(categorical_features)].astype("float32")
212 |
213 | # Numerical feature matrix ofMatrice delle feature numeriche di Y (num_rows_Y x num_num_feat)
214 | Y_numerical = Y[:, np.logical_not(categorical_features)].astype("float32")
215 |
216 | # The range of the numerical features is necessary for the way Gower distances are calculated.
217 | # I find the minimum and maximum for each numerical feature by concatenating X and Y, and then
218 | # calculate the range by subtracting all the minimum values from the maximum values.
219 |
220 | numericals_mins = np.amin(np.concatenate((X_numerical, Y_numerical)), axis=0)
221 | numericals_maxs = np.amax(np.concatenate((X_numerical, Y_numerical)), axis=0)
222 | numericals_ranges = numericals_maxs - numericals_mins
223 |
224 | X_rows = X.shape[0]
225 |
226 | """
227 | Perform a parallel calculation on a DataFrame by dividing the data into chunks
228 | and distributing the chunks across multiple CPUs (all the ones available CPUs except one).
229 | divide the total number of rows in X by the number of CPUs used to determine the size of the chunks on
230 | which the parallel calculation is performed.
231 | The for loop with index executes the actual calculation, and index is passed to fill the value that
232 | corresponds to the distance from the same instance in case of calculation on a single DataFrame
233 | (fill_diagonal == True).
234 | """
235 | if parallel:
236 | result_objs = []
237 | number_of_cpus = os.cpu_count() - 1
238 | chunk_size = int(X_rows / number_of_cpus)
239 | chunk_size = chunk_size if chunk_size > 0 else 1
240 | with Pool(processes=number_of_cpus) as pool:
241 | for index in range(0, X_rows, chunk_size):
242 | result = pool.apply_async(
243 | _gower_matrix,
244 | (
245 | X_categorical[index : index + chunk_size],
246 | X_numerical[index : index + chunk_size],
247 | Y_categorical,
248 | Y_numerical,
249 | numericals_ranges,
250 | weight_sum,
251 | fill_diagonal,
252 | index,
253 | ),
254 | )
255 | result_objs.append(result)
256 | results = [result.get() for result in result_objs]
257 | dcr = np.concatenate(results)
258 | else:
259 | dcr = _gower_matrix(
260 | X_categorical,
261 | X_numerical,
262 | Y_categorical,
263 | Y_numerical,
264 | numericals_ranges,
265 | weight_sum,
266 | fill_diagonal,
267 | )
268 | if save_data:
269 | _save_to_json("dcr_"+dcr_name, dcr, path_to_json)
270 | return dcr
271 |
272 | def dcr_stats(dcr_name: str,
273 | distances_to_closest_record: np.ndarray,
274 | save_data: bool = True,
275 | path_to_json: str = "") -> Dict:
276 | """
277 | This function returns the statisitcs for an array containing DCR computed previously.
278 |
279 | Parameters
280 | ----------
281 | dcr_name : str
282 | Name with which the DCR will be saved with in the JSON file used to generate the final report.
283 | Can be one of the following:
284 | - synth_train
285 | - synth_val
286 | - other
287 | distances_to_closest_record : np.ndarray
288 | A 1D-array containing the Distance to the Closest Record for each row of a dataframe
289 | shape (dataframe rows, )
290 | save_data : bool
291 | If True, saves the DCR information into the JSON file used to generate the final report.
292 | path_to_json : str
293 | Path to the JSON file used to generate the final report.
294 |
295 | Returns
296 | -------
297 | Dict
298 | A dictionary containing mean and percentiles of the given DCR array.
299 | """
300 | if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other":
301 | raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"")
302 |
303 | dcr_mean = np.mean(distances_to_closest_record)
304 | dcr_percentiles = np.percentile(distances_to_closest_record, [0, 25, 50, 75, 100])
305 | dcr_stats = {
306 | "mean": dcr_mean.item(),
307 | "min": dcr_percentiles[0].item(),
308 | "25%": dcr_percentiles[1].item(),
309 | "median": dcr_percentiles[2].item(),
310 | "75%": dcr_percentiles[3].item(),
311 | "max": dcr_percentiles[4].item(),
312 | }
313 | if save_data:
314 | _save_to_json("dcr_"+dcr_name+"_stats", dcr_stats, path_to_json)
315 | return dcr_stats
316 |
317 | def number_of_dcr_equal_to_zero(dcr_name: str,
318 | distances_to_closest_record: np.ndarray,
319 | save_data: bool = True,
320 | path_to_json: str = "") -> int_type:
321 | """
322 | Return the number of 0s in the given DCR array, that is the number of duplicates/clones detected.
323 |
324 | Parameters
325 | ----------
326 | distances_to_closest_record : np.ndarray
327 | A 1D-array containing the Distance to the Closest Record for each row of a dataframe
328 | shape (dataframe rows, )
329 | save_data : bool
330 | If True, saves the DCR information into the JSON file used to generate the final report.
331 | path_to_json : str
332 | Path to the JSON file used to generate the final report.
333 |
334 | Returns
335 | -------
336 | int
337 | The number of 0s in the given DCR array.
338 | """
339 | if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other":
340 | raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"")
341 |
342 | zero_values_mask = distances_to_closest_record == 0.0
343 | if save_data:
344 | _save_to_json("dcr_"+dcr_name+"_num_of_zeros", zero_values_mask.sum(), path_to_json)
345 | return zero_values_mask.sum()
346 |
347 | # def dcr_histogram(
348 | # dcr_name: str,
349 | # distances_to_closest_record: np.ndarray,
350 | # bins: int = 20,
351 | # scale_to_100: bool = True,
352 | # save_data: bool = True,
353 | # path_to_json: str = ""
354 | # ) -> Dict:
355 | # """
356 | # Compute the histogram of a DCR array: the DCR values equal to 0 are extracted before the
357 | # histogram computation so that the first bar represent only the 0 (duplicates/clones)
358 | # and the following bars represent the standard bins (with edge) of an histogram.
359 |
360 | # Parameters
361 | # ----------
362 | # distances_to_closest_record : np.ndarray
363 | # A 1D-array containing the Distance to the Closest Record for each row of a dataframe
364 | # shape (dataframe rows, )
365 | # bins : int, optional
366 | # _description_, by default 20
367 | # scale_to_100 : bool, optional
368 | # Wheter to scale the histogram bins between 0 and 100 (instead of 0 and 1), by default True
369 |
370 | # Returns
371 | # -------
372 | # Dict
373 | # A dict containing the following items:
374 | # * bins, histogram bins detected as string labels.
375 | # The first bin/label is 0 (duplicates/clones), then the format is [inf_edge, sup_edge).
376 | # * count, histogram values for each bin in bins
377 | # * bins_edge_without_zero, the bin edges as returned by the np.histogram function without 0.
378 | # """
379 | # if dcr_name != "synth_train" and dcr_name != "synth_val" and dcr_name != "other":
380 | # raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"")
381 |
382 | # range_bins_with_zero = ["0.0"]
383 | # number_of_dcr_zeros = number_of_dcr_equal_to_zero(dcr_name, distances_to_closest_record)
384 | # dcr_non_zeros = distances_to_closest_record[distances_to_closest_record > 0]
385 | # counts_without_zero, bins_without_zero = np.histogram(
386 | # dcr_non_zeros, bins=bins, range=(0.0, 1.0), density=False
387 | # )
388 | # if scale_to_100:
389 | # scaled_bins_without_zero = bins_without_zero * 100
390 | # else:
391 | # scaled_bins_without_zero = bins_without_zero
392 |
393 | # range_bins_with_zero.append("(0.0-{:.2f})".format(scaled_bins_without_zero[1]))
394 | # for i, left_edge in enumerate(scaled_bins_without_zero[1:-2]):
395 | # range_bins_with_zero.append(
396 | # "[{:.2f}-{:.2f})".format(left_edge, scaled_bins_without_zero[i + 2])
397 | # )
398 | # range_bins_with_zero.append(
399 | # "[{:.2f}-{:.2f}]".format(
400 | # scaled_bins_without_zero[-2], scaled_bins_without_zero[-1]
401 | # )
402 | # )
403 |
404 | # counts_with_zero = np.insert(counts_without_zero, 0, number_of_dcr_zeros)
405 |
406 | # dcr_hist = {
407 | # "bins": range_bins_with_zero,
408 | # "counts": counts_with_zero.tolist(),
409 | # "bins_edge_without_zero": bins_without_zero.tolist(),
410 | # }
411 |
412 | # if save_data:
413 | # _save_to_json("dcr_"+dcr_name+"_hist", dcr_hist, path_to_json)
414 | # return dcr_hist
415 |
416 | def validation_dcr_test(
417 | dcr_synth_train: np.ndarray,
418 | dcr_synth_validation: np.ndarray,
419 | save_data: bool = True,
420 | path_to_json: str = ""
421 | ) -> float_type:
422 | """
423 | - If the returned percentage is close to (or smaller than) 50%, then the synthetic datset's records are equally close to the original training set and to the validation set.
424 | In this casse the synthetic data does not allow to conjecture whether a record was or was not contained in the training dataset.
425 | - If the returned percentage is greater than 50%, then the synthetic datset's records are closer to the training set than to the validation set, indicating
426 | that vulnerable records are present in the synthetic dataset.
427 |
428 | Parameters
429 | ----------
430 | dcr_synth_train : np.ndarray
431 | A 1D-array containing the Distance to the Closest Record for each row of the synthetic
432 | dataset wrt the training dataset, shape (synthetic rows, )
433 | dcr_synth_validation : np.ndarray
434 | A 1D-array containing the Distance to the Closest Record for each row of the synthetic
435 | dataset wrt the validation dataset, shape (synthetic rows, )
436 | save_data : bool
437 | If True, saves the DCR information into the JSON file used to generate the final report.
438 | path_to_json : str
439 | Path to the JSON file used to generate the final report.
440 |
441 | Returns
442 | -------
443 | float
444 | The percentage of synthetic rows closer to the training dataset than to the validation dataset.
445 |
446 | Raises
447 | ------
448 | ValueError
449 | If the two DCR array given as parameters have different shapes.
450 | """
451 | if dcr_synth_train.shape != dcr_synth_validation.shape:
452 | raise ValueError("Dcr arrays have different shapes.")
453 |
454 | warnings = ""
455 | percentage = 0.0
456 |
457 | if dcr_synth_train.sum() == 0:
458 | percentage = 100.0
459 | warnings = (
460 | "The synthetic dataset is an exact copy/clone of the training dataset."
461 | )
462 | elif (dcr_synth_train == dcr_synth_validation).all():
463 | percentage = 0.0
464 | warnings = (
465 | "The validation dataset is an exact copy/clone of the training dataset."
466 | )
467 | else:
468 | if dcr_synth_validation.sum() == 0:
469 | warnings = "The synthetic dataset is an exact copy/clone of the validation dataset."
470 |
471 | number_of_rows = dcr_synth_train.shape[0]
472 | synth_dcr_smaller_than_holdout_dcr_mask = dcr_synth_train < dcr_synth_validation
473 | synth_dcr_smaller_than_holdout_dcr_sum = (
474 | synth_dcr_smaller_than_holdout_dcr_mask.sum()
475 | )
476 | percentage = synth_dcr_smaller_than_holdout_dcr_sum / number_of_rows * 100
477 |
478 | dcr_validation = {"percentage": round(percentage,4), "warnings": warnings}
479 | if save_data:
480 | _save_to_json("dcr_validation", dcr_validation, path_to_json)
481 | return dcr_validation
--------------------------------------------------------------------------------
/sure/distance_metrics/gower_matrix_c.pyx:
--------------------------------------------------------------------------------
1 | # cython: language_level = 3
2 | # distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
3 |
4 | import numpy
5 | cimport numpy
6 | cimport cython
7 |
8 | cdef numpy.ndarray[numpy.float32_t, ndim=1] gower_row(
9 | numpy.ndarray[numpy.uint8_t, ndim=1] x_categoricals,
10 | numpy.ndarray[numpy.float32_t, ndim=1] x_numericals,
11 | numpy.ndarray[numpy.uint8_t, ndim=2] Y_categoricals,
12 | numpy.ndarray[numpy.float32_t, ndim=2] Y_numericals,
13 | numpy.ndarray[numpy.float32_t, ndim=1] numericals_ranges,
14 | numpy.float32_t features_weight_sum
15 | ):
16 | cdef numpy.ndarray[numpy.uint8_t, ndim=2] dist_x_Y_categoricals
17 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y_categoricals_sum
18 | cdef numpy.ndarray[numpy.float32_t, ndim=2] dist_x_Y_numericals
19 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y_numericals_sum
20 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y
21 |
22 | dist_x_Y_categoricals = (x_categoricals!=Y_categoricals).astype(numpy.uint8)
23 | dist_x_Y_categoricals_sum = dist_x_Y_categoricals.sum(axis=1).astype(numpy.float32)
24 |
25 | dist_x_Y_numericals = numpy.abs((x_numericals-Y_numericals)/numericals_ranges)
26 | dist_x_Y_numericals = numpy.nan_to_num(dist_x_Y_numericals)
27 | dist_x_Y_numericals_sum = dist_x_Y_numericals.sum(axis=1).astype(numpy.float32)
28 |
29 | dist_x_Y = dist_x_Y_categoricals_sum + dist_x_Y_numericals_sum
30 | dist_x_Y = dist_x_Y / features_weight_sum
31 |
32 | return dist_x_Y
33 |
34 |
35 | @cython.boundscheck(False) # turn off bounds-checking for entire function
36 | @cython.wraparound(False) # turn off negative index wrapping for entire function
37 | def gower_matrix_c(
38 | numpy.ndarray[numpy.uint8_t, ndim=2] X_categorical,
39 | numpy.ndarray[numpy.float32_t, ndim=2] X_numerical,
40 | numpy.ndarray[numpy.uint8_t, ndim=2] Y_categorical,
41 | numpy.ndarray[numpy.float32_t, ndim=2] Y_numerical,
42 | numpy.ndarray[numpy.float32_t, ndim=1] numericals_ranges,
43 | numpy.float32_t features_weight_sum,
44 | bint fill_diagonal,
45 | Py_ssize_t first_index
46 | ):
47 | cdef Py_ssize_t i
48 | cdef numpy.int32_t X_rows
49 | cdef numpy.ndarray[numpy.float32_t, ndim=1] dist_x_Y
50 | cdef numpy.ndarray[numpy.float32_t, ndim=1] distance_matrix
51 |
52 | X_rows = X_categorical.shape[0]
53 | distance_matrix = numpy.zeros(X_rows, dtype=numpy.float32)
54 |
55 | # per ogni istanza della matrice X
56 | for i in range(X_rows):
57 | dist_x_Y = gower_row(X_categorical[i, :], X_numerical[i, :], Y_categorical, Y_numerical, numericals_ranges, features_weight_sum)
58 |
59 | if (fill_diagonal and first_index < 0):
60 | dist_x_Y[i] = 5.0
61 | elif (fill_diagonal):
62 | dist_x_Y[first_index+i] = 5.0
63 | distance_matrix[i] = numpy.amin(dist_x_Y)
64 |
65 | return distance_matrix
--------------------------------------------------------------------------------
/sure/privacy/__init__.py:
--------------------------------------------------------------------------------
1 | from ..distance_metrics.distance import (distance_to_closest_record,
2 | dcr_stats,
3 | number_of_dcr_equal_to_zero,
4 | validation_dcr_test)
5 | from .privacy import (adversary_dataset,
6 | membership_inference_test)
7 |
8 | __all__ = [
9 | "distance_to_closest_record",
10 | "dcr_stats",
11 | "number_of_dcr_equal_to_zero",
12 | "validation_dcr_test",
13 | "adversary_dataset",
14 | "membership_inference_test",
15 | ]
--------------------------------------------------------------------------------
/sure/privacy/privacy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import polars as pl
4 |
5 | from typing import Dict, List, Tuple
6 |
7 | from sklearn.metrics import precision_score
8 |
9 | from sure.privacy import distance_to_closest_record
10 | from sure import _save_to_json, _drop_cols
11 |
12 | # ATTACK SANDBOX
13 | def _polars_to_pandas(dataframe: pl.DataFrame | pl.LazyFrame):
14 | """Convert polars DataFrame to pandas DataFrame"""
15 | if isinstance(dataframe, pl.DataFrame):
16 | dataframe = dataframe.to_pandas()
17 | if isinstance(dataframe, pl.LazyFrame):
18 | dataframe = dataframe.collect().to_pandas()
19 | return dataframe
20 |
21 | def _pl_pd_to_numpy(dataframe: pl.DataFrame | pl.LazyFrame | pl.Series | pd.DataFrame):
22 | """Convert polars or pandas DataFrame to numpy array"""
23 | if isinstance(dataframe, (pl.DataFrame, pl.Series, pd.DataFrame)):
24 | dataframe = dataframe.to_numpy()
25 | if isinstance(dataframe, pl.LazyFrame):
26 | dataframe = dataframe.collect().to_numpy()
27 | return dataframe
28 |
29 | def adversary_dataset(
30 | training_set: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
31 | validation_set: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
32 | original_dataset_sample_fraction: float = 0.2,
33 | ) -> pd.DataFrame:
34 | """
35 | Create an adversary dataset for the Membership Inference Test given a training
36 | and validation set. The validation set must be smaller than the training set.
37 |
38 | The size of the resulting adversary dataset is a fraction of the sum of the training
39 | set size and the validation set size.
40 |
41 | It takes half of the final rows from the training set and the other half from the
42 | validation set. It adds a column to mark which rows was sampled from the training set.
43 |
44 | Parameters
45 | ----------
46 | training_set : pd.DataFrame
47 | The training set as a pandas DataFrame.
48 | validation_set : pd.DataFrame
49 | The validation set as a pandas DataFrame.
50 | original_dataset_sample_fraction : float, optional
51 | How many rows (a fraction from 0 to 1) to sample from the concatenation of the
52 | training and validation set, by default 0.2
53 |
54 | Returns
55 | -------
56 | pd.DataFrame
57 | A new pandas DataFrame in which half of the rows come from the training set and
58 | the other half come from the validation set.
59 | """
60 | training_set = _polars_to_pandas(training_set)
61 | validation_set = _polars_to_pandas(validation_set)
62 |
63 | sample_number_of_rows = (
64 | training_set.shape[0] + validation_set.shape[0]
65 | ) * original_dataset_sample_fraction
66 |
67 | # if the validation set is very small, we'll set the number of rows to sample equal to
68 | # the number of rows of the validation set, that is every row of the validation set
69 | # is going into the adversary set.
70 | sample_number_of_rows = min(int(sample_number_of_rows / 2), validation_set.shape[0])
71 |
72 | sampled_from_training = training_set.sample(
73 | sample_number_of_rows, replace=False, random_state=42
74 | )
75 | sampled_from_training["privacy_test_is_training"] = True
76 |
77 | sampled_from_validation = validation_set.sample(
78 | sample_number_of_rows, replace=False, random_state=42
79 | )
80 | sampled_from_validation["privacy_test_is_training"] = False
81 |
82 | adversary_dataset = pd.concat(
83 | [sampled_from_training, sampled_from_validation], ignore_index=True
84 | )
85 | adversary_dataset = adversary_dataset.sample(frac=1).reset_index(drop=True)
86 | return adversary_dataset
87 |
88 | def membership_inference_test(
89 | adversary_dataset: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
90 | synthetic_dataset: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
91 | adversary_guesses_ground_truth: np.ndarray | pd.DataFrame | pl.DataFrame | pl.LazyFrame | pl.Series,
92 | parallel: bool = True,
93 | save_data = True,
94 | path_to_json: str = ""
95 | ):
96 | """
97 | Simulate a Membership Inference Attack on the provided synthetic dataset using an adversary dataset.
98 |
99 | Parameters
100 | ----------
101 | adversary_dataset : pd.DataFrame, pl.DataFrame, or pl.LazyFrame
102 | The dataset used by the adversary, containing features for the attack simulation.
103 | synthetic_dataset : pd.DataFrame, pl.DataFrame, or pl.LazyFrame
104 | The synthetic dataset on which the Membership Inference Attack is performed.
105 | adversary_guesses_ground_truth : np.ndarray, pd.DataFrame, pl.DataFrame, pl.LazyFrame, or pl.Series
106 | Ground truth labels indicating whether a sample is from the original training dataset or not.
107 | parallel : bool, optional
108 | Whether to use parallel processing for distance calculations, by default True.
109 | save_data : bool
110 | If True, saves the DCR information into the JSON file used to generate the final report, by default True.
111 | path_to_json : str, optional
112 | Path to save the attack output as a JSON file. If empty, the output is not saved, by default "".
113 |
114 | Returns
115 | -------
116 | dict
117 | A dictionary containing the attack results, including distance thresholds, precisions, and risk score.
118 |
119 | Notes
120 | -----
121 | - This function simulates an attack where an adversary attempts to distinguish between real and synthetic samples.
122 | - The attack results are saved to a JSON file if `path_to_json` is provided.
123 | """
124 |
125 | # Convert datasets
126 | adversary_dataset = _polars_to_pandas(adversary_dataset)
127 | synthetic_dataset = _polars_to_pandas(synthetic_dataset)
128 | adversary_guesses_ground_truth = _pl_pd_to_numpy(adversary_guesses_ground_truth)
129 |
130 | adversary_dataset=adversary_dataset.drop(["privacy_test_is_training"],axis=1)
131 |
132 | # Drop columns that are present in adversary_dataset but are missing in synthetic_dataset and vice versa
133 | synthetic_dataset, adversary_dataset = _drop_cols(synthetic_dataset, adversary_dataset)
134 |
135 | dcr_adversary_synth = distance_to_closest_record("other",
136 | adversary_dataset,
137 | synthetic_dataset,
138 | parallel=parallel,
139 | save_data=False
140 | )
141 |
142 | adversary_precisions = []
143 | distance_thresholds = np.quantile(
144 | dcr_adversary_synth, [0.5, 0.25, 0.2, np.min(dcr_adversary_synth) + 0.01]
145 | )
146 | for distance_threshold in distance_thresholds:
147 | adversary_guesses = dcr_adversary_synth < distance_threshold
148 | adversary_precision = precision_score(
149 | adversary_guesses_ground_truth, adversary_guesses, zero_division=0
150 | )
151 | adversary_precisions.append(max(adversary_precision, 0.5))
152 | adversary_precision_mean = np.mean(adversary_precisions).item()
153 | membership_inference_mean_risk_score = max(
154 | (adversary_precision_mean - 0.5) * 2, 0.0
155 | )
156 |
157 | attack_output = {
158 | "adversary_distance_thresholds": distance_thresholds.tolist(),
159 | "adversary_precisions": adversary_precisions,
160 | "membership_inference_mean_risk_score": membership_inference_mean_risk_score,
161 | }
162 |
163 | if save_data:
164 | _save_to_json("MIA_attack", attack_output, path_to_json)
165 | return attack_output
--------------------------------------------------------------------------------
/sure/report_generator/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | base="light"
3 | primaryColor="#6f329a"
4 | secondaryBackgroundColor="#a58cc9"
5 | textColor="#3d3d3d"
6 |
7 | [browser]
8 | gatherUsageStats = false
--------------------------------------------------------------------------------
/sure/report_generator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Clearbox-AI/SURE/4460454e24d142c73457d8f05326151be54f778f/sure/report_generator/__init__.py
--------------------------------------------------------------------------------
/sure/report_generator/pages/privacy.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 | import seaborn.objects as so
5 |
6 | @st.cache_data
7 | def plot_DCR(train_data, val_data=None):
8 | """
9 | Plot histograms for Distance to Closest Record (DCR) for synthetic training and validation datasets.
10 |
11 | This function creates histograms to visualize the distribution of DCR values for synthetic training data.
12 | If validation data is provided, it also plots the DCR distribution for the synthetic validation data.
13 |
14 | Parameters
15 | ----------
16 | train_data : array-like
17 | DCR values for the synthetic training dataset.
18 | val_data : array-like, optional
19 | DCR values for the synthetic validation dataset. If None, only the synthetic training DCR histogram
20 | is plotted. Default is None.
21 |
22 | Returns
23 | -------
24 | None
25 | The function does not return any value. It plots the histograms using Streamlit.
26 | """
27 | # Convert data to pandas DataFrame
28 | df = pd.DataFrame({'DCR': train_data, 'Data': 'Synthetic-Trainining DCR'})
29 | if val_data is not None:
30 | df_val = pd.DataFrame({'DCR': val_data, 'Data': 'Synthetic-Validation DCR'})
31 | df = pd.concat([df, df_val])
32 |
33 | # colors = ['#6268ff','#ccccff']
34 | # chart = alt.Chart(df).mark_bar(opacity=0.6).encode(
35 | # alt.X('DCR:Q', bin=alt.Bin(maxbins=15)),
36 | # alt.Y('count()', stack=None),
37 | # color=alt.Color('Data:N', scale=alt.Scale(range=colors))
38 | # ).properties(
39 | # title='Histograms of Synthetic Train and Validation Data' if val_data is not None else 'Histograms of Synthetic Train',
40 | # width=600,
41 | # height=400
42 | # )
43 | # # Display chart in Streamlit
44 | # st.altair_chart(chart)
45 |
46 | f = plt.figure(figsize=(8, 4))
47 | sf = f.subfigures(1, 1)
48 | (
49 | so.Plot(df, x="DCR")
50 | .facet("Data")
51 | .add(so.Bars(color="#6268ff"), so.Hist())
52 | .on(sf)
53 | .plot()
54 | )
55 |
56 | for ax in sf.axes:
57 | plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontsize=6)
58 | plt.setp(ax.get_yticklabels(), fontsize=8)
59 | ax.set_xlabel(ax.get_xlabel(), fontsize=6) # Set x-axis label font size
60 | ax.set_ylabel(ax.get_ylabel(), fontsize=6) # Set y-axis label font size
61 | ax.set_title(ax.get_title(), fontsize=8) # Set title font size
62 |
63 | # Display the plot in Streamlit
64 | st.pyplot(f)
65 |
66 | @st.cache_data
67 | def dcr_stats_table(train_stats, val_stats=None):
68 | """
69 | Display a table with overall Distance to Closest Record (DCR) statistics.
70 |
71 | Parameters
72 | ----------
73 | train_stats : dict
74 | Dictionary containing statistics for the synthetic training dataset.
75 | val_stats : dict, optional
76 | Dictionary containing statistics for the synthetic validation dataset. Default is None.
77 |
78 | Returns
79 | -------
80 | None
81 | The function outputs a table using Streamlit.
82 | """
83 | df1 = pd.DataFrame.from_dict(train_stats, orient='index', columns=['Synth-Train'])
84 | if val_stats:
85 | df2 = pd.DataFrame.from_dict(val_stats, orient='index', columns=['Synth-Val'])
86 | # Merge the train and val DataFrames
87 | merged_df = pd.concat([df1, df2], axis=1)
88 | st.table(merged_df)
89 |
90 | @st.cache_data
91 | def dcr_validation(dcr_val, dcr_zero_train=None, dcr_zero_val=None):
92 | """
93 | Display the DCR share value and additional metrics for clones.
94 |
95 | Parameters
96 | ----------
97 | dcr_val : dict
98 | Dictionary containing DCR share values and warnings.
99 | dcr_zero_train : int, optional
100 | Number of clones in the synthetic training dataset. Default is None.
101 | dcr_zero_val : int, optional
102 | Number of clones in the synthetic validation dataset. Default is None.
103 |
104 | Returns
105 | -------
106 | None
107 | The function outputs metrics and information using Streamlit.
108 | """
109 | perc = dcr_val["percentage"]
110 | # st.write("The share of records of the synthetic dataset that are closer to the training set than to the validation set is: ", round(perc,1),"%")
111 | cols = st.columns(2)
112 | with cols[0]:
113 | st.metric("DCR closer to training", str(round(perc,1))+"%", help="For each synthetic record we computed its DCR with respect to the training set as well as with respect to the validation set. The validation set was not used to train the generative model. If we can find that the DCR values between synthetic and training (histogram on the left) are not systematically smaller than the DCR values between synthetic and validation (histogram on the right), we gain evidence of a high level of privacy.The share of records that are then closer to a training than to a validation record serves us as our proposed privacy risk measure. If that resulting share is then close to (or smaller than) 50%, we gain empirical evidence of the training and validation data being interchangeable with respect to the synthetic data. This in turn allows to make a strong case for plausible deniability for any individual, as the synthetic data records do not allow to conjecture whether an individual was or was not contained in the training dataset.")
114 | with cols[1]:
115 | if dcr_val["warnings"] != '':
116 | st.write(dcr_val["warnings"])
117 | # st.caption("N.B. the closer this value is to 50%, the less the synthetic dataset is vulnerable to reidentification attacks!")
118 |
119 | cols2 = st.columns(2)
120 | # with cols2[0]:
121 | if dcr_zero_train:
122 | # Table with the number of DCR equal to zero
123 | st.metric("Clones Synth-Train", dcr_zero_train, help="The number of clones shows how many rows of the synthetic dataset have an identical match in the training dataset. A very low value indicates a low risk in terms of privacy. Ideally this value should be close to 0, but some peculiar characteristics of the training dataset (small size or low column cardinality) may lead to a higher value. The duplicatestable shows the number of duplicates (identical rows) in the training dataset and the synthetic dataset: similar percentages mean higher utility.")
124 | # with cols2[1]:
125 | if dcr_zero_val:
126 | st.metric("Clones Synth-Val", dcr_zero_val, help="The number of clones shows how many rows of the synthetic dataset have an identical match in the validation dataset. A very low value indicates a low risk in terms of privacy. Ideally this value should be close to 0, but some peculiar characteristics of the training dataset (small size or low column cardinality) may lead to a higher value.")
127 |
128 | def _MIA():
129 | cols = st.columns([1,3,2])
130 | with cols[0]:
131 | st.metric("MI mean risk score", str(round(st.session_state["MIA_attack"]["membership_inference_mean_risk_score"],3)*100)+"%", help="The MI Risk score is computed as (precision - 0.5) * 2.\n MI Risk Score smaller than 0.2 (20%) are considered to be very LOW RISK of disclosure due to membership inference.")
132 | with cols[1]:
133 | df_MIA = pd.DataFrame(st.session_state["MIA_attack"])
134 | st.dataframe(df_MIA.drop(columns=df_MIA.columns[-1]).iloc[::-1], hide_index=True)
135 |
136 | def main():
137 | """
138 | Main function to configure and run the Streamlit application.
139 |
140 | The application provides a report on privacy metrics for a synthetic dataset,
141 | using the SURE library to visualize and interpret the data.
142 |
143 | Returns
144 | -------
145 | None
146 | The function runs the Streamlit app.
147 | """
148 | # Set app conifgurations
149 | st.set_page_config(layout="wide", page_title='SURE', page_icon=':large_purple_square:')
150 |
151 | # Header, subheader and description
152 | st.title('SURE')
153 | st.subheader('Synthetic Data: Utility, Regulatory compliance, and Ethical privacy')
154 | st.write(
155 | """This report provides a visual digest of the privacy metrics computed
156 | with the library [SURE](https://github.com/Clearbox-AI/SURE) on the synthetic dataset under test.""")
157 |
158 | ### PRIVACY
159 | st.header("Privacy", divider='violet')
160 | st.sidebar.markdown("# Privacy")
161 |
162 | ## Distance to closest record
163 | st.subheader("Distance to closest record", help="Distances-to-Closest-Record are individual-level distances of synthetic records with respect to their corresponding nearest neighboring records from the training dataset. A DCR of 0 corresponds to an identical match. These histograms are used to assess whether the synthetic data is a simple copy or minor perturbation of the training data, resulting in high risk of disclosure. There is one DCR histogram computed only on the Training Set (histogram on the left) and one computed for the Synthetic Set vs Training Set (histogram on the right). Ideally, the two histograms should have a similar shape and, above all, the histogram on the right should be far enough away from 0.")
164 | # DCR statistics
165 | st.write("DCR statistics")
166 | col1, buff, col2 = st.columns([3,0.5,3])
167 | with col1:
168 | if "dcr_synth_train_stats" in st.session_state:
169 | dcr_stats_table(st.session_state["dcr_synth_train_stats"],
170 | st.session_state["dcr_synth_val_stats"])
171 | with col2:
172 | if "dcr_validation" in st.session_state:
173 | dcr_validation(st.session_state["dcr_validation"],
174 | st.session_state["dcr_synth_train_num_of_zeros"],
175 | st.session_state["dcr_synth_val_num_of_zeros"])
176 |
177 | # Synth-train DCR and synth-validation DCR histograms
178 | if "dcr_synth_train" in st.session_state:
179 | plot_DCR(st.session_state["dcr_synth_train"],
180 | st.session_state["dcr_synth_val"])
181 |
182 | st.divider()
183 |
184 | ## Membership Inference Attack
185 | st.subheader("Membership Inference Attack", help="Membership inference attacks seek to infer membership of an individual record in the training set from which the synthetic data was generated. We consider a hypothetical adversary which has access to a subset of records K containing half instances from the training set and half instances from the validation set and that attempts a membership inference attack as follows: given an individual record k in K and the synthetic set S, the adversary identifies the closest record s in S; the adversary determines that k is part of the training set if d(k, s) is lower than a certain threshold. We evaluate the success rate of such attack strategy. Precision represents the number of correct decisions the adversary has made. Since 50% of the instances in K come from the training set and 50% come from the validation set, the baseline precision is 0.5, corresponding to a random choice and any value above that reflects an increasing levels of disclosure risk.")
186 | if "MIA_attack" in st.session_state:
187 | _MIA()
188 |
189 | if __name__ == "__main__":
190 | main()
191 |
--------------------------------------------------------------------------------
/sure/report_generator/report_app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import pandas as pd
3 | import numpy as np
4 | import seaborn as sns
5 | import seaborn.objects as so
6 | import matplotlib as mpl
7 | import matplotlib.pyplot as plt
8 | import argparse
9 | import os
10 |
11 | from report_generator import _load_from_json, _convert_to_dataframe
12 |
13 |
14 | def _plot_hist(real_data, synth_data):
15 | """Plot the synth-train DCR and synth-validation DCR histograms"""
16 | # Convert data to pandas DataFrame
17 | real_label = pd.DataFrame({'is_real': ['Real']*len(real_data)})
18 | df_real = pd.concat([real_data, real_label], axis=1)
19 | synth_label = pd.DataFrame({'is_real': ['Synthetic']*len(synth_data)})
20 | df_synth = pd.concat([synth_data, synth_label], axis=1)
21 |
22 | df = pd.concat([df_real, df_synth])
23 | cols = df.columns.to_list()
24 | cols.remove("is_real")
25 |
26 | # Create dropdown menu for feature selection
27 | selected_feature = st.selectbox(label = 'Select a feature from the dataset:',
28 | options = ["Select a feature..."] + cols,
29 | index = None,
30 | placeholder = "Select a feature...",
31 | label_visibility = "collapsed")
32 |
33 | if selected_feature and selected_feature!="Select a feature...":
34 | f = plt.figure(figsize=(8, 4))
35 | sf = f.subfigures(1, 1)
36 | (
37 | so.Plot(df, x=selected_feature)
38 | .facet("is_real")
39 | .add(so.Bars(color="#3e42a8"), so.Hist())
40 | .on(sf)
41 | .plot()
42 | )
43 |
44 | for ax in sf.axes:
45 | plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontsize=6)
46 | plt.setp(ax.get_yticklabels(), fontsize=8)
47 | ax.set_xlabel(ax.get_xlabel(), fontsize=6) # Set x-axis label font size
48 | ax.set_ylabel(ax.get_ylabel(), fontsize=6) # Set y-axis label font size
49 | ax.set_title(ax.get_title(), fontsize=8) # Set title font size
50 |
51 | # Display the plot in Streamlit
52 | st.pyplot(f)
53 |
54 |
55 | def _plot_heatmap(data, title):
56 | df = pd.DataFrame(data)
57 | # Generate a mask for the upper triangle
58 | mask = np.triu(np.ones_like(df, dtype=bool), 1)
59 |
60 | # Set up the matplotlib figure
61 | f, ax = plt.subplots(figsize=(13, 11))
62 |
63 | # Generate a custom diverging colormap
64 | cmap = sns.diverging_palette(20, 265, as_cmap=True)
65 |
66 | # Draw the heatmap with the mask, correct aspect ratio, and column names as labels
67 | sns.heatmap(df, cmap=cmap, center=0,
68 | square=True, mask=mask, linewidths=.5,
69 | cbar_kws={"shrink": .75},
70 | xticklabels=df.columns, # Display column names on x-axis
71 | yticklabels=df.columns # Display column names on y-axis
72 | )
73 |
74 | # Rotate x-axis labels for better readability and add title
75 | plt.xticks(rotation=45, ha='right')
76 | ax.set_title(title, fontsize=16, pad=20)
77 |
78 | # Display the plot
79 | st.pyplot(f)
80 |
81 | def _display_feature_data(data):
82 | """Display the data for a selected feature"""
83 | # Get list of feature names
84 | feature_names = list(data.keys())
85 |
86 | # Create dropdown menu for feature selection
87 | selected_feature = st.selectbox(label = 'Select a statistical quantity:',
88 | options = ["Select a statistical quantity..."] + feature_names,
89 | index = None,
90 | placeholder = "Select a statistical quantity...",
91 | label_visibility = "collapsed")
92 |
93 | # If a feature has been selected, display its data and create another dropdown menu
94 | if selected_feature and selected_feature!="Select a statistical quantity...":
95 | # Get data for selected feature
96 | feature_data = data[selected_feature]
97 |
98 | # Convert data to DataFrame
99 | df_real = pd.DataFrame(feature_data['real'], index=["Original"])
100 | df_synthetic = pd.DataFrame(feature_data['synthetic'], index=["Synthetic"])
101 | df = pd.concat([df_real, df_synthetic])
102 | # Add row with the difference between the synthetic dataset and the real one
103 | df.loc['Difference'] = df.iloc[0] - df.iloc[1]
104 |
105 | # Display DataFrame
106 | st.write(df)
107 |
108 | # Remove selected feature from list of feature names
109 | feature_names.remove(selected_feature)
110 |
111 | # If there are still features left, create another dropdown menu
112 | if feature_names:
113 | _display_feature_data({name: data[name] for name in feature_names})
114 |
115 | def _ml_utility():
116 | def _select_all():
117 | st.session_state["selected_models"] = models_df.index.values
118 | def _deselect_all():
119 | st.session_state["selected_models"] = []
120 |
121 | if 'selected_models' not in st.session_state:
122 | st.session_state["selected_models"] = []
123 |
124 | container = st.container()
125 |
126 | cols = st.columns([1.2,1])
127 | default = ["LinearSVC", "Perceptron", "LogisticRegression", "XGBClassifier", "SVC", "BaggingClassifier", "SGDClassifier", "RandomForestClassifier", "AdaBoostClassifier", "KNeighborsClassifier", "DecisionTreeClassifier", "DummyClassifier"]
128 |
129 | models_real_df = _convert_to_dataframe(st.session_state["models"]).set_index(['Model'])
130 | models_synth_df = _convert_to_dataframe(st.session_state["models_synth"]).set_index(['Model'])
131 | models_df = pd.concat([models_real_df, models_synth_df], axis=1)
132 | interleaved_columns = [col for pair in zip(models_real_df.columns, models_synth_df.columns) for col in pair]
133 | models_df = models_df[interleaved_columns]
134 | models_delta_df = _convert_to_dataframe(st.session_state["models_delta"]).set_index(['Model'])
135 |
136 | st.session_state["selected_models"] = default
137 |
138 | options = st.multiselect(label="Select ML models to show in the table:",
139 | options=models_df.index.values,
140 | default= [x for x in st.session_state["selected_models"] if x in models_df.index.values],
141 | placeholder="Select ML models...",
142 | key="models_multiselect")
143 | subcols = st.columns([1,1,3])
144 | with subcols[0]:
145 | butt1 = st.button("Select all models")
146 | with subcols[1]:
147 | butt2 = st.button("Deselect all models")
148 | if butt1:
149 | options = models_df.index.values
150 | if butt2:
151 | options = []
152 |
153 | st.text("") # vertical space
154 | st.text("") # vertical space
155 | st.text("ML Metrics")
156 | st.dataframe(models_df.loc[options].style.highlight_max(axis=0, subset=models_df.columns[:-2], color="#5cbd91"))
157 | st.text("") # vertical space
158 | st.text("") # vertical space
159 | st.text("ML Delta Metrics", help="Difference between the metrics of the original dataset and the synthetic dataset.")
160 | st.dataframe(models_delta_df.abs().loc[options].style.highlight_min(axis=0, subset=models_delta_df.columns[:-1], color="#5cbd91").highlight_max(axis=0, subset=models_delta_df.columns[:-1], color="#c45454"))
161 |
162 | # def _ml_utility(models_df):
163 | # def _select_all():
164 | # st.session_state['selected_models'] = models_df.index.values
165 | # def _deselect_all():
166 | # st.session_state['selected_models'] = []
167 |
168 | # if 'selected_models' not in st.session_state:
169 | # st.session_state['selected_models'] = ["LinearSVC", "Perceptron", "LogisticRegression", "XGBClassifier", "SVC", "BaggingClassifier", "SGDClassifier", "RandomForestClassifier", "AdaBoostClassifier", "KNeighborsClassifier", "DecisionTreeClassifier", "DummyClassifier"]
170 |
171 | # container = st.container()
172 |
173 | # col1, col2, col3 = st.columns([1, 1, 1])
174 | # with col1:
175 | # st.button("Select all models", on_click=_select_all, key="butt1")
176 | # with col2:
177 | # st.button("Deselect all models", on_click=_deselect_all, key="butt2")
178 |
179 | # selected_models = container.multiselect(
180 | # "Select ML models to show in the table:",
181 | # models_df.index.values,
182 | # default=st.session_state['selected_models'],
183 | # key='selected_models'
184 | # )
185 |
186 | # with col1:
187 | # st.dataframe(models_df.loc[selected_models].style.highlight_max(axis=0, subset=models_df.columns[:-1], color="#99ffcc"))
188 | # with col2:
189 | # st.session_state['selected_models']
190 | # with col3:
191 | # selected_models
192 |
193 | # def _ml_utility():
194 | # def select_all():
195 | # st.session_state['selected_models'] = ['A', 'B', 'C', 'D']
196 | # def deselect_all():
197 | # st.session_state['selected_models'] = []
198 |
199 | # if 'selected_models' not in st.session_state:
200 | # st.session_state['selected_models'] = []
201 |
202 | # container = st.container()
203 |
204 | # col1, col2 = st.columns([1, 1])
205 | # with col1:
206 | # st.button("Select all", on_click=select_all, key="butt1")
207 | # with col2:
208 | # st.button("Deselect all", on_click=deselect_all, key="butt2")
209 |
210 | # selected_models = container.multiselect(
211 | # "Select one or more options:",
212 | # ['A', 'B', 'C', 'D'],
213 | # default=st.session_state['selected_models'],
214 | # key='selected_models'
215 | # )
216 |
217 | # st.markdown('##')
218 | # cols = st.columns(2)
219 | # with cols[0]:
220 | # "st.session_state['selected_models']:"
221 | # st.session_state['selected_models']
222 | # with cols[1]:
223 | # "selected_models:"
224 | # selected_models
225 |
226 | def main(real_df, synth_df, path_to_json):
227 | """
228 | Main function to generate a Streamlit-based report for analyzing and visualizing
229 | utility metrics of a synthetic dataset using the SURE library.
230 |
231 | Parameters
232 | ----------
233 | real_df : str
234 | Path to the pickle file containing the real dataset.
235 | synth_df : str
236 | Path to the pickle file containing the synthetic dataset.
237 | path_to_json : str
238 | Path to the JSON file for loading session state data. If not provided, the default session state is used.
239 |
240 | Returns
241 | -------
242 | None
243 | The function initializes and runs a Streamlit app for visualizing and comparing utility metrics.
244 | """
245 | # Set app conifgurations
246 | st.set_page_config(layout="wide", page_title='SURE', page_icon='https://raw.githubusercontent.com/Clearbox-AI/SURE/main/docs/source/img/favicon.ico')
247 |
248 | # Header and subheader and description
249 | st.title('SURE')
250 | st.subheader('Synthetic Data: Utility, Regulatory compliance, and Ethical privacy')
251 | st.write(
252 | """This report provides a visual digest of the utility metrics computed
253 | with the library [SURE](https://github.com/Clearbox-AI/SURE) on the synthetic dataset under test.""")
254 |
255 | ### UTILITY
256 | st.header('Utility', divider='gray')
257 | st.sidebar.markdown("# Utility")
258 |
259 | # Load real dataset
260 | real_df = pd.read_pickle(real_df)
261 | synth_df = pd.read_pickle(synth_df)
262 |
263 | ## Plot real and synthetic data distributions
264 | st.subheader("Dataset feature distribution", help="Distribution histogram plot of the selected feature for the real and synthetic dataset (before pre-processing).")
265 | _plot_hist(real_df, synth_df)
266 |
267 | # Load data in the session state, so that it is available in all the pages of the app
268 | if path_to_json:
269 | st.session_state = _load_from_json(path_to_json)
270 | else:
271 | st.session_state = _load_from_json("")
272 |
273 | ## Statistical similarity
274 | st.subheader("Statistical similarity", help="Statistical quantities computed for each feature of the real and synthetic dataset (after pre-processing).")
275 |
276 | # Features distribution
277 | # plot_distribution()
278 |
279 | # General statistics
280 | if "num_features_comparison" in st.session_state and st.session_state["num_features_comparison"]:
281 | features_comparison = st.session_state["num_features_comparison"]
282 | if "cat_features_comparison" in st.session_state and st.session_state["cat_features_comparison"]:
283 | if "features_comparison" in locals():
284 | features_comparison = {**features_comparison, **st.session_state["cat_features_comparison"]}
285 | else:
286 | features_comparison = st.session_state["cat_features_comparison"]
287 | if "time_features_comparison" in st.session_state and st.session_state["time_features_comparison"]:
288 | if "features_comparison" in locals():
289 | features_comparison = {**features_comparison, **st.session_state["time_features_comparison"]}
290 | else:
291 | features_comparison = st.session_state["time_features_comparison"]
292 | if features_comparison:
293 | _display_feature_data(features_comparison)
294 |
295 | st.markdown('###')
296 |
297 | # Correlation
298 | st.subheader("Feature correlation", help="This matrix shows the correlation between ordinal and categorical features. These correlation coefficients are obtained using the mutual information metric. Mutual information describes relationships in terms of uncertainty.")
299 | if "real_corr" in st.session_state:
300 | cb_rea_corr = st.checkbox("Show original dataset correlation matrix", value=False)
301 | if cb_rea_corr:
302 | _plot_heatmap(st.session_state["real_corr"], 'Original Dataset Correlation Matrix Heatmap')
303 | if "synth_corr" in st.session_state:
304 | cb_synth_corr = st.checkbox("Show synthetic dataset correlation matrix", value=False)
305 | if cb_synth_corr:
306 | _plot_heatmap(st.session_state["synth_corr"], 'Synthetic Dataset Correlation Matrix Heatmap')
307 | if "diff_corr" in st.session_state:
308 | cb_diff_corr = st.checkbox("Show difference between the correlation matrix of the original dataset and the one of the synthetic dataset", value=False)
309 | if cb_diff_corr:
310 | _plot_heatmap(st.session_state["diff_corr"], 'Original-Synthetic Dataset Correlation Matrix Difference Heatmap')
311 |
312 | st.divider()
313 |
314 | ## ML Utility
315 | st.subheader("ML utility", help="The datasets were evaluated using various machine learning models. The metrics presented below reflect the performance of each model tested, applied to both the real and synthetic datasets.")
316 | if "models" in st.session_state:
317 | _ml_utility()
318 |
319 | if __name__ == "__main__":
320 | # Create an ArgumentParser object
321 | parser = argparse.ArgumentParser(description="This script runs the utility and privacy report app of the SURE library.")
322 | parser.add_argument('real_data', type=str, default="", help='real dataframe')
323 | parser.add_argument('synth_data', type=str, default="", help='synthetic dataframe')
324 | parser.add_argument('path', type=str, default="", help='path where the json file with the results is saved')
325 | args = parser.parse_args()
326 |
327 | main(args.real_data, args.synth_data, args.path)
--------------------------------------------------------------------------------
/sure/report_generator/report_generator.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import pkg_resources
3 | import sys
4 |
5 | import json
6 | import os
7 | import tempfile
8 |
9 | import pandas as pd
10 | import polars as pl
11 | import numpy as np
12 |
13 | # Function to run the streamlit app
14 | def report(df_real: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
15 | df_synth: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
16 | path_to_json:str = ""):
17 | """Generate the report app"""
18 | # Check dataframe type
19 | if isinstance(df_real, pd.DataFrame) and isinstance(df_synth, pd.DataFrame):
20 | pass
21 | elif isinstance(df_real, pl.DataFrame) and isinstance(df_synth, pl.DataFrame):
22 | df_real = df_real.to_pandas()
23 | df_synth = df_synth.to_pandas()
24 | elif isinstance(df_real, pl.LazyFrame) and isinstance(df_synth, pl.LazyFrame):
25 | df_real = df_real.collect().to_pandas()
26 | df_synth = df_synth.collect().to_pandas()
27 | else:
28 | sys.exit('ErrorType\nThe datatype provided is not supported or the two datasets have different types.')
29 |
30 | # Save the DataFrame to a temporary file (pickle format)
31 | # if df_real:
32 | with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmpfile:
33 | df_path_real = tmpfile.name
34 | df_real.to_pickle(df_path_real) # Save the DataFrame as a pickle file
35 |
36 | with tempfile.NamedTemporaryFile(delete=False, suffix='.pkl') as tmpfile:
37 | df_path_synth = tmpfile.name
38 | df_synth.to_pickle(df_path_synth) # Save the DataFrame as a pickle file
39 |
40 | report_path = pkg_resources.resource_filename('sure.report_generator', 'report_app.py')
41 | process = subprocess.run(['streamlit', 'run', report_path, df_path_real, df_path_synth, path_to_json])
42 | return process
43 |
44 | def _convert_to_serializable(obj: object):
45 | """Recursively convert DataFrames and other non-serializable objects in a nested dictionary to serializable formats"""
46 | if isinstance(obj, (pd.DataFrame, pl.DataFrame, pl.LazyFrame)):
47 | if isinstance(obj, pl.DataFrame):
48 | obj = obj.to_pandas()
49 | if isinstance(obj, pl.LazyFrame):
50 | obj = obj.collect().to_pandas()
51 |
52 | # Convert index to column only if index is non-numerical
53 | if obj.index.dtype == 'object' or pd.api.types.is_string_dtype(obj.index):
54 | obj = obj.reset_index()
55 |
56 | # Convert datetime columns to string
57 | for col in obj.columns:
58 | if pd.api.types.is_datetime64_any_dtype(obj[col]):
59 | obj[col] = obj[col].astype(str)
60 |
61 | return obj.to_dict(orient='records')
62 | elif isinstance(obj, dict):
63 | return {k: _convert_to_serializable(v) for k, v in obj.items()}
64 | elif isinstance(obj, list):
65 | return [_convert_to_serializable(item) for item in obj]
66 | elif isinstance(obj, np.integer):
67 | return int(obj)
68 | elif isinstance(obj, np.floating):
69 | return float(obj)
70 | elif isinstance(obj, np.ndarray):
71 | return obj.tolist()
72 | else:
73 | return obj
74 |
75 | def _save_to_json(data_name: str,
76 | new_data: object,
77 | path_to_json: str):
78 | """Save data into a JSON file in the folder where the user is working"""
79 | # Check if the file exists
80 | path = os.path.join(path_to_json,"data.json")
81 | if os.path.exists(path):
82 | # Read the existing data from the file
83 | with open(path, 'r') as file:
84 | try:
85 | data = json.load(file)
86 | except json.JSONDecodeError:
87 | data = {} # Initialize to an empty dictionary if the file is empty or invalid
88 | else:
89 | data = {}
90 |
91 | # Convert new_data to a serializable format if it is a DataFrame or contains DataFrames
92 | serializable_data = _convert_to_serializable(new_data)
93 |
94 | # Update the data dictionary with the new_data
95 | data[data_name] = serializable_data
96 |
97 | # Write the updated data back to the file
98 | with open(path, 'w') as file:
99 | json.dump(data, file, indent=4)
100 |
101 | def _load_from_json(path_to_json: str,
102 | data_name: str = None
103 | ):
104 | """Load data from a JSON file "data.json" in the folder where the user is working"""
105 | # Check if the file exists
106 | path = os.path.join(path_to_json,"data.json")
107 | if not os.path.exists(path):
108 | raise FileNotFoundError("The data.json file does not exist.")
109 |
110 | # Read the data from the file
111 | with open(path, 'r') as file:
112 | try:
113 | data = json.load(file)
114 | except json.JSONDecodeError:
115 | raise ValueError("The data.json file is empty or invalid.")
116 |
117 | if data_name:
118 | # Extract the relevant data
119 | if data_name not in data:
120 | raise KeyError(f"Key '{data_name}' not found in dictionary.")
121 | data = data.get(data_name, None)
122 |
123 | return data
124 |
125 | def _convert_to_dataframe(obj):
126 | """Convert nested dictionaries back to DataFrames"""
127 | if isinstance(obj, list) and all(isinstance(item, dict) for item in obj):
128 | return pd.DataFrame(obj)
129 | elif isinstance(obj, dict):
130 | return {k: _convert_to_dataframe(v) for k, v in obj.items()}
131 | else:
132 | return obj
--------------------------------------------------------------------------------
/sure/utility.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Callable, List
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import polars as pl
6 | import polars.selectors as cs
7 | import random
8 | from functools import reduce
9 | from sklearn.metrics import accuracy_score, roc_auc_score
10 | from sklearn.preprocessing import LabelEncoder
11 | from clearbox_preprocessor import Preprocessor
12 |
13 | from sure import _save_to_json
14 | from sure._lazypredict import LazyClassifier, LazyRegressor
15 |
16 | def _to_polars_df(df):
17 | """Converting Real and Synthetic Dataset into pl.DataFrame"""
18 | if isinstance(df, pd.DataFrame):
19 | df = pl.from_pandas(df)
20 | elif isinstance(df, np.ndarray):
21 | df = pl.from_numpy(df)
22 | elif isinstance(df, pl.LazyFrame):
23 | df = df.collect()
24 | elif isinstance(df, pl.DataFrame):
25 | pass
26 | else:
27 | raise TypeError("Invalid type for dataframe")
28 | return df
29 |
30 | def _to_numpy(data):
31 | ''' This functions transforms polars or pandas DataFrames or LazyFrames into numpy arrays'''
32 | if isinstance(data, pl.LazyFrame):
33 | return data.collect().to_numpy()
34 | elif isinstance(data, pl.DataFrame | pd.DataFrame| pl.Series | pd.Series):
35 | return data.to_numpy()
36 | elif isinstance(data, np.ndarray):
37 | return data
38 | else:
39 | print("The dataframe must be a Polars LazyFrame or a Pandas DataFrame")
40 | return
41 |
42 | def _drop_cols(synth, real):
43 | ''' This function returns the real dataset without the columns that are not present in the synthetic one
44 | '''
45 | if not isinstance(synth, np.ndarray) and not isinstance(real, np.ndarray):
46 | col_synth = set(synth.columns)
47 | col_real = set(real.columns)
48 | not_in_real = col_synth-col_real
49 | not_in_synth = col_real-col_synth
50 |
51 | # Drop columns that are present in the real dataset but are missing in the synthetic one
52 | if len(not_in_real)>0:
53 | print(f"""Warning: The following columns of the synthetic dataset are not present in the real dataset and were dropped to carry on with the computation:\n{not_in_real}\nIf you used only a subset of the dataset for computation, consider increasing the number of rows to ensure that all categorical values are adequately represented after one-hot-encoding.""")
54 | if isinstance(synth, pd.DataFrame):
55 | synth = synth.drop(columns=list(not_in_real))
56 | if isinstance(synth, pl.DataFrame):
57 | synth = synth.drop(list(not_in_real))
58 | if isinstance(synth, pl.LazyFrame):
59 | synth = synth.collect.drop(list(not_in_real))
60 |
61 | # Drop columns that are present in the real dataset but are missing in the synthetic one
62 | if isinstance(real, pd.DataFrame):
63 | real = real.drop(columns=list(not_in_synth))
64 | if isinstance(real, pl.DataFrame):
65 | real = real.drop(list(not_in_synth))
66 | if isinstance(real, pl.LazyFrame):
67 | real = real.collect.drop(list(not_in_synth))
68 | return synth, real
69 |
70 | # MODEL GARDEN MODULE
71 | class ClassificationGarden:
72 | """
73 | A class to facilitate the training and evaluation of multiple classification models
74 | using LazyClassifier.
75 |
76 | Parameters
77 | ----------
78 | verbose : int, optional
79 | Verbosity level, by default 0.
80 | ignore_warnings : bool, optional
81 | Whether to ignore warnings, by default True.
82 | custom_metric : callable, optional
83 | Custom metric function for evaluation, by default None.
84 | predictions : bool, optional
85 | Whether to return predictions along with model performance, by default False.
86 | classifiers : str or list, optional
87 | List of classifiers to use, or "all" for using all available classifiers, by default "all".
88 |
89 | :meta private:
90 | """
91 |
92 | def __init__(
93 | self,
94 | verbose = 0,
95 | ignore_warnings = True,
96 | custom_metric = None,
97 | predictions = False,
98 | classifiers = "all",
99 | ):
100 | self.verbose = verbose
101 | self.ignore_warnings = ignore_warnings
102 | self.custom_metric = custom_metric
103 | self.predictions = predictions
104 | self.classifiers = classifiers
105 |
106 | self.clf = LazyClassifier(verbose = verbose,
107 | ignore_warnings = ignore_warnings,
108 | custom_metric = custom_metric,
109 | predictions = predictions,
110 | classifiers = classifiers)
111 |
112 | def fit(
113 | self,
114 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
115 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
116 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
117 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray
118 | ) -> pd.DataFrame | np.ndarray:
119 | """
120 | Fit multiple classification models on the provided training data and evaluate on the test data.
121 |
122 | Parameters
123 | ----------
124 | X_train : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray
125 | Training features.
126 | X_test : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray
127 | Test features.
128 | y_train : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray
129 | Training labels.
130 | y_test : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray
131 | Test labels.
132 |
133 | Returns
134 | -------
135 | pd.DataFrame or np.ndarray
136 | Model performance metrics and predictions if specified.
137 |
138 | :meta private:
139 | """
140 | data = [X_train, X_test, y_train, y_test]
141 |
142 | for count, el in enumerate(data):
143 | data[count] = _to_numpy(el)
144 |
145 | models, predictions = self.clf.fit(data[0], data[1], data[2], data[3])
146 | return models, predictions
147 |
148 | class RegressionGarden():
149 | """
150 | A class to facilitate the training and evaluation of multiple regression models
151 | using LazyRegressor.
152 |
153 | Parameters
154 | ----------
155 | verbose : int, optional
156 | Verbosity level, by default 0.
157 | ignore_warnings : bool, optional
158 | Whether to ignore warnings, by default True.
159 | custom_metric : callable, optional
160 | Custom metric function for evaluation, by default None.
161 | predictions : bool, optional
162 | Whether to return predictions along with model performance, by default False.
163 | regressors : str or list, optional
164 | List of regressors to use, or "all" for using all available regressors, by default "all".
165 |
166 | :meta private:
167 | """
168 | def __init__(
169 | self,
170 | verbose = 0,
171 | ignore_warnings = True,
172 | custom_metric = None,
173 | predictions = False,
174 | regressors = "all"
175 | ):
176 | self.verbose = verbose
177 | self.ignore_warnings =ignore_warnings
178 | self.custom_metric = custom_metric
179 | self.predictions = predictions
180 | self.regressors = regressors
181 |
182 | self.reg = LazyRegressor(verbose = verbose,
183 | ignore_warnings = ignore_warnings,
184 | custom_metric = custom_metric,
185 | predictions = predictions,
186 | regressors = regressors)
187 |
188 | def fit(
189 | self,
190 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
191 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
192 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
193 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray
194 | ) -> pd.DataFrame | np.ndarray:
195 | """
196 | Fit multiple regression models on the provided training data and evaluate on the test data.
197 |
198 | Parameters
199 | ----------
200 | X_train : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray
201 | Training features.
202 | X_test : pl.DataFrame, pl.LazyFrame, pd.DataFrame, or np.ndarray
203 | Test features.
204 | y_train : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray
205 | Training labels.
206 | y_test : pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, or np.ndarray
207 | Test labels.
208 |
209 | Returns
210 | -------
211 | pd.DataFrame or np.ndarray
212 | Model performance metrics and predictions if specified.
213 |
214 | :meta private:
215 | """
216 | data = [X_train, X_test, y_train, y_test]
217 |
218 | for count, el in enumerate(data):
219 | data[count] = _to_numpy(el)
220 |
221 | models, predictions = self.reg.fit(data[0], data[1], data[2], data[3])
222 | return models, predictions
223 |
224 | # ML UTILITY METRICS MODULE
225 | def compute_utility_metrics_class(
226 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
227 | X_synth: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
228 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
229 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
230 | y_synth: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
231 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
232 | custom_metric: Callable = None,
233 | classifiers: List[Callable] = "all",
234 | predictions: bool = False,
235 | save_data = True,
236 | path_to_json: str = ""
237 | ):
238 | """
239 | Train and evaluate classification models on both real and synthetic datasets.
240 |
241 | Parameters
242 | ----------
243 | X_train : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
244 | Training features for real data.
245 | X_synth : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
246 | Training features for synthetic data.
247 | X_test : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
248 | Test features for evaluation.
249 | y_train : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray]
250 | Training labels for real data.
251 | y_synth : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray]
252 | Training labels for synthetic data.
253 | y_test : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray]
254 | Test labels for evaluation.
255 | custom_metric : Callable, optional
256 | Custom metric for model evaluation, by default None.
257 | classifiers : List[Callable], optional
258 | List of classifiers to use, or "all" for all available classifiers, by default "all".
259 | predictions : bool, optional
260 | If True, returns predictions along with model performance, by default False.
261 | save_data : bool
262 | If True, saves the DCR information into the JSON file used to generate the final report, by default True.
263 | path_to_json : str, optional
264 | Path to save the output JSON files, by default "".
265 |
266 | Returns
267 | -------
268 | Union[pd.DataFrame, pl.DataFrame, np.ndarray]
269 | Model performance metrics for both real and synthetic datasets, and optionally predictions.
270 | """
271 | # Store DataFrame type information for returning the same type
272 | was_pd = True
273 | was_pl = False
274 | was_np = False
275 |
276 | if isinstance(X_train, pl.DataFrame) or isinstance(X_train, pl.LazyFrame):
277 | was_pl = True
278 | elif isinstance(X_train, np.ndarray):
279 | was_np = True
280 |
281 | # Initialise ClassificationGarden class and start training
282 | classifier = ClassificationGarden(predictions=predictions, classifiers=classifiers, custom_metric=custom_metric)
283 | print('Fitting original models:')
284 | models_train, pred_train = classifier.fit(X_train[[w for w in X_train.columns if w in X_test.columns]],
285 | X_test[[w for w in X_train.columns if w in X_test.columns]],
286 | y_train,
287 | y_test)
288 |
289 | print('Fitting synthetic models:')
290 | classifier_synth = ClassificationGarden(predictions=predictions, classifiers=classifiers, custom_metric=custom_metric)
291 | models_synth, pred_synth = classifier_synth.fit(X_synth[[w for w in X_synth.columns if w in X_test.columns]],
292 | X_test[[w for w in X_synth.columns if w in X_test.columns]],
293 | y_synth,
294 | y_test)
295 | delta = models_train-models_synth
296 | col_names_delta = [s + " Delta" for s in list(delta.columns)]
297 | delta.columns = col_names_delta
298 | col_names_train= [s + " Real" for s in list(models_train.columns)]
299 | models_train.columns = col_names_train
300 | col_names_synth = [s + " Synth" for s in list(models_synth.columns)]
301 | models_synth.columns = col_names_synth
302 |
303 | if save_data:
304 | _save_to_json("models", models_train, path_to_json)
305 | _save_to_json("models_synth", models_synth, path_to_json)
306 | _save_to_json("models_delta", delta, path_to_json)
307 |
308 | # Transform the output DataFrames into the type used for the input DataFrames
309 | if was_pl:
310 | models_train = pl.from_pandas(models_train)
311 | models_synth = pl.from_pandas(models_synth)
312 | delta = pl.from_pandas(delta)
313 | elif was_np:
314 | models_train = pl.from_numpy(models_train)
315 | models_synth = pl.from_numpy(models_synth)
316 | delta = pl.from_numpy(delta)
317 |
318 | if predictions:
319 | if isinstance(y_train, (pl.DataFrame, pl.LazyFrame, pl.Series)):
320 | y_train = y_train.to_pandas()
321 | y_synth = y_synth.to_pandas()
322 | elif isinstance(y_train, np.ndarray):
323 | y_train = pd.DataFrame(y_train)
324 | y_synth = pd.DataFrame(y_synth)
325 |
326 | pred_train = pd.concat([y_train,pred_train], axis=1)
327 | pred_train.columns.values[0] = 'Ground truth'
328 | pred_synth = pd.concat([y_synth,pred_synth], axis=1)
329 | pred_synth.columns.values[0] = 'Ground truth'
330 |
331 | if was_pl:
332 | pred_train = pl.from_pandas(pred_train)
333 | pred_synth = pl.from_pandas(pred_synth)
334 | elif was_np:
335 | pred_train = pred_train.to_numpy()
336 | pred_synth = pred_synth.to_numpy()
337 | return models_train, models_synth, delta, pred_train, pred_synth
338 | else:
339 | return models_train, models_synth, delta
340 |
341 | def compute_utility_metrics_regr(
342 | X_train: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
343 | X_synth: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
344 | X_test: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
345 | y_train: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
346 | y_synth: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
347 | y_test: pl.DataFrame | pl.LazyFrame | pl.Series | pd.Series | pd.DataFrame | np.ndarray,
348 | custom_metric: Callable = None,
349 | regressors: List[Callable] = "all",
350 | predictions: bool = False,
351 | save_data = True,
352 | path_to_json: str = ""
353 | ):
354 | """
355 | Train and evaluate regression models on both real and synthetic datasets.
356 |
357 | Parameters
358 | ----------
359 | X_train : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
360 | Training features for real data.
361 | X_synth : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
362 | Training features for synthetic data.
363 | X_test : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
364 | Test features for evaluation.
365 | y_train : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray]
366 | Training labels for real data.
367 | y_synth : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray]
368 | Training labels for synthetic data.
369 | y_test : Union[pl.DataFrame, pl.LazyFrame, pl.Series, pd.Series, pd.DataFrame, np.ndarray]
370 | Test labels for evaluation.
371 | custom_metric : Callable, optional
372 | Custom metric for model evaluation, by default None.
373 | regressors : List[Callable], optional
374 | List of regressors to use, or "all" for all available regressors, by default "all".
375 | predictions : bool, optional
376 | If True, returns predictions along with model performance, by default False.
377 | save_data : bool
378 | If True, saves the DCR information into the JSON file used to generate the final report, by default True.
379 | path_to_json : str, optional
380 | Path to save the output JSON files, by default "".
381 |
382 | Returns
383 | -------
384 | Union[pd.DataFrame, pl.DataFrame, np.ndarray]
385 | Model performance metrics for both real and synthetic datasets, and optionally predictions.
386 | """
387 | # Store DataFrame type information for returning the same type
388 | was_pd = True
389 | was_pl = False
390 | was_np = False
391 |
392 | if isinstance(X_train, pl.DataFrame) or isinstance(X_train, pl.LazyFrame):
393 | was_pl = True
394 | elif isinstance(X_train, np.ndarray):
395 | was_np = True
396 |
397 | # Initialise RegressionGarden class and start training
398 | regressor = RegressionGarden(predictions=predictions, regressors=regressors, custom_metric=custom_metric)
399 | print('Fitting original models:')
400 | models_train, pred_train = regressor.fit(X_train[[w for w in X_train.columns if w in X_test.columns]],
401 | X_test[[w for w in X_train.columns if w in X_test.columns]],
402 | y_train,
403 | y_test)
404 | print('Fitting synthetic models:')
405 |
406 | regressor_synth = RegressionGarden(predictions=predictions, regressors=regressors, custom_metric=custom_metric)
407 | models_synth, pred_synth = regressor_synth.fit(X_synth[[w for w in X_synth.columns if w in X_test.columns]],
408 | X_test[[w for w in X_synth.columns if w in X_test.columns]],
409 | y_synth,
410 | y_test)
411 | delta = models_train-models_synth
412 | col_names_delta = [s + " Delta" for s in list(delta.columns)]
413 | delta.columns = col_names_delta
414 | col_names_train= [s + " Real" for s in list(models_train.columns)]
415 | models_train.columns = col_names_train
416 | col_names_synth = [s + " Synth" for s in list(models_synth.columns)]
417 | models_synth.columns = col_names_synth
418 |
419 | if save_data:
420 | _save_to_json("models", models_train, path_to_json)
421 | _save_to_json("models_synth", models_synth, path_to_json)
422 | _save_to_json("models_delta", delta, path_to_json)
423 |
424 | pred_train = pd.concat([y_train.to_pandas(),pred_train], axis=1)
425 | pred_train.columns.values[0] = 'Ground truth'
426 | pred_synth = pd.concat([y_synth.to_pandas(),pred_synth], axis=1)
427 | pred_synth.columns.values[0] = 'Ground truth'
428 |
429 | # Transform the output DataFrames into the type used for the input DataFrames
430 | if isinstance(y_train, (pl.DataFrame, pl.LazyFrame)):
431 | y_train = y_train.to_pandas()
432 | y_synth = y_synth.to_pandas()
433 | elif isinstance(y_train, np.ndarray):
434 | y_train = pd.DataFrame(y_train)
435 | y_synth = pd.DataFrame(y_synth)
436 |
437 | pred_train = pd.concat([y_train,pred_train], axis=1)
438 | pred_train.columns.values[0] = 'Groud truth'
439 | pred_synth = pd.concat([y_synth,pred_synth], axis=1)
440 | pred_synth.columns.values[0] = 'Groud truth'
441 |
442 | if was_pl:
443 | models_train = pl.from_pandas(models_train)
444 | models_synth = pl.from_pandas(models_synth)
445 | delta = pl.from_pandas(delta)
446 | elif was_np:
447 | models_train = pl.from_numpy(models_train)
448 | models_synth = pl.from_numpy(models_synth)
449 | delta = pl.from_numpy(delta)
450 |
451 | if predictions:
452 | if was_pl:
453 | pred_train = pl.from_pandas(pred_train)
454 | pred_synth = pl.from_pandas(pred_synth)
455 | elif was_np:
456 | pred_train = pl.from_numpy(pred_train)
457 | pred_synth = pl.from_numpy(pred_synth)
458 | return models_train, models_synth, delta, pred_train, pred_synth
459 | else:
460 | return models_train, models_synth, delta
461 |
462 | # STATISTICAL SIMILARITIES METRICS MODULE
463 | def _value_count(data: pl.DataFrame,
464 | features: List
465 | ) -> Dict:
466 | ''' This function returns the unique values count and frequency for each feature in a Polars DataFrame
467 | '''
468 | values = dict()
469 | for feature in data.select(pl.col(features)).columns:
470 | # Get the value counts for the specified feature
471 | value_counts_df = data.group_by(feature).len(name="count")
472 |
473 | # Calculate the frequency
474 | total_counts = value_counts_df['count'].sum()
475 | value_counts_df = value_counts_df.with_columns(
476 | (pl.col('count') / total_counts * 100).round(2).alias('freq_%')
477 | ).sort("freq_%",descending=True)
478 |
479 | values[feature] = value_counts_df
480 |
481 | return values
482 |
483 | def _most_frequent_values(data: pl.DataFrame,
484 | features: List
485 | ) -> Dict:
486 | ''' This function returns the most frequent value for each feature of the Polars DataFrame
487 | '''
488 | most_frequent_dict = dict()
489 | for feature in features:
490 | # Calculate the most frequent value (mode) for original dataset
491 | most_frequent = data[feature].mode().cast(pl.String).to_list()
492 |
493 | # Create the dictionary entry for the current feature
494 | if len(most_frequent)>5:
495 | most_frequent_dict[feature] = None
496 | else:
497 | most_frequent_dict[feature] = most_frequent
498 |
499 | return most_frequent_dict
500 |
501 | def compute_statistical_metrics(
502 | real_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
503 | synth_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
504 | save_data = True,
505 | path_to_json: str = ""
506 | ) -> Tuple[Dict, Dict, Dict]:
507 | """
508 | Compute statistical metrics for numerical, categorical, and temporal features
509 | in both real and synthetic datasets.
510 |
511 | Parameters
512 | ----------
513 | real_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
514 | The real dataset containing numerical, categorical, and/or temporal features.
515 | synth_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
516 | The synthetic dataset containing numerical, categorical, and/or temporal features.
517 | save_data : bool
518 | If True, saves the DCR information into the JSON file used to generate the final report, by default True.
519 | path_to_json : str, optional
520 | The file path to save the comparison metrics in JSON format, by default "".
521 |
522 | Returns
523 | -------
524 | Tuple[Dict, Dict, Dict]
525 | A tuple containing three dictionaries with statistical comparisons for
526 | numerical, categorical, and temporal features, respectively.
527 | """
528 | num_features_comparison = None
529 | cat_features_comparison = None
530 | time_features_comparison = None
531 |
532 | # Converting Real and Synthetic Dataset into pl.DataFrame
533 | real_data = _to_polars_df(real_data)
534 | synth_data = _to_polars_df(synth_data)
535 |
536 | # Drop columns that are present in the real dataset but not in the synthetic dataset and vice versa
537 | synth_data, real_data = _drop_cols(synth_data, real_data)
538 |
539 | # Check that the real features and the synthetic ones match
540 | if not real_data.columns==synth_data.columns:
541 | raise ValueError("The features from the real dataset and the synthetic one do not match.")
542 |
543 | # Transform boolean into int
544 | real_data = real_data.with_columns(pl.col(pl.Boolean).cast(pl.UInt8))
545 | synth_data = synth_data.with_columns(pl.col(pl.Boolean).cast(pl.UInt8))
546 |
547 | cat_features = cs.expand_selector(real_data, cs.string())
548 | num_features = cs.expand_selector(real_data, cs.numeric())
549 | time_features = cs.expand_selector(real_data, cs.temporal())
550 |
551 | # Numerical features
552 | if len(num_features) != 0:
553 | num_features_comparison = dict()
554 | num_features_comparison["null_count"] = { "real" : real_data.select(pl.col(num_features)).null_count(),
555 | "synthetic" : synth_data.select(pl.col(num_features)).null_count()}
556 | num_features_comparison["unique_val_number"]= { "real" : real_data.select(pl.n_unique(num_features)),
557 | "synthetic" : synth_data.select(pl.n_unique(num_features))}
558 | num_features_comparison["mean"] = { "real" : real_data.select(pl.mean(num_features)),
559 | "synthetic" : synth_data.select(pl.mean(num_features))}
560 | num_features_comparison["std"] = { "real" : real_data.select(num_features).std(),
561 | "synthetic" : synth_data.select(num_features).std()}
562 | num_features_comparison["min"] = { "real" : real_data.select(pl.min(num_features)),
563 | "synthetic" : synth_data.select(pl.min(num_features))}
564 | num_features_comparison["first_quartile"] = { "real" : real_data.select(num_features).quantile(0.25,"nearest"),
565 | "synthetic" : synth_data.select(num_features).quantile(0.25,"nearest")}
566 | num_features_comparison["second_quartile"] = { "real" : real_data.select(num_features).quantile(0.5,"nearest"),
567 | "synthetic" : synth_data.select(num_features).quantile(0.5,"nearest")}
568 | num_features_comparison["third_quartile"] = { "real" : real_data.select(num_features).quantile(0.75,"nearest"),
569 | "synthetic" : synth_data.select(num_features).quantile(0.75,"nearest")}
570 | num_features_comparison["max"] = { "real" : real_data.select(pl.max(num_features)),
571 | "synthetic" : synth_data.select(pl.max(num_features))}
572 | num_features_comparison["skewness"] = { "real" : real_data.select(pl.col(num_features).skew()),
573 | "synthetic" : synth_data.select(pl.col(num_features).skew())}
574 | num_features_comparison["kurtosis"] = { "real" : real_data.select(pl.col(num_features).kurtosis()),
575 | "synthetic" : synth_data.select(pl.col(num_features).kurtosis())}
576 |
577 | # Categorical features
578 | if len(cat_features) != 0:
579 | cat_features_comparison = dict()
580 | cat_features_comparison["null_count"] = { "real" : real_data.select(pl.col(cat_features)).null_count(),
581 | "synthetic" : synth_data.select(pl.col(cat_features)).null_count()}
582 | cat_features_comparison["unique_val_number"]= { "real" : real_data.select(pl.n_unique(cat_features)),
583 | "synthetic" : synth_data.select(pl.n_unique(cat_features))}
584 | cat_features_comparison["unique_val_stats"] = { "real" : _value_count(real_data, cat_features),
585 | "synthetic" : _value_count(synth_data, cat_features)}
586 |
587 | # Temporal features
588 | if len(time_features) != 0:
589 | time_features_comparison = dict()
590 | time_features_comparison["null_count"] = { "real" : real_data.select(pl.col(time_features)).null_count(),
591 | "synthetic" : synth_data.select(pl.col(time_features)).null_count()}
592 | time_features_comparison["unique_val_number"] = { "real" : real_data.select(pl.n_unique(time_features)),
593 | "synthetic" : synth_data.select(pl.n_unique(time_features))}
594 | time_features_comparison["min"] = { "real" : real_data.select(pl.min(time_features)),
595 | "synthetic" : synth_data.select(pl.min(time_features))}
596 | time_features_comparison["max"] = { "real" : real_data.select(pl.max(time_features)),
597 | "synthetic" : synth_data.select(pl.max(time_features))}
598 | time_features_comparison["most_frequent"] = { "real" : _most_frequent_values(real_data, time_features),
599 | "synthetic" : _most_frequent_values(synth_data, time_features)}
600 |
601 | if save_data:
602 | _save_to_json("num_features_comparison", num_features_comparison, path_to_json)
603 | _save_to_json("cat_features_comparison", cat_features_comparison, path_to_json)
604 | _save_to_json("time_features_comparison", time_features_comparison, path_to_json)
605 |
606 | return num_features_comparison, cat_features_comparison, time_features_comparison
607 |
608 | def compute_mutual_info(
609 | real_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
610 | synth_data: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
611 | exclude_columns: List = [],
612 | save_data = True,
613 | path_to_json: str = ""
614 | ) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
615 | """
616 | Compute the correlation matrices for both real and synthetic datasets, and
617 | calculate the difference between these matrices.
618 |
619 | Parameters
620 | ----------
621 | real_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
622 | The real dataset, which can be in the form of a Polars DataFrame, LazyFrame,
623 | pandas DataFrame, or numpy ndarray.
624 | synth_data : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
625 | The synthetic dataset, provided in the same format as `real_data`.
626 | exclude_columns: List, option
627 | A list of columns to exclude from the computaion of mutual information,
628 | by default [].
629 | save_data : bool
630 | If True, saves the DCR information into the JSON file used to generate the final report, by default True.
631 | path_to_json : str, optional
632 | File path to save the correlation matrices and their differences in JSON format,
633 | by default "".
634 |
635 | Returns
636 | -------
637 | Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]
638 | A tuple containing:
639 | - real_corr: Correlation matrix of the real dataset with column names included.
640 | - synth_corr: Correlation matrix of the synthetic dataset with column names included.
641 | - diff_corr: Difference between the correlation matrices of the real and synthetic
642 | datasets, with values smaller than 1e-5 substituted with 0.
643 |
644 | Raises
645 | ------
646 | ValueError
647 | If the features in the real and synthetic datasets do not match or if non-numerical
648 | features are present.
649 | """
650 | # Converting Real and Synthetic Dataset into pl.DataFrame
651 | real_data = _to_polars_df(real_data)
652 | synth_data = _to_polars_df(synth_data)
653 |
654 | for col in exclude_columns:
655 | if col not in real_data.columns:
656 | raise KeyError(f"Column {col} not found in DataFrame.")
657 |
658 | real_data = real_data.drop(exclude_columns)
659 | synth_data = synth_data.drop(exclude_columns)
660 |
661 | # Drop columns that are present in the real dataset but not in the synthetic dataset and vice versa
662 | synth_data, real_data = _drop_cols(synth_data, real_data)
663 |
664 | # Check that the real features and the synthetic ones match
665 | if not real_data.columns==synth_data.columns:
666 | raise ValueError("The features from the real dataset and the synthetic one do not match.")
667 |
668 | # Convert Boolean and Temporal types to numerical
669 | real_data = real_data.with_columns(cs.boolean().cast(pl.UInt8))
670 | synth_data = synth_data.with_columns(cs.boolean().cast(pl.UInt8))
671 | real_data = real_data.with_columns(cs.temporal().as_expr().dt.timestamp('ms'))
672 | synth_data = synth_data.with_columns(cs.temporal().as_expr().dt.timestamp('ms'))
673 |
674 | # Label Encoding of categorical features to compute mutual information
675 | encoder = LabelEncoder()
676 | for col in real_data.select(cs.string()).columns:
677 | real_data = real_data.with_columns([
678 | pl.Series(col, encoder.fit_transform(real_data[col].to_list())),
679 | ])
680 | synth_data = synth_data.with_columns([
681 | pl.Series(col, encoder.fit_transform(synth_data[col].to_list())),
682 | ])
683 |
684 | # Real and Synthetic dataset correlation matrix
685 | real_corr = real_data.corr()
686 | synth_corr = synth_data.corr()
687 |
688 | # Difference between the correlation matrix of the real dataset and the correlation matrix of the synthetic dataset
689 | diff_corr = real_corr-synth_corr
690 |
691 | # Substitute elements with abs value lower than 1e-5 with 0
692 | diff_corr = diff_corr.with_columns([pl.when(abs(pl.col(col)) < 1e-5).then(0).otherwise(pl.col(col)).alias(col) for col in diff_corr.columns])
693 |
694 | if save_data:
695 | _save_to_json("real_corr", real_corr, path_to_json)
696 | _save_to_json("synth_corr", synth_corr, path_to_json)
697 | _save_to_json("diff_corr", diff_corr, path_to_json)
698 |
699 | return real_corr, synth_corr, diff_corr
700 |
701 | def detection(
702 | df_original: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
703 | df_synthetic: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
704 | preprocessor: Preprocessor = None,
705 | features_to_hide: List = [],
706 | save_data = True,
707 | path_to_json: str = ""
708 | ) -> Dict:
709 | """
710 | Computes the detection score by training an XGBoost model to differentiate between
711 | original and synthetic data. The lower the model's accuracy, the higher the quality
712 | of the synthetic data.
713 |
714 | Parameters
715 | ----------
716 | df_original : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
717 | The original dataset containing real data.
718 | df_synthetic : Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]
719 | The synthetic dataset to be evaluated.
720 | preprocessor : Preprocessor, optional
721 | A preprocessor object for transforming the datasets. If None, a new Preprocessor
722 | instance will be created. Defaults to None.
723 | features_to_hide : list, optional
724 | List of features to exclude from importance analysis. Defaults to an empty list.
725 | save_data : bool
726 | If True, saves the DCR information into the JSON file used to generate the final report, by default True.
727 | path_to_json : str, optional
728 | Path to save the output JSON files, by default "".
729 |
730 | Returns
731 | -------
732 | dict
733 | A dictionary containing accuracy, ROC AUC score, detection score, and feature importances.
734 |
735 | Notes
736 | -----
737 | The method operates through the following steps:
738 |
739 | 1. Prepares the datasets:
740 |
741 | - Samples the original dataset to match the size of the synthetic dataset.
742 | - Preprocesses both datasets to ensure consistent feature representation.
743 | - Labels original data as `0` and synthetic data as `1`.
744 |
745 | 2. Builds a classification model:
746 |
747 | - Uses XGBoost to train a model that classifies data points as either real or synthetic.
748 | - Splits the data into training and test sets.
749 | - Trains the model using a 33% test split.
750 |
751 | 3. Computes Evaluation Metrics:
752 |
753 | - Accuracy: Measures classification correctness.
754 | - ROC-AUC Score: Measures the model’s discriminatory power.
755 | - Detection Score
756 |
757 | 4. Extracts Feature Importances:
758 |
759 | - Identifies which features contribute most to distinguishing real vs. synthetic data.
760 | - Helps detect which synthetic features deviate from real-world patterns.
761 |
762 |
763 | The detection score is calculated as:
764 |
765 | .. code-block:: python
766 |
767 | detection_score["score"] = (1 - detection_score["ROC_AUC"]) * 2
768 |
769 | - If ``ROC_AUC <= 0.5``, the synthetic data is considered indistinguishable from the real dataset (``score = 1``).
770 | - A lower score means better synthetic data quality.
771 | - Feature importance analysis helps detect which synthetic features deviate most from real data.
772 |
773 | Examples
774 | --------
775 | Example of dictionary returned:
776 |
777 | .. code-block:: python
778 |
779 | >>> detection_results = detection_score(original_df, synthetic_df)
780 | >>> print(detection_results)
781 | {
782 | "accuracy": 0.85, # How often the model classifies correctly
783 | "ROC_AUC": 0.90, # The ability to distinguish real vs synthetic
784 | "score": 0.2, # The final detection score (lower = better)
785 | "feature_importances": {"feature_1": 0.34, "feature_2": 0.21, ...}
786 | }
787 |
788 | """
789 | import xgboost as xgb
790 | from sklearn.model_selection import train_test_split
791 |
792 | if preprocessor is None:
793 | preprocessor = Preprocessor(df_original)
794 |
795 | df_original = _to_polars_df(df_original)
796 | df_synthetic = _to_polars_df(df_synthetic)
797 |
798 | # Sample from the original dataset to match the size of the synthetic dataset
799 | df_original = df_original.head(len(df_synthetic))
800 |
801 | # Replace minority labels in the original data
802 | for i in preprocessor.discarded[1].keys():
803 | list_minority_labels = preprocessor.discarded[1][i]
804 | for j in list_minority_labels:
805 | df_original = df_original.with_columns(
806 | pl.when(pl.col(i) == j)
807 | .then(pl.lit("other"))
808 | .otherwise(pl.col(i))
809 | .alias(i)
810 | )
811 |
812 | # Preprocess and label the original data
813 | preprocessed_df_original = preprocessor.transform(df_original)
814 | df_original = preprocessor.inverse_transform(preprocessed_df_original)
815 | df_original = df_original.with_columns([
816 | pl.Series("label", np.zeros(len(df_original)).astype(int)),
817 | ])
818 |
819 | # Preprocess and label the synthetic data
820 | preprocessed_df_synthetic = preprocessor.transform(df_synthetic)
821 | df_synthetic = preprocessor.inverse_transform(preprocessed_df_synthetic)
822 | df_synthetic = df_synthetic.with_columns([
823 | pl.Series("label", np.ones(len(df_synthetic)).astype(int)),
824 | ])
825 |
826 | df = pl.concat([df_original, df_synthetic])
827 | preprocessor_ = Preprocessor(df, target_column = "label")
828 | df_preprocessed = preprocessor_.transform(df)
829 |
830 | X_train, X_test, y_train, y_test = train_test_split(
831 | df_preprocessed.select(pl.exclude("label")),
832 | df_preprocessed.select(pl.col("label")),
833 | test_size=0.33,
834 | random_state=42
835 | )
836 |
837 | model = xgb.XGBClassifier(max_depth=3, n_estimators=50, use_label_encoder=False, eval_metric="logloss")
838 | model.fit(X_train, y_train)
839 |
840 | # Make predictions and compute metrics
841 | y_pred = model.predict(X_test)
842 | detection_score = {}
843 | detection_score["accuracy"] = round(
844 | accuracy_score(y_true=y_test, y_pred=y_pred), 4
845 | )
846 | detection_score["ROC_AUC"] = round(
847 | roc_auc_score(
848 | y_true=y_test,
849 | y_score=model.predict_proba(X_test)[:, 1],
850 | average=None,
851 | ),
852 | 4,
853 | )
854 | detection_score["score"] = (
855 | 1
856 | if detection_score["ROC_AUC"] <= 0.5
857 | else (1 - detection_score["ROC_AUC"]) * 2
858 | )
859 |
860 | detection_score["feature_importances"] = {}
861 |
862 | # Determine feature importances
863 | numerical_features_sizes, categorical_features_sizes = preprocessor.get_features_sizes()
864 |
865 | numerical_features = preprocessor.numerical_features
866 | categorical_features = preprocessor.categorical_features
867 | datetime_features = preprocessor.datetime_features
868 |
869 | index = 0
870 | for feature, importance in zip(numerical_features, model.feature_importances_):
871 | if feature not in features_to_hide:
872 | detection_score["feature_importances"][feature] = round(float(importance), 4)
873 | index += 1
874 |
875 | if len(datetime_features)>0:
876 | for feature, importance in zip(datetime_features, model.feature_importances_[index:]):
877 | if feature not in features_to_hide:
878 | detection_score["feature_importances"][feature] = round(float(importance), 4)
879 | index += 1
880 |
881 | for feature, feature_size in zip(categorical_features, categorical_features_sizes):
882 | importance = np.sum(model.feature_importances_[index : index + feature_size])
883 | index += feature_size
884 | if feature not in features_to_hide:
885 | detection_score["feature_importances"][feature] = round(float(importance), 4)
886 |
887 | if save_data:
888 | _save_to_json("detection_score", detection_score, path_to_json)
889 |
890 | return detection_score
891 |
892 | def query_power(
893 | df_original: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
894 | df_synthetic: pl.DataFrame | pl.LazyFrame | pd.DataFrame | np.ndarray,
895 | preprocessor: Preprocessor = None,
896 | save_data = True,
897 | path_to_json: str = ""
898 | ) -> dict:
899 | """
900 | Generates and runs queries to compare the original and synthetic datasets.
901 |
902 | This method creates random queries that filter data from both datasets.
903 | The similarity between the sizes of the filtered results is used to score
904 | the quality of the synthetic data.
905 |
906 | Parameters:
907 | -----------
908 | df_original (Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]):
909 | The original dataset containing real data.
910 | df_synthetic (Union[pl.DataFrame, pl.LazyFrame, pd.DataFrame, np.ndarray]):
911 | The synthetic dataset to be evaluated.
912 | preprocessor (Preprocessor, optional):
913 | A preprocessor object for transforming the datasets. If None, a new Preprocessor
914 | instance will be created. Defaults to None.
915 | save_data (bool):
916 | If True, saves the DCR information into the JSON file used to generate the final report.
917 | path_to_json (str, optional):
918 | Path to save the output JSON files.
919 | Returns:
920 | --------
921 | dict: A dictionary containing query texts, the number of matches for each
922 | query in both datasets, and an overall score indicating the quality
923 | of the synthetic data.
924 | """
925 | def polars_query(feat_type, feature, op, value):
926 | if feat_type == 'num':
927 | if op == "<=":
928 | query = pl.col(feature) <= value
929 | elif op == ">=":
930 | query = pl.col(feature) >= value
931 | elif feat_type == 'cat':
932 | if op == "==":
933 | query = pl.col(feature) == value
934 | elif op == "!=":
935 | query = pl.col(feature) != value
936 | return query
937 |
938 | query_power = {"queries": []}
939 |
940 | if preprocessor is None:
941 | preprocessor = Preprocessor(df_original)
942 |
943 | df_original = _to_polars_df(df_original)
944 | df_synthetic = _to_polars_df(df_synthetic)
945 |
946 | df_original = df_original.sample(len(df_synthetic)).clone()
947 | df_original_preprocessed = preprocessor.transform(df_original)
948 | df_original = preprocessor.inverse_transform(df_original_preprocessed)
949 |
950 | df_synthetic_preprocessed = preprocessor.transform(df_synthetic)
951 | df_synthetic = preprocessor.inverse_transform(df_synthetic_preprocessed)
952 |
953 | # Extract feature types
954 | numerical_features = preprocessor.numerical_features
955 | categorical_features = preprocessor.categorical_features
956 | datetime_features = preprocessor.datetime_features
957 | boolean_features = preprocessor.boolean_features
958 |
959 | # Prepare the feature list, excluding datetime features
960 | features = list(set(df_original.columns) - set(datetime_features))
961 |
962 | # Define query parameters
963 | quantiles = [0.25, 0.5, 0.75]
964 | numerical_ops = ["<=", ">="]
965 | categorical_ops = ["==", "!="]
966 | logical_ops = ["and"]
967 |
968 | queries_score = []
969 |
970 | # Generate and run up to 5 queries
971 | while len(features) >= 2 and len(query_power["queries"]) < 5:
972 | # Randomly select two features for the query
973 | feats = [random.choice(features)]
974 | features.remove(feats[0])
975 | feats.append(random.choice(features))
976 | features.remove(feats[1])
977 |
978 | queries = []
979 | queries_text = []
980 | # Construct query conditions for each selected feature
981 | for feature in feats:
982 | if feature in numerical_features:
983 | feat_type = 'num'
984 | op = random.choice(numerical_ops)
985 | value = df_original.select(
986 | pl.col(feature).quantile(
987 | random.choice(quantiles), interpolation="nearest"
988 | )).item()
989 | elif feature in categorical_features or feature in boolean_features:
990 | feat_type = 'cat'
991 | op = random.choice(categorical_ops)
992 | value = random.choice(df_original[feature].unique())
993 | else:
994 | continue
995 |
996 | queries_text.append(f"`{feature}` {op} `{value}`")
997 | queries.append(polars_query(feat_type, feature, op, value))
998 |
999 | # Combine query conditions with a logical operator
1000 | text = f" {random.choice(logical_ops)} ".join(queries_text)
1001 | combined_query = reduce(lambda a, b: a & b, queries)
1002 |
1003 | try:
1004 | query = {
1005 | "text": text,
1006 | "original_df": len(df_original.filter(combined_query)),
1007 | "synthetic_df": len(df_synthetic.filter(combined_query)),
1008 | }
1009 | except Exception:
1010 | query = {"text": "Invalid query", "original_df": 0, "synthetic_df": 0}
1011 |
1012 | # Append the query and calculate the score
1013 | query_power["queries"].append(query)
1014 | queries_score.append(
1015 | 1 - abs(query["original_df"] - query["synthetic_df"]) / len(df_original)
1016 | )
1017 |
1018 | # Calculate the overall query power score
1019 | query_power["score"] = round(float(sum(queries_score) / len(queries_score)), 4)
1020 |
1021 | if save_data:
1022 | _save_to_json("query_power", query_power, path_to_json)
1023 |
1024 | return query_power
--------------------------------------------------------------------------------
/tests/resources/synthetic_dataset.csv:
--------------------------------------------------------------------------------
1 | age,work_class,education,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
2 | 34,Private,Some-college,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,48,United-States,>50K
3 | 49,Private,Some-college,Separated,Exec-managerial,Unmarried,Black,Female,0,0,60,United-States,<=50K
4 | 44,Federal-gov,Bachelors,Married-civ-spouse,Adm-clerical,Husband,Asian-Pac-Islander,Male,0,0,40,*,>50K
5 | 35,Private,HS-grad,Divorced,Craft-repair,Not-in-family,White,Male,0,0,40,United-States,<=50K
6 | 44,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,*,Male,0,0,40,United-States,<=50K
7 | 28,Private,Assoc-voc,Never-married,Sales,Not-in-family,White,Male,0,0,50,United-States,<=50K
8 | 44,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,40,United-States,<=50K
9 | 54,Local-gov,10th,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,45,United-States,<=50K
10 | 33,Private,Bachelors,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,50,United-States,>50K
11 | 30,Private,10th,Divorced,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K
12 | 50,Private,Bachelors,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,1902,40,United-States,>50K
13 | 30,State-gov,Bachelors,Never-married,Prof-specialty,Other-relative,White,Male,0,0,35,United-States,<=50K
14 | 44,Private,Masters,Married-civ-spouse,Adm-clerical,Husband,White,Male,3998,0,50,United-States,<=50K
15 | 20,nan,Some-college,Never-married,nan,Own-child,White,Female,0,0,40,United-States,<=50K
16 | 36,Private,Bachelors,Divorced,Adm-clerical,Not-in-family,White,Male,0,0,54,United-States,>50K
17 | 54,Local-gov,Some-college,Married-civ-spouse,Transport-moving,Husband,Black,Male,0,0,30,United-States,<=50K
18 | 23,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,50,United-States,<=50K
19 | 26,Private,11th,Separated,Craft-repair,Other-relative,White,Male,3121,0,50,United-States,<=50K
20 | 20,nan,Some-college,Never-married,nan,Own-child,White,Female,0,0,40,United-States,<=50K
21 | 57,Private,Some-college,Married-civ-spouse,Craft-repair,Husband,White,Male,0,1709,45,United-States,<=50K
22 | 27,Local-gov,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Female,0,0,40,United-States,<=50K
23 | 19,Private,Some-college,Never-married,Other-service,Own-child,White,Female,0,0,30,United-States,<=50K
24 | 41,Private,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Female,0,0,45,United-States,<=50K
25 | 33,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,35,United-States,<=50K
26 | 33,Private,*,Never-married,Other-service,Unmarried,Black,Female,0,0,14,*,<=50K
27 | 61,Private,11th,Divorced,*,Not-in-family,White,Female,0,0,12,nan,<=50K
28 | 40,Private,Bachelors,*,Transport-moving,Own-child,*,Male,0,0,45,nan,<=50K
29 | 29,Private,Bachelors,Never-married,Exec-managerial,Not-in-family,White,Female,0,0,45,United-States,>50K
30 | 18,Private,HS-grad,Never-married,Handlers-cleaners,Own-child,White,Male,0,0,40,United-States,<=50K
31 | 43,Self-emp-inc,Bachelors,Married-civ-spouse,Sales,Husband,White,Male,0,0,45,United-States,>50K
32 | 42,Self-emp-not-inc,Assoc-voc,Married-civ-spouse,Transport-moving,Husband,White,Male,4797,0,80,United-States,>50K
33 | 37,Private,Some-college,Married-civ-spouse,Sales,Husband,*,Male,0,0,45,United-States,>50K
34 | 29,Private,Assoc-voc,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,44,United-States,<=50K
35 | 45,Private,HS-grad,Married-civ-spouse,Transport-moving,Husband,*,Male,4132,0,40,United-States,<=50K
36 | 36,Self-emp-not-inc,*,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,35,United-States,<=50K
37 | 27,Private,HS-grad,Never-married,Machine-op-inspct,Not-in-family,White,Female,0,0,40,United-States,<=50K
38 | 18,Private,HS-grad,Never-married,Adm-clerical,Own-child,White,Male,0,0,20,United-States,<=50K
39 | 46,Private,HS-grad,*,Other-service,Not-in-family,White,Male,0,0,35,Mexico,<=50K
40 | 30,Self-emp-not-inc,Some-college,Never-married,Sales,Own-child,White,Male,0,0,25,United-States,<=50K
41 | 42,Private,Assoc-voc,Divorced,Adm-clerical,Unmarried,White,Female,0,0,40,United-States,<=50K
42 | 33,Private,*,*,Transport-moving,Unmarried,*,Male,0,0,20,*,<=50K
43 | 45,Private,HS-grad,Never-married,Exec-managerial,Not-in-family,Black,Female,0,0,40,United-States,<=50K
44 | 54,Private,Some-college,Married-civ-spouse,Craft-repair,Wife,White,Female,0,0,40,United-States,<=50K
45 | 37,Self-emp-inc,HS-grad,Never-married,Transport-moving,Not-in-family,White,Male,0,0,50,United-States,<=50K
46 | 37,Self-emp-not-inc,HS-grad,Married-civ-spouse,Transport-moving,Husband,White,Male,0,0,30,United-States,<=50K
47 | 23,Private,Assoc-acdm,Never-married,Sales,Own-child,Black,Female,0,0,36,United-States,<=50K
48 | 26,Private,*,Never-married,Handlers-cleaners,Unmarried,White,Male,0,0,40,United-States,<=50K
49 | 25,Private,*,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,34,Mexico,<=50K
50 | 31,Private,Bachelors,Married-civ-spouse,Sales,Husband,White,Male,0,0,40,United-States,>50K
51 | 26,Local-gov,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Female,0,0,40,United-States,<=50K
52 | 74,nan,Bachelors,Widowed,nan,Not-in-family,*,Female,0,0,35,United-States,<=50K
53 | 25,Private,HS-grad,Never-married,Craft-repair,Not-in-family,White,Male,0,0,35,United-States,<=50K
54 | 18,Private,HS-grad,Never-married,Sales,Own-child,White,Female,0,0,27,United-States,<=50K
55 | 40,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,40,United-States,<=50K
56 | 32,Private,Assoc-voc,Divorced,Exec-managerial,Not-in-family,White,Female,0,0,50,United-States,<=50K
57 | 54,Private,HS-grad,Married-civ-spouse,Tech-support,Husband,White,Male,7687,0,25,United-States,>50K
58 | 23,Private,Some-college,Never-married,Handlers-cleaners,Own-child,White,Male,2961,0,50,United-States,<=50K
59 | 56,Self-emp-not-inc,*,Divorced,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K
60 | 33,Local-gov,HS-grad,Widowed,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K
61 | 21,State-gov,Some-college,Never-married,Adm-clerical,Own-child,White,Female,0,0,30,United-States,<=50K
62 | 21,Private,HS-grad,Never-married,Sales,Not-in-family,White,Male,0,0,54,United-States,<=50K
63 | 27,Private,HS-grad,Widowed,Craft-repair,Unmarried,White,Female,0,1610,26,United-States,<=50K
64 | 47,Self-emp-not-inc,HS-grad,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,2415,50,United-States,>50K
65 | 35,Private,Masters,Never-married,Exec-managerial,Not-in-family,White,Female,0,0,54,United-States,<=50K
66 | 39,Self-emp-not-inc,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,70,nan,<=50K
67 | 19,Private,11th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,60,United-States,<=50K
68 | 52,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,United-States,>50K
69 | 50,Private,HS-grad,Married-civ-spouse,Other-service,Husband,White,Male,0,0,40,United-States,>50K
70 | 39,Private,Some-college,Never-married,Other-service,Unmarried,Black,Female,0,0,40,United-States,<=50K
71 | 22,nan,Some-college,Never-married,nan,Not-in-family,Asian-Pac-Islander,Female,0,0,12,*,<=50K
72 | 31,Private,HS-grad,Never-married,Other-service,Not-in-family,White,Male,0,0,50,United-States,<=50K
73 | 26,Self-emp-not-inc,Some-college,Never-married,Other-service,Not-in-family,White,Female,0,0,30,United-States,<=50K
74 | 26,Private,HS-grad,Married-civ-spouse,Exec-managerial,Wife,White,Female,5013,0,40,United-States,<=50K
75 | 30,Private,10th,Never-married,Craft-repair,Not-in-family,White,Male,0,0,50,United-States,<=50K
76 | 22,Private,Some-college,Married-civ-spouse,Tech-support,Husband,White,Male,0,0,40,United-States,<=50K
77 | 49,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,40,United-States,>50K
78 | 35,Private,Some-college,Divorced,Exec-managerial,Not-in-family,White,Female,0,0,40,United-States,>50K
79 | 25,Private,HS-grad,Married-civ-spouse,Transport-moving,Husband,White,Male,0,0,40,United-States,<=50K
80 | 42,Private,Bachelors,Married-civ-spouse,Craft-repair,Wife,Black,Female,0,0,35,United-States,>50K
81 | 28,Local-gov,Assoc-voc,Married-civ-spouse,Prof-specialty,Wife,White,Female,7687,0,37,United-States,>50K
82 | 20,Private,Some-college,Never-married,Other-service,Not-in-family,White,Male,0,0,40,United-States,<=50K
83 | 30,Self-emp-not-inc,HS-grad,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,1887,50,United-States,>50K
84 | 35,Private,HS-grad,Married-civ-spouse,Handlers-cleaners,Husband,White,Male,0,0,40,Mexico,<=50K
85 | 56,Private,HS-grad,Widowed,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K
86 | 25,Private,Bachelors,Never-married,Exec-managerial,Own-child,White,Male,0,261,40,United-States,<=50K
87 | 39,Private,Masters,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,40,United-States,<=50K
88 | 60,Self-emp-not-inc,*,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,United-States,<=50K
89 | 51,Local-gov,*,Divorced,Exec-managerial,Not-in-family,White,Female,0,0,45,United-States,>50K
90 | 53,Private,Masters,Never-married,Prof-specialty,Not-in-family,White,Female,0,1591,54,United-States,>50K
91 | 24,Private,HS-grad,Never-married,Transport-moving,Unmarried,White,Female,0,0,40,Mexico,<=50K
92 | 43,Private,Some-college,Divorced,Transport-moving,Not-in-family,White,Male,0,0,45,United-States,<=50K
93 | 51,Self-emp-not-inc,Masters,Divorced,Exec-managerial,Unmarried,White,Male,26552,0,50,United-States,>50K
94 | 31,Private,Assoc-acdm,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,0,0,20,United-States,<=50K
95 | 60,Private,HS-grad,Married-civ-spouse,Handlers-cleaners,Husband,White,Male,0,0,40,United-States,<=50K
96 | 24,Private,Some-college,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,40,United-States,<=50K
97 | 45,State-gov,Masters,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,*,<=50K
98 | 23,Private,Some-college,Never-married,Sales,Not-in-family,White,Female,0,0,30,United-States,<=50K
99 | 34,Self-emp-not-inc,Assoc-acdm,Married-civ-spouse,Prof-specialty,Wife,White,Female,0,0,25,United-States,>50K
100 | 18,Private,Some-college,Never-married,Adm-clerical,Own-child,White,Female,0,0,38,United-States,<=50K
101 |
--------------------------------------------------------------------------------
/tests/resources/validation_dataset.csv:
--------------------------------------------------------------------------------
1 | age,work_class,education,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
2 | 25,Private,11th,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K
3 | 38,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K
4 | 28,Local-gov,Assoc-acdm,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K
5 | 44,Private,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K
6 | 18,,Some-college,Never-married,,Own-child,White,Female,0,0,30,United-States,<=50K
7 | 34,Private,10th,Never-married,Other-service,Not-in-family,White,Male,0,0,30,United-States,<=50K
8 | 29,,HS-grad,Never-married,,Unmarried,Black,Male,0,0,40,United-States,<=50K
9 | 63,Self-emp-not-inc,Prof-school,Married-civ-spouse,Prof-specialty,Husband,White,Male,3103,0,32,United-States,>50K
10 | 24,Private,Some-college,Never-married,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K
11 | 55,Private,7th-8th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,10,United-States,<=50K
12 | 65,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,6418,0,40,United-States,>50K
13 | 36,Federal-gov,Bachelors,Married-civ-spouse,Adm-clerical,Husband,White,Male,0,0,40,United-States,<=50K
14 | 26,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,39,United-States,<=50K
15 | 58,,HS-grad,Married-civ-spouse,,Husband,White,Male,0,0,35,United-States,<=50K
16 | 48,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,3103,0,48,United-States,>50K
17 | 43,Private,Masters,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,50,United-States,>50K
18 | 20,State-gov,Some-college,Never-married,Other-service,Own-child,White,Male,0,0,25,United-States,<=50K
19 | 43,Private,HS-grad,Married-civ-spouse,Adm-clerical,Wife,White,Female,0,0,30,United-States,<=50K
20 | 37,Private,HS-grad,Widowed,Machine-op-inspct,Unmarried,White,Female,0,0,20,United-States,<=50K
21 | 40,Private,Doctorate,Married-civ-spouse,Prof-specialty,Husband,Asian-Pac-Islander,Male,0,0,45,,>50K
22 | 34,Private,Bachelors,Married-civ-spouse,Tech-support,Husband,White,Male,0,0,47,United-States,>50K
23 | 34,Private,Some-college,Never-married,Other-service,Own-child,Black,Female,0,0,35,United-States,<=50K
24 | 72,,7th-8th,Divorced,,Not-in-family,White,Female,0,0,6,United-States,<=50K
25 | 25,Private,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,43,Peru,<=50K
26 | 25,Private,Bachelors,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,United-States,<=50K
27 | 45,Self-emp-not-inc,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,7298,0,90,United-States,>50K
28 | 22,Private,HS-grad,Never-married,Adm-clerical,Own-child,White,Male,0,0,20,United-States,<=50K
29 | 23,Private,HS-grad,Separated,Machine-op-inspct,Unmarried,Black,Male,0,0,54,United-States,<=50K
30 | 54,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,35,United-States,<=50K
31 | 32,Self-emp-not-inc,Some-college,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,60,United-States,<=50K
32 | 46,State-gov,Some-college,Married-civ-spouse,Exec-managerial,Husband,Black,Male,7688,0,38,United-States,>50K
33 | 56,Self-emp-not-inc,11th,Widowed,Other-service,Unmarried,White,Female,0,0,50,United-States,<=50K
34 | 24,Self-emp-not-inc,Bachelors,Never-married,Sales,Not-in-family,White,Male,0,0,50,United-States,<=50K
35 | 23,Local-gov,Some-college,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,<=50K
36 | 26,Private,HS-grad,Divorced,Exec-managerial,Unmarried,White,Female,0,0,40,United-States,<=50K
37 | 65,,HS-grad,Married-civ-spouse,,Husband,White,Male,0,0,40,United-States,<=50K
38 | 36,Local-gov,Bachelors,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,United-States,>50K
39 | 22,Private,5th-6th,Never-married,Priv-house-serv,Not-in-family,White,Male,0,0,50,Guatemala,<=50K
40 | 17,Private,10th,Never-married,Machine-op-inspct,Not-in-family,White,Male,0,0,40,United-States,<=50K
41 | 20,Private,HS-grad,Never-married,Craft-repair,Own-child,White,Male,0,0,40,United-States,<=50K
42 | 65,Private,Masters,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,50,United-States,>50K
43 | 44,Self-emp-inc,Assoc-voc,Married-civ-spouse,Sales,Husband,White,Male,0,0,45,United-States,>50K
44 | 36,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,40,United-States,<=50K
45 | 29,Private,11th,Married-civ-spouse,Other-service,Husband,White,Male,0,0,40,United-States,<=50K
46 | 20,State-gov,Some-college,Never-married,Farming-fishing,Own-child,White,Male,0,0,32,United-States,<=50K
47 | 28,Private,Assoc-voc,Married-civ-spouse,Prof-specialty,Wife,White,Female,0,0,36,United-States,>50K
48 | 39,Private,7th-8th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,Mexico,<=50K
49 | 54,Private,Some-college,Married-civ-spouse,Transport-moving,Husband,White,Male,3908,0,50,United-States,<=50K
50 | 52,Private,11th,Separated,Priv-house-serv,Not-in-family,Black,Female,0,0,18,United-States,<=50K
51 |
--------------------------------------------------------------------------------
/tests/test_sure.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import polars as pl
4 | import pandas as pd
5 | from pathlib import Path
6 | from sure import Preprocessor
7 | from sure.utility import (compute_statistical_metrics, compute_mutual_info,
8 | compute_utility_metrics_class)
9 | from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test,
10 | adversary_dataset, membership_inference_test)
11 | @pytest.fixture
12 | def real_data():
13 | test_dir = Path(__file__).parent
14 | df = pl.read_csv(test_dir / "resources" / "dataset.csv")
15 | return df
16 |
17 | @pytest.fixture
18 | def synthetic_data():
19 | test_dir = Path(__file__).parent
20 | df = pl.read_csv(test_dir / "resources" / "synthetic_dataset.csv")
21 | return df
22 |
23 | @pytest.fixture
24 | def validation_data():
25 | test_dir = Path(__file__).parent
26 | df = pl.read_csv(test_dir / "resources" / "validation_dataset.csv")
27 | return df
28 |
29 | def test_statistical_metrics(real_data, synthetic_data):
30 | """Test computation of statistical metrics between real and synthetic data"""
31 | num_stats, cat_stats, _ = compute_statistical_metrics(real_data, synthetic_data)
32 |
33 | assert isinstance(num_stats, dict)
34 |
35 | def test_mutual_info(real_data, synthetic_data):
36 | """Test computation of mutual information metrics"""
37 | preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize')
38 | real_preprocessed = preprocessor.transform(real_data)
39 | synth_preprocessed = preprocessor.transform(synthetic_data)
40 |
41 | corr_real, corr_synth, corr_diff = compute_mutual_info(real_preprocessed, synth_preprocessed)
42 |
43 | assert corr_real.shape == corr_synth.shape
44 |
45 | def test_privacy_metrics(real_data, synthetic_data, validation_data):
46 | """Test computation of privacy metrics"""
47 | # Test DCR computations
48 | dcr_synth_train = distance_to_closest_record("synth_train", synthetic_data, real_data)
49 | dcr_synth_valid = distance_to_closest_record("synth_val", synthetic_data, validation_data)
50 |
51 | # Test DCR stats
52 | stats_train = dcr_stats("synth_train", dcr_synth_train)
53 | stats_valid = dcr_stats("synth_val", dcr_synth_valid)
54 |
55 | # Test validation
56 | share = validation_dcr_test(dcr_synth_train, dcr_synth_valid)
57 | assert isinstance(share['percentage'], float)
58 | assert 0 <= share['percentage'] <= 100
59 |
60 | def test_membership_inference(real_data, synthetic_data, validation_data):
61 | """Test membership inference attack"""
62 | adversary_df = adversary_dataset(real_data, validation_data)
63 | ground_truth = adversary_df["privacy_test_is_training"]
64 |
65 | mia_results = membership_inference_test(adversary_df, synthetic_data, ground_truth)
66 |
67 | assert isinstance(mia_results, dict)
68 |
--------------------------------------------------------------------------------
/tutorials/sure_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# SURE library use case notebook\n",
8 | "Useful links:\n",
9 | "- [Github repo](https://github.com/Clearbox-AI/SURE)\n",
10 | "- [Documentation](https://dario-brunelli-clearbox-ai.notion.site/SURE-Documentation-2c17db370641488a8db5bce406032c1f)\n",
11 | "\n",
12 | "Download the datasets and try out the library with this guided use case.\n",
13 | "\n",
14 | "We would greatly appreciate your feedback to help us improve the library! \\\n",
15 | "If you encounter any issues, please open an issue on our [GitHub repository](https://github.com/Clearbox-AI/SURE).\n",
16 | "\n",
17 | "### Datasets description\n",
18 | "\n",
19 | "The three datasets provided are the following:\n",
20 | "\n",
21 | "- *census_dataset_training.csv* \\\n",
22 | " The original real dataset used to train the generative model from which *census_dataset_synthetic* was produced.\n",
23 | " \n",
24 | "- *census_dataset_validation.csv* \\\n",
25 | " This dataset was also part of the original real dataset, but it was NOT used to train the generative model that produced *census_dataset_synthetic*.\n",
26 | " \n",
27 | "- *census_dataset_synthetic.csv* \\\n",
28 | " The synthetic dataset produced with the generative model trained on *census_dataset_training.*\n",
29 | " \n",
30 | "\n",
31 | "The three census datasets include various demographic, social, economic, and housing characteristics of individuals. Every row of the datasets corresponds to an individual.\n",
32 | "\n",
33 | "The machine learning task related to these datasets is a classification task, where, based on all the features, a ML classifier model must decide whether the individual earns more than 50k dollars per year (label=1) or less (label=0).\\\n",
34 | "The column \"label\" in each dataset is the ground truth for this classification task."
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "## 0. Installing the library and importing dependencies "
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "# install the SURE library \n",
51 | "%pip install clearbox-sure"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "# importing dependencies\n",
61 | "import polars as pl # you can use polars or pandas for importing the datasets\n",
62 | "import pandas as pd\n",
63 | "import os\n",
64 | "\n",
65 | "from sure import Preprocessor, report\n",
66 | "from sure.utility import (compute_statistical_metrics, compute_mutual_info,\n",
67 | "\t\t\t \t\t\t compute_utility_metrics_class,\n",
68 | "\t\t\t\t\t\t detection,\n",
69 | "\t\t\t\t\t\t query_power)\n",
70 | "from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test, \n",
71 | "\t\t\t adversary_dataset, membership_inference_test)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "## 1. Dataset import and preparation"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "#### 1.1 Import the datasets"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "file_path = \"https://raw.githubusercontent.com/Clearbox-AI/SURE/main/examples/data/census_dataset\"\n",
95 | "\n",
96 | "real_data = pl.from_pandas(pd.read_csv(os.path.join(file_path,\"census_dataset_training.csv\"))).lazy()\n",
97 | "valid_data = pl.from_pandas(pd.read_csv(os.path.join(file_path,\"census_dataset_validation.csv\"))).lazy()\n",
98 | "synth_data = pl.from_pandas(pd.read_csv(os.path.join(file_path,\"census_dataset_synthetic.csv\"))).lazy()"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "#### 1.2 Datasets preparation\n",
106 | "Apply a series of transformations to the raw dataset to prepare it for the subsequent steps."
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "# Preprocessor initialization and query execution on the real, synthetic and validation datasets\n",
116 | "preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize')\n",
117 | "\n",
118 | "real_data_preprocessed = preprocessor.transform(real_data)\n",
119 | "valid_data_preprocessed = preprocessor.transform(valid_data)\n",
120 | "synth_data_preprocessed = preprocessor.transform(synth_data)"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "## 2. Utility assessment"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "#### 2.1 Statistical properties and mutual information\n",
135 | "These functions compute general statistical features, the correlation matrices and the difference between the correlation matrix of the real and synthetic dataset."
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 6,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "# Compute statistical properties and features mutual information\n",
145 | "num_features_stats, cat_features_stats, temporal_feat_stats = compute_statistical_metrics(real_data, synth_data)\n",
146 | "corr_real, corr_synth, corr_difference = compute_mutual_info(real_data_preprocessed, synth_data_preprocessed)"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "#### 2.2 ML utility - Train on Synthetic Test on Real\n",
154 | "The `compute_utility_metrics_class` trains multiple machine learning classification models on the synthetic dataset and evaluates their performance on the validation set.\n",
155 | "\n",
156 | "For comparison, it also trains the same models on the original training set and evaluates them on the same validation set. This allows a direct comparison between models trained on synthetic data and those trained on real data."
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# Assessing the machine learning utility of the synthetic dataset on the classification task\n",
166 | "\n",
167 | "# ML utility: TSTR - Train on Synthetic, Test on Real\n",
168 | "X_train = real_data_preprocessed.drop(\"label\") # Assuming the datasets have a “label” column for the machine learning task they are intended for\n",
169 | "y_train = real_data_preprocessed[\"label\"]\n",
170 | "X_synth = synth_data_preprocessed.drop(\"label\")\n",
171 | "y_synth = synth_data_preprocessed[\"label\"]\n",
172 | "X_test = valid_data_preprocessed.drop(\"label\").limit(10000) # Test the trained models on a portion of the original real dataset (first 10k rows)\n",
173 | "y_test = valid_data_preprocessed[\"label\"].limit(10000)\n",
174 | "TSTR_metrics = compute_utility_metrics_class(X_train, X_synth, X_test, y_train, y_synth, y_test)"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {},
180 | "source": [
181 | "#### 2.3 Detection Score\n",
182 | "Computes the detection score by training an XGBoost model to differentiate between original and synthetic data. \n",
183 | "\n",
184 | "The lower the model's accuracy, the higher the quality of the synthetic data.\n",
185 | "\n",
186 | "\n",
187 | "The detection score is computed as\n",
188 | "\n",
189 | "detection_score = 2*(1 - ROC_AUC)\n",
190 | "\n",
191 | "So if ROC_AUC<=0.5 the synthetic dataset is considered indistinguishable from the real dataset (detection score =1)\n"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "detection_score = detection(real_data, synth_data, preprocessor)\n",
201 | "print(\"Detection accuracy: \", detection_score[\"accuracy\"])\n",
202 | "print(\"Detection ROC_AUC: \", detection_score[\"ROC_AUC\"])\n",
203 | "print(\"Detection score: \", detection_score[\"score\"])\n",
204 | "print(\"Detection feature importances: \", detection_score[\"feature_importances\"])"
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {},
210 | "source": [
211 | "#### Query Power\n",
212 | "Generates and runs queries to compare the original and synthetic datasets.\n",
213 | "\n",
214 | "This method creates random queries that filter data from both datasets.\n",
215 | "\n",
216 | "The similarity between the sizes of the filtered results is used to score the quality of the synthetic data."
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "metadata": {},
223 | "outputs": [],
224 | "source": [
225 | "query_power_score = query_power(real_data, synth_data, preprocessor)\n",
226 | "\n",
227 | "print(\"Query Power score: \", query_power_score[\"score\"])\n",
228 | "for query in query_power_score[\"queries\"]:\n",
229 | " print(\"\\n\", query[\"text\"])\n",
230 | " print(\"Query result on real: \", query[\"original_df\"])\n",
231 | " print(\"Query result on synthetic: \", query[\"synthetic_df\"])"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "## 3. Privacy assessment"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "#### 3.1 Distance to closest record (DCR)"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 10,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "# Compute the distances to closest record between the synthetic dataset and the real dataset\n",
255 | "# and the distances to closest record between the synthetic dataset and the validation dataset\n",
256 | "\n",
257 | "dcr_synth_train = distance_to_closest_record(\"synth_train\", synth_data, real_data)\n",
258 | "dcr_synth_valid = distance_to_closest_record(\"synth_val\", synth_data, valid_data)"
259 | ]
260 | },
261 | {
262 | "cell_type": "code",
263 | "execution_count": null,
264 | "metadata": {},
265 | "outputs": [],
266 | "source": [
267 | "# Check for any clones shared between the synthetic and real datasets (DCR=0).\n",
268 | "\n",
269 | "dcr_zero_synth_train = number_of_dcr_equal_to_zero(\"synth_train\", dcr_synth_train)\n",
270 | "dcr_zero_synth_valid = number_of_dcr_equal_to_zero(\"synth_val\", dcr_synth_valid)"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {},
277 | "outputs": [],
278 | "source": [
279 | "# Compute some general statistcs for the DCR arrays computed above\n",
280 | "\n",
281 | "dcr_stats_synth_train = dcr_stats(\"synth_train\", dcr_synth_train)\n",
282 | "dcr_stats_synth_valid = dcr_stats(\"synth_val\", dcr_synth_valid)"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": null,
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "# Compute the share of records that are closer to the training set than to the validation set\n",
292 | "\n",
293 | "share = validation_dcr_test(dcr_synth_train, dcr_synth_valid)"
294 | ]
295 | },
296 | {
297 | "cell_type": "markdown",
298 | "metadata": {},
299 | "source": [
300 | "#### 3.2 Membership Inference Attack test"
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": null,
306 | "metadata": {},
307 | "outputs": [],
308 | "source": [
309 | "# Simulate a Membership inference Attack on your syntehtic dataset\n",
310 | "# To do so, you'll need to produce an adversary dataset and some labels as adversary guesses groundtruth\n",
311 | "\n",
312 | "# The label is automatically produced by the function adversary_dataset and is added as a column named \n",
313 | "# \"privacy_test_is_training\" in the adversary dataset returned\n",
314 | "\n",
315 | "# ML privacy attack sandbox initialization and simulation\n",
316 | "adversary_df = adversary_dataset(real_data, valid_data)\n",
317 | "\n",
318 | "# The function adversary_dataset adds a column \"privacy_test_is_training\" to the adversary dataset, indicating whether the record was part of the training set or not\n",
319 | "adversary_guesses_ground_truth = adversary_df[\"privacy_test_is_training\"] \n",
320 | "MIA = membership_inference_test(adversary_df, synth_data, adversary_guesses_ground_truth)"
321 | ]
322 | },
323 | {
324 | "cell_type": "markdown",
325 | "metadata": {},
326 | "source": [
327 | "## 4. Utility-Privacy report"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "metadata": {},
334 | "outputs": [],
335 | "source": [
336 | "# Produce the utility privacy report with the information computed above\n",
337 | "\n",
338 | "report(real_data, synth_data)"
339 | ]
340 | }
341 | ],
342 | "metadata": {
343 | "kernelspec": {
344 | "display_name": "Python (test_sure)",
345 | "language": "python",
346 | "name": "test_sure"
347 | },
348 | "language_info": {
349 | "codemirror_mode": {
350 | "name": "ipython",
351 | "version": 3
352 | },
353 | "file_extension": ".py",
354 | "mimetype": "text/x-python",
355 | "name": "python",
356 | "nbconvert_exporter": "python",
357 | "pygments_lexer": "ipython3",
358 | "version": "3.10.12"
359 | }
360 | },
361 | "nbformat": 4,
362 | "nbformat_minor": 2
363 | }
364 |
--------------------------------------------------------------------------------