├── LICENSE ├── README.md └── Syncora_vs_Gretel_vs_MostlyAI_metrics_comparison.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Syncora.ai - Agentic Synthetic Data Generation Tool 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # syncora-benchmarks 2 | 3 | Syncora Benchmarks is a plug and play toolkit to evaluate synthetic data quality. Add CSVs from any generator and get instant fidelity metrics and visual comparisons with simple file naming. 4 | --- 5 | 6 | ## What’s Inside 7 | 8 | - **`Syncora_vs_Gretel_vs_MostlyAI_metrics_comparison.ipynb`** 9 | A Jupyter Notebook that: 10 | 1. Loads your real & synthetic datasets 11 | 2. Computes a suite of similarity & fidelity metrics 12 | 3. Visualizes comparative results 13 | 14 | - **`README.md`** 15 | This overview file. 16 | 17 | - **Raw / Synthetic Data Files** 18 | Place your CSVs here following the naming convention: _synthetic.csv 19 | 20 | --- 21 | 22 | ## Template Usage 23 | 24 | 1. **Generate synthetic data** 25 | Use any platform or in-house model to produce a CSV. 26 | 27 | 2. **Name your output** 28 | Rename your file to: mygenerator_synthetic.csv 29 | 30 | _e.g._ `Syncora_synthetic.csv`, `Gretel_synthetic.csv`, etc. 31 | 32 | 3. **Drop it into this repo** 33 | Place your CSV alongside the notebook in the same folder. 34 | 35 | 4. **Edit & run the notebook** 36 | - Open `Syncora_vs_Gretel_vs_MostlyAI_metrics_comparison.ipynb`. 37 | - The code automatically discovers all `*_synthetic.csv` files. 38 | - Execute all cells to regenerate metrics and plots. 39 | 40 | --- 41 | 42 | ## Adding New Generators or Datasets 43 | 44 | 1. Generate your synthetic CSV and name it `_synthetic.csv`. 45 | 2. (Optionally) Add a short description in the notebook’s metadata. 46 | 3. Re-run the notebook — your new results will be appended to the comparison charts. 47 | 48 | --- 49 | 50 | ## Contributing 51 | 52 | 1. Fork this repository. 53 | 2. Create a feature branch (`git checkout -b feature/xyz`). 54 | 3. Submit a pull request with your updates. 55 | 4. Ensure the notebook runs end-to-end without errors. 56 | 57 | --- 58 | 59 | ## License 60 | 61 | This project is released under the [MIT License](LICENSE). 62 | -------------------------------------------------------------------------------- /Syncora_vs_Gretel_vs_MostlyAI_metrics_comparison.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "id": "llZS4N1pjedS" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 40, 18 | "metadata": { 19 | "colab": { 20 | "base_uri": "https://localhost:8080/" 21 | }, 22 | "id": "AtD6Tn_OlLdf", 23 | "outputId": "7651a509-4c6f-4e68-e4e6-8d9fad588cc0" 24 | }, 25 | "outputs": [ 26 | { 27 | "output_type": "stream", 28 | "name": "stdout", 29 | "text": [ 30 | "Collecting sdmetrics\n", 31 | " Downloading sdmetrics-0.22.0-py3-none-any.whl.metadata (9.4 kB)\n", 32 | "Requirement already satisfied: numpy>=1.24.0 in /usr/local/lib/python3.11/dist-packages (from sdmetrics) (2.0.2)\n", 33 | "Requirement already satisfied: pandas>=1.5.0 in /usr/local/lib/python3.11/dist-packages (from sdmetrics) (2.2.2)\n", 34 | "Requirement already satisfied: scikit-learn>=1.1.3 in /usr/local/lib/python3.11/dist-packages (from sdmetrics) (1.6.1)\n", 35 | "Requirement already satisfied: scipy>=1.9.2 in /usr/local/lib/python3.11/dist-packages (from sdmetrics) (1.16.0)\n", 36 | "Collecting copulas>=0.12.1 (from sdmetrics)\n", 37 | " Downloading copulas-0.12.3-py3-none-any.whl.metadata (9.5 kB)\n", 38 | "Requirement already satisfied: tqdm>=4.29 in /usr/local/lib/python3.11/dist-packages (from sdmetrics) (4.67.1)\n", 39 | "Requirement already satisfied: plotly>=5.19.0 in /usr/local/lib/python3.11/dist-packages (from sdmetrics) (5.24.1)\n", 40 | "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.5.0->sdmetrics) (2.9.0.post0)\n", 41 | "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.5.0->sdmetrics) (2025.2)\n", 42 | "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.5.0->sdmetrics) (2025.2)\n", 43 | "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.11/dist-packages (from plotly>=5.19.0->sdmetrics) (8.5.0)\n", 44 | "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from plotly>=5.19.0->sdmetrics) (25.0)\n", 45 | "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=1.1.3->sdmetrics) (1.5.1)\n", 46 | "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=1.1.3->sdmetrics) (3.6.0)\n", 47 | "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas>=1.5.0->sdmetrics) (1.17.0)\n", 48 | "Downloading sdmetrics-0.22.0-py3-none-any.whl (198 kB)\n", 49 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m198.1/198.1 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 50 | "\u001b[?25hDownloading copulas-0.12.3-py3-none-any.whl (52 kB)\n", 51 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m52.7/52.7 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 52 | "\u001b[?25hInstalling collected packages: copulas, sdmetrics\n", 53 | "Successfully installed copulas-0.12.3 sdmetrics-0.22.0\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "!pip install sdmetrics" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 17, 64 | "metadata": { 65 | "id": "2NfHINaxj0_C", 66 | "colab": { 67 | "base_uri": "https://localhost:8080/", 68 | "height": 424 69 | }, 70 | "outputId": "a89b1fe8-3460-47df-eeb8-cbbaa520123d" 71 | }, 72 | "outputs": [ 73 | { 74 | "output_type": "execute_result", 75 | "data": { 76 | "text/plain": [ 77 | " Age Gender Blood Type Medical Condition Billing Amount \\\n", 78 | "0 30 1.0 5.0 2.0 18856.281306 \n", 79 | "1 62 1.0 0.0 5.0 33643.327287 \n", 80 | "2 76 0.0 1.0 5.0 27955.096079 \n", 81 | "3 28 0.0 6.0 3.0 37909.782410 \n", 82 | "4 43 0.0 2.0 2.0 14238.317814 \n", 83 | "... ... ... ... ... ... \n", 84 | "55495 42 0.0 6.0 1.0 2650.714952 \n", 85 | "55496 61 0.0 3.0 5.0 31457.797307 \n", 86 | "55497 38 0.0 4.0 4.0 27620.764717 \n", 87 | "55498 43 1.0 7.0 0.0 32451.092358 \n", 88 | "55499 53 0.0 6.0 0.0 4010.134172 \n", 89 | "\n", 90 | " Admission Type Medication Test Results \n", 91 | "0 2.0 3.0 2.0 \n", 92 | "1 1.0 1.0 1.0 \n", 93 | "2 1.0 0.0 2.0 \n", 94 | "3 0.0 1.0 0.0 \n", 95 | "4 2.0 4.0 0.0 \n", 96 | "... ... ... ... \n", 97 | "55495 0.0 4.0 0.0 \n", 98 | "55496 0.0 0.0 2.0 \n", 99 | "55497 2.0 1.0 0.0 \n", 100 | "55498 0.0 1.0 0.0 \n", 101 | "55499 2.0 1.0 0.0 \n", 102 | "\n", 103 | "[55500 rows x 8 columns]" 104 | ], 105 | "text/html": [ 106 | "\n", 107 | "
\n", 108 | "
\n", 109 | "\n", 122 | "\n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | "
AgeGenderBlood TypeMedical ConditionBilling AmountAdmission TypeMedicationTest Results
0301.05.02.018856.2813062.03.02.0
1621.00.05.033643.3272871.01.01.0
2760.01.05.027955.0960791.00.02.0
3280.06.03.037909.7824100.01.00.0
4430.02.02.014238.3178142.04.00.0
...........................
55495420.06.01.02650.7149520.04.00.0
55496610.03.05.031457.7973070.00.02.0
55497380.04.04.027620.7647172.01.00.0
55498431.07.00.032451.0923580.01.00.0
55499530.06.00.04010.1341722.01.00.0
\n", 260 | "

55500 rows × 8 columns

\n", 261 | "
\n", 262 | "
\n", 263 | "\n", 264 | "
\n", 265 | " \n", 273 | "\n", 274 | " \n", 314 | "\n", 315 | " \n", 339 | "
\n", 340 | "\n", 341 | "\n", 342 | "
\n", 343 | " \n", 354 | "\n", 355 | "\n", 444 | "\n", 445 | " \n", 467 | "
\n", 468 | "\n", 469 | "
\n", 470 | " \n", 501 | " \n", 510 | " \n", 522 | "
\n", 523 | "\n", 524 | "
\n", 525 | "
\n" 526 | ], 527 | "application/vnd.google.colaboratory.intrinsic+json": { 528 | "type": "dataframe", 529 | "variable_name": "df_real", 530 | "summary": "{\n \"name\": \"df_real\",\n \"rows\": 55500,\n \"fields\": [\n {\n \"column\": \"Age\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 19,\n \"min\": 13,\n \"max\": 89,\n \"num_unique_values\": 77,\n \"samples\": [\n 43,\n 22,\n 72\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Gender\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.5000043175658112,\n \"min\": 0.0,\n \"max\": 1.0,\n \"num_unique_values\": 2,\n \"samples\": [\n 0.0,\n 1.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Blood Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 2.289699610900092,\n \"min\": 0.0,\n \"max\": 7.0,\n \"num_unique_values\": 8,\n \"samples\": [\n 0.0,\n 3.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medical Condition\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.7083359197155301,\n \"min\": 0.0,\n \"max\": 5.0,\n \"num_unique_values\": 6,\n \"samples\": [\n 2.0,\n 5.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Billing Amount\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14211.45443086446,\n \"min\": -2008.4921398591305,\n \"max\": 52764.276736469175,\n \"num_unique_values\": 50000,\n \"samples\": [\n 41172.960486003554,\n 7672.233633429568\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Admission Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.8190475504400777,\n \"min\": 0.0,\n \"max\": 2.0,\n \"num_unique_values\": 3,\n \"samples\": [\n 2.0,\n 1.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medication\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.4132435830881946,\n \"min\": 0.0,\n \"max\": 4.0,\n \"num_unique_values\": 5,\n \"samples\": [\n 1.0,\n 2.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Test Results\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.8180888655374859,\n \"min\": 0.0,\n \"max\": 2.0,\n \"num_unique_values\": 3,\n \"samples\": [\n 2.0,\n 1.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" 531 | } 532 | }, 533 | "metadata": {}, 534 | "execution_count": 17 535 | } 536 | ], 537 | "source": [ 538 | "df_real = pd.read_csv('/content/healthcare_cleaned_data.csv')\n", 539 | "df_real" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 15, 545 | "metadata": { 546 | "id": "Zc__Y5t2lwIa", 547 | "colab": { 548 | "base_uri": "https://localhost:8080/", 549 | "height": 424 550 | }, 551 | "outputId": "82bd5713-5598-4391-f927-a189a135b89e" 552 | }, 553 | "outputs": [ 554 | { 555 | "output_type": "execute_result", 556 | "data": { 557 | "text/plain": [ 558 | " Age Gender Blood Type Medical Condition Billing Amount \\\n", 559 | "0 52 1.0 7.0 4.0 19205.266739 \n", 560 | "1 75 0.0 4.0 1.0 1189.229029 \n", 561 | "2 62 1.0 3.0 4.0 8068.886263 \n", 562 | "3 61 0.0 4.0 3.0 7179.079255 \n", 563 | "4 65 0.0 3.0 3.0 12120.088272 \n", 564 | "... ... ... ... ... ... \n", 565 | "29993 74 1.0 0.0 4.0 27015.554780 \n", 566 | "29994 53 1.0 3.0 2.0 45501.646881 \n", 567 | "29995 61 1.0 0.0 3.0 36968.704333 \n", 568 | "29996 44 1.0 6.0 5.0 48874.126856 \n", 569 | "29997 61 1.0 2.0 0.0 25784.574781 \n", 570 | "\n", 571 | " Admission Type Medication Test Results \n", 572 | "0 1.0 3.0 1.0 \n", 573 | "1 0.0 1.0 1.0 \n", 574 | "2 2.0 2.0 0.0 \n", 575 | "3 0.0 3.0 1.0 \n", 576 | "4 1.0 4.0 1.0 \n", 577 | "... ... ... ... \n", 578 | "29993 0.0 0.0 0.0 \n", 579 | "29994 0.0 1.0 2.0 \n", 580 | "29995 2.0 0.0 1.0 \n", 581 | "29996 2.0 2.0 0.0 \n", 582 | "29997 2.0 3.0 1.0 \n", 583 | "\n", 584 | "[29998 rows x 8 columns]" 585 | ], 586 | "text/html": [ 587 | "\n", 588 | "
\n", 589 | "
\n", 590 | "\n", 603 | "\n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | "
AgeGenderBlood TypeMedical ConditionBilling AmountAdmission TypeMedicationTest Results
0521.07.04.019205.2667391.03.01.0
1750.04.01.01189.2290290.01.01.0
2621.03.04.08068.8862632.02.00.0
3610.04.03.07179.0792550.03.01.0
4650.03.03.012120.0882721.04.01.0
...........................
29993741.00.04.027015.5547800.00.00.0
29994531.03.02.045501.6468810.01.02.0
29995611.00.03.036968.7043332.00.01.0
29996441.06.05.048874.1268562.02.00.0
29997611.02.00.025784.5747812.03.01.0
\n", 741 | "

29998 rows × 8 columns

\n", 742 | "
\n", 743 | "
\n", 744 | "\n", 745 | "
\n", 746 | " \n", 754 | "\n", 755 | " \n", 795 | "\n", 796 | " \n", 820 | "
\n", 821 | "\n", 822 | "\n", 823 | "
\n", 824 | " \n", 835 | "\n", 836 | "\n", 925 | "\n", 926 | " \n", 948 | "
\n", 949 | "\n", 950 | "
\n", 951 | " \n", 982 | " \n", 991 | " \n", 1003 | "
\n", 1004 | "\n", 1005 | "
\n", 1006 | "
\n" 1007 | ], 1008 | "application/vnd.google.colaboratory.intrinsic+json": { 1009 | "type": "dataframe", 1010 | "variable_name": "df_syncora", 1011 | "summary": "{\n \"name\": \"df_syncora\",\n \"rows\": 29998,\n \"fields\": [\n {\n \"column\": \"Age\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 19,\n \"min\": 10,\n \"max\": 92,\n \"num_unique_values\": 82,\n \"samples\": [\n 46,\n 52,\n 71\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Gender\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.5000074628547292,\n \"min\": 0.0,\n \"max\": 1.0,\n \"num_unique_values\": 2,\n \"samples\": [\n 0.0,\n 1.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Blood Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 2.297786847763739,\n \"min\": 0.0,\n \"max\": 7.0,\n \"num_unique_values\": 8,\n \"samples\": [\n 4.0,\n 5.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medical Condition\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.7031843744873485,\n \"min\": 0.0,\n \"max\": 5.0,\n \"num_unique_values\": 6,\n \"samples\": [\n 4.0,\n 1.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Billing Amount\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14244.80230719376,\n \"min\": -2503.2441829154573,\n \"max\": 54927.96333269359,\n \"num_unique_values\": 29998,\n \"samples\": [\n 22686.23873928449,\n 23125.61129632902\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Admission Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.8212432357383687,\n \"min\": 0.0,\n \"max\": 2.0,\n \"num_unique_values\": 3,\n \"samples\": [\n 1.0,\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medication\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.4136687408242463,\n \"min\": 0.0,\n \"max\": 4.0,\n \"num_unique_values\": 5,\n \"samples\": [\n 1.0,\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Test Results\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.8145298486109528,\n \"min\": 0.0,\n \"max\": 2.0,\n \"num_unique_values\": 3,\n \"samples\": [\n 1.0,\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" 1012 | } 1013 | }, 1014 | "metadata": {}, 1015 | "execution_count": 15 1016 | } 1017 | ], 1018 | "source": [ 1019 | "df_syncora = pd.read_csv('/content/syncora-healthcare.csv')\n", 1020 | "df_syncora" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "execution_count": 23, 1026 | "metadata": { 1027 | "id": "HFe0Dwj71SQV", 1028 | "colab": { 1029 | "base_uri": "https://localhost:8080/", 1030 | "height": 424 1031 | }, 1032 | "outputId": "365dff1e-dba5-42bf-bb61-db9e932a1980" 1033 | }, 1034 | "outputs": [ 1035 | { 1036 | "output_type": "execute_result", 1037 | "data": { 1038 | "text/plain": [ 1039 | " Age Gender Blood Type Medical Condition Billing Amount \\\n", 1040 | "0 57 1 1 4.0 9222.063822 \n", 1041 | "1 61 1 5 0.0 48199.843441 \n", 1042 | "2 79 0 2 5.0 34559.720382 \n", 1043 | "3 38 1 3 3.0 5152.106075 \n", 1044 | "4 20 0 5 5.0 47127.044982 \n", 1045 | "... ... ... ... ... ... \n", 1046 | "29995 67 0 5 3.0 24048.348990 \n", 1047 | "29996 67 1 0 5.0 306.935522 \n", 1048 | "29997 62 1 1 0.0 7144.839921 \n", 1049 | "29998 34 0 1 2.0 39901.103876 \n", 1050 | "29999 50 1 2 1.0 12641.611835 \n", 1051 | "\n", 1052 | " Admission Type Medication Test Results \n", 1053 | "0 2 2 2 \n", 1054 | "1 1 0 1 \n", 1055 | "2 2 0 1 \n", 1056 | "3 1 0 1 \n", 1057 | "4 0 0 2 \n", 1058 | "... ... ... ... \n", 1059 | "29995 1 1 1 \n", 1060 | "29996 2 4 2 \n", 1061 | "29997 2 3 0 \n", 1062 | "29998 1 1 0 \n", 1063 | "29999 2 2 0 \n", 1064 | "\n", 1065 | "[30000 rows x 8 columns]" 1066 | ], 1067 | "text/html": [ 1068 | "\n", 1069 | "
\n", 1070 | "
\n", 1071 | "\n", 1084 | "\n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | "
AgeGenderBlood TypeMedical ConditionBilling AmountAdmission TypeMedicationTest Results
057114.09222.063822222
161150.048199.843441101
279025.034559.720382201
338133.05152.106075101
420055.047127.044982002
...........................
2999567053.024048.348990111
2999667105.0306.935522242
2999762110.07144.839921230
2999834012.039901.103876110
2999950121.012641.611835220
\n", 1222 | "

30000 rows × 8 columns

\n", 1223 | "
\n", 1224 | "
\n", 1225 | "\n", 1226 | "
\n", 1227 | " \n", 1235 | "\n", 1236 | " \n", 1276 | "\n", 1277 | " \n", 1301 | "
\n", 1302 | "\n", 1303 | "\n", 1304 | "
\n", 1305 | " \n", 1316 | "\n", 1317 | "\n", 1406 | "\n", 1407 | " \n", 1429 | "
\n", 1430 | "\n", 1431 | "
\n", 1432 | " \n", 1463 | " \n", 1472 | " \n", 1484 | "
\n", 1485 | "\n", 1486 | "
\n", 1487 | "
\n" 1488 | ], 1489 | "application/vnd.google.colaboratory.intrinsic+json": { 1490 | "type": "dataframe", 1491 | "variable_name": "df_gretel", 1492 | "summary": "{\n \"name\": \"df_gretel\",\n \"rows\": 30000,\n \"fields\": [\n {\n \"column\": \"Age\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 19,\n \"min\": 13,\n \"max\": 89,\n \"num_unique_values\": 77,\n \"samples\": [\n 20,\n 47,\n 28\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Gender\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 1,\n \"num_unique_values\": 2,\n \"samples\": [\n 0,\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Blood Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 2,\n \"min\": 0,\n \"max\": 7,\n \"num_unique_values\": 8,\n \"samples\": [\n 5,\n 6\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medical Condition\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.6875826439204777,\n \"min\": 0.0,\n \"max\": 5.0,\n \"num_unique_values\": 6,\n \"samples\": [\n 4.0,\n 0.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Billing Amount\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 15707.36726458584,\n \"min\": -1956.0183883179,\n \"max\": 52755.0249810965,\n \"num_unique_values\": 30000,\n \"samples\": [\n 13652.8463977239,\n 14627.807964455\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Admission Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 2,\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medication\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 0,\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Test Results\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 2,\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" 1493 | } 1494 | }, 1495 | "metadata": {}, 1496 | "execution_count": 23 1497 | } 1498 | ], 1499 | "source": [ 1500 | "df_gretel = pd.read_csv('/content/gretel-healthcare.csv')\n", 1501 | "df_gretel\n" 1502 | ] 1503 | }, 1504 | { 1505 | "cell_type": "code", 1506 | "source": [ 1507 | "df_mostlyai = pd.read_csv('/content/mostlyai-healthcare.csv')\n", 1508 | "df_mostlyai" 1509 | ], 1510 | "metadata": { 1511 | "colab": { 1512 | "base_uri": "https://localhost:8080/", 1513 | "height": 424 1514 | }, 1515 | "id": "gZITpNXRLsYa", 1516 | "outputId": "66c7a807-d82e-4ef1-d797-3ee2f79bd0f5" 1517 | }, 1518 | "execution_count": 25, 1519 | "outputs": [ 1520 | { 1521 | "output_type": "execute_result", 1522 | "data": { 1523 | "text/plain": [ 1524 | " Age Gender Blood Type Medical Condition Billing Amount \\\n", 1525 | "0 19 1 7 5 34357.470891 \n", 1526 | "1 45 0 7 4 5015.284517 \n", 1527 | "2 29 1 2 3 45050.120972 \n", 1528 | "3 42 1 3 5 27874.820180 \n", 1529 | "4 83 0 7 5 14946.856473 \n", 1530 | "... ... ... ... ... ... \n", 1531 | "29995 76 1 0 1 12813.129371 \n", 1532 | "29996 23 0 3 5 5900.749086 \n", 1533 | "29997 34 1 1 5 44033.367695 \n", 1534 | "29998 34 0 0 3 48530.034546 \n", 1535 | "29999 51 1 0 4 13618.873465 \n", 1536 | "\n", 1537 | " Admission Type Medication Test Results \n", 1538 | "0 1 0 0 \n", 1539 | "1 0 4 0 \n", 1540 | "2 2 0 2 \n", 1541 | "3 1 3 2 \n", 1542 | "4 2 3 0 \n", 1543 | "... ... ... ... \n", 1544 | "29995 0 1 2 \n", 1545 | "29996 0 0 1 \n", 1546 | "29997 2 3 2 \n", 1547 | "29998 2 4 2 \n", 1548 | "29999 0 0 2 \n", 1549 | "\n", 1550 | "[30000 rows x 8 columns]" 1551 | ], 1552 | "text/html": [ 1553 | "\n", 1554 | "
\n", 1555 | "
\n", 1556 | "\n", 1569 | "\n", 1570 | " \n", 1571 | " \n", 1572 | " \n", 1573 | " \n", 1574 | " \n", 1575 | " \n", 1576 | " \n", 1577 | " \n", 1578 | " \n", 1579 | " \n", 1580 | " \n", 1581 | " \n", 1582 | " \n", 1583 | " \n", 1584 | " \n", 1585 | " \n", 1586 | " \n", 1587 | " \n", 1588 | " \n", 1589 | " \n", 1590 | " \n", 1591 | " \n", 1592 | " \n", 1593 | " \n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1602 | " \n", 1603 | " \n", 1604 | " \n", 1605 | " \n", 1606 | " \n", 1607 | " \n", 1608 | " \n", 1609 | " \n", 1610 | " \n", 1611 | " \n", 1612 | " \n", 1613 | " \n", 1614 | " \n", 1615 | " \n", 1616 | " \n", 1617 | " \n", 1618 | " \n", 1619 | " \n", 1620 | " \n", 1621 | " \n", 1622 | " \n", 1623 | " \n", 1624 | " \n", 1625 | " \n", 1626 | " \n", 1627 | " \n", 1628 | " \n", 1629 | " \n", 1630 | " \n", 1631 | " \n", 1632 | " \n", 1633 | " \n", 1634 | " \n", 1635 | " \n", 1636 | " \n", 1637 | " \n", 1638 | " \n", 1639 | " \n", 1640 | " \n", 1641 | " \n", 1642 | " \n", 1643 | " \n", 1644 | " \n", 1645 | " \n", 1646 | " \n", 1647 | " \n", 1648 | " \n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | "
AgeGenderBlood TypeMedical ConditionBilling AmountAdmission TypeMedicationTest Results
01917534357.470891100
1450745015.284517040
22912345050.120972202
34213527874.820180132
48307514946.856473230
...........................
299957610112813.129371012
29996230355900.749086001
299973411544033.367695232
299983400348530.034546242
299995110413618.873465002
\n", 1707 | "

30000 rows × 8 columns

\n", 1708 | "
\n", 1709 | "
\n", 1710 | "\n", 1711 | "
\n", 1712 | " \n", 1720 | "\n", 1721 | " \n", 1761 | "\n", 1762 | " \n", 1786 | "
\n", 1787 | "\n", 1788 | "\n", 1789 | "
\n", 1790 | " \n", 1801 | "\n", 1802 | "\n", 1891 | "\n", 1892 | " \n", 1914 | "
\n", 1915 | "\n", 1916 | "
\n", 1917 | " \n", 1948 | " \n", 1957 | " \n", 1969 | "
\n", 1970 | "\n", 1971 | "
\n", 1972 | "
\n" 1973 | ], 1974 | "application/vnd.google.colaboratory.intrinsic+json": { 1975 | "type": "dataframe", 1976 | "variable_name": "df_mostlyai", 1977 | "summary": "{\n \"name\": \"df_mostlyai\",\n \"rows\": 30000,\n \"fields\": [\n {\n \"column\": \"Age\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 19,\n \"min\": 13,\n \"max\": 89,\n \"num_unique_values\": 77,\n \"samples\": [\n 83,\n 62,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Gender\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 1,\n \"num_unique_values\": 2,\n \"samples\": [\n 0,\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Blood Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 2,\n \"min\": 0,\n \"max\": 7,\n \"num_unique_values\": 8,\n \"samples\": [\n 2,\n 1\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medical Condition\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 5,\n \"num_unique_values\": 6,\n \"samples\": [\n 5,\n 4\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Billing Amount\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14212.245751873123,\n \"min\": -1310.2728947084124,\n \"max\": 52170.03685355641,\n \"num_unique_values\": 29981,\n \"samples\": [\n 12914.23721681,\n 16524.88569619\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Admission Type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 1,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Medication\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Test Results\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" 1978 | } 1979 | }, 1980 | "metadata": {}, 1981 | "execution_count": 25 1982 | } 1983 | ] 1984 | }, 1985 | { 1986 | "cell_type": "markdown", 1987 | "source": [ 1988 | "The Below model training is specifically for : https://www.kaggle.com/datasets/prasad22/healthcare-dataset/code this dataset, feel free to change the code if you want to try different datasets." 1989 | ], 1990 | "metadata": { 1991 | "id": "tbn1MxPUJHHe" 1992 | } 1993 | }, 1994 | { 1995 | "cell_type": "code", 1996 | "execution_count": 26, 1997 | "metadata": { 1998 | "id": "FlkjgRx01okZ", 1999 | "colab": { 2000 | "base_uri": "https://localhost:8080/" 2001 | }, 2002 | "outputId": "9fab4875-4a7b-43a6-e4b9-e928a4af1f8c" 2003 | }, 2004 | "outputs": [ 2005 | { 2006 | "output_type": "stream", 2007 | "name": "stdout", 2008 | "text": [ 2009 | "Classification Report:\n", 2010 | " precision recall f1-score support\n", 2011 | "\n", 2012 | " 0.0 0.42 0.42 0.42 3754\n", 2013 | " 1.0 0.42 0.43 0.42 3617\n", 2014 | " 2.0 0.43 0.42 0.42 3729\n", 2015 | "\n", 2016 | " accuracy 0.42 11100\n", 2017 | " macro avg 0.42 0.42 0.42 11100\n", 2018 | "weighted avg 0.42 0.42 0.42 11100\n", 2019 | "\n", 2020 | "Accuracy: 0.4218018018018018\n" 2021 | ] 2022 | } 2023 | ], 2024 | "source": [ 2025 | "# Import necessary libraries\n", 2026 | "from sklearn.model_selection import train_test_split\n", 2027 | "from sklearn.ensemble import RandomForestClassifier\n", 2028 | "from sklearn.metrics import classification_report, accuracy_score\n", 2029 | "\n", 2030 | "# Define features (X) and target (y)\n", 2031 | "\n", 2032 | "X = df_real.drop(['Test Results'], axis=1)\n", 2033 | "y = df_real['Test Results']\n", 2034 | "\n", 2035 | "# Split data into training and testing sets\n", 2036 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 2037 | "\n", 2038 | "# Initialize and train the Random Forest Classifier\n", 2039 | "model = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2040 | "model.fit(X_train, y_train)\n", 2041 | "\n", 2042 | "# Make predictions on the test set\n", 2043 | "y_pred = model.predict(X_test)\n", 2044 | "\n", 2045 | "# Print the classification report\n", 2046 | "print(\"Classification Report:\")\n", 2047 | "print(classification_report(y_test, y_pred))\n", 2048 | "\n", 2049 | "# Print the accuracy score\n", 2050 | "print(\"Accuracy:\", accuracy_score(y_test, y_pred))" 2051 | ] 2052 | }, 2053 | { 2054 | "cell_type": "code", 2055 | "source": [ 2056 | "from sklearn.model_selection import train_test_split\n", 2057 | "from sklearn.ensemble import RandomForestClassifier\n", 2058 | "from sklearn.metrics import classification_report, accuracy_score\n", 2059 | "import pandas as pd\n", 2060 | "\n", 2061 | "# Step 1: Split df_real into training and test sets (20% held out for testing)\n", 2062 | "real_train, real_test = train_test_split(df_real, test_size=0.2, random_state=42)\n", 2063 | "\n", 2064 | "# Separate features and targets\n", 2065 | "X_real_train = real_train.drop(['Test Results'], axis=1)\n", 2066 | "y_real_train = real_train['Test Results']\n", 2067 | "\n", 2068 | "X_real_test = real_test.drop(['Test Results'], axis=1)\n", 2069 | "y_real_test = real_test['Test Results']\n", 2070 | "\n", 2071 | "# Step 2: Train model_real on only the real training data\n", 2072 | "model_real = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2073 | "model_real.fit(X_real_train, y_real_train)\n", 2074 | "\n", 2075 | "# Step 3: Combine synthetic and real training data, then train model_combined\n", 2076 | "combined_train_df = pd.concat([df_gretel, real_train], ignore_index=True)\n", 2077 | "X_combined_train = combined_train_df.drop(['Test Results'], axis=1)\n", 2078 | "y_combined_train = combined_train_df['Test Results']\n", 2079 | "\n", 2080 | "model_combined = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2081 | "model_combined.fit(X_combined_train, y_combined_train)\n", 2082 | "\n", 2083 | "# Step 4: Evaluate both models on the same real test set\n", 2084 | "y_pred_real = model_real.predict(X_real_test)\n", 2085 | "y_pred_combined = model_combined.predict(X_real_test)\n", 2086 | "\n", 2087 | "# Step 5: Print classification reports\n", 2088 | "print(\"=== Model Trained Only on Real Data ===\")\n", 2089 | "print(classification_report(y_real_test, y_pred_real))\n", 2090 | "print(\"Accuracy:\", accuracy_score(y_real_test, y_pred_real))\n", 2091 | "\n", 2092 | "print(\"\\n=== Model Trained on Real + Gretel Synthetic Data ===\")\n", 2093 | "print(classification_report(y_real_test, y_pred_combined))\n", 2094 | "print(\"Accuracy:\", accuracy_score(y_real_test, y_pred_combined))" 2095 | ], 2096 | "metadata": { 2097 | "colab": { 2098 | "base_uri": "https://localhost:8080/" 2099 | }, 2100 | "id": "frc6qYvt1peS", 2101 | "outputId": "db359022-05c2-4729-b69f-c68270adbc5d" 2102 | }, 2103 | "execution_count": 28, 2104 | "outputs": [ 2105 | { 2106 | "output_type": "stream", 2107 | "name": "stdout", 2108 | "text": [ 2109 | "=== Model Trained Only on Real Data ===\n", 2110 | " precision recall f1-score support\n", 2111 | "\n", 2112 | " 0.0 0.42 0.42 0.42 3754\n", 2113 | " 1.0 0.42 0.43 0.42 3617\n", 2114 | " 2.0 0.43 0.42 0.42 3729\n", 2115 | "\n", 2116 | " accuracy 0.42 11100\n", 2117 | " macro avg 0.42 0.42 0.42 11100\n", 2118 | "weighted avg 0.42 0.42 0.42 11100\n", 2119 | "\n", 2120 | "Accuracy: 0.4218018018018018\n", 2121 | "\n", 2122 | "=== Model Trained on Real + Gretel Synthetic Data ===\n", 2123 | " precision recall f1-score support\n", 2124 | "\n", 2125 | " 0.0 0.42 0.40 0.41 3754\n", 2126 | " 1.0 0.41 0.42 0.41 3617\n", 2127 | " 2.0 0.41 0.42 0.42 3729\n", 2128 | "\n", 2129 | " accuracy 0.41 11100\n", 2130 | " macro avg 0.41 0.41 0.41 11100\n", 2131 | "weighted avg 0.41 0.41 0.41 11100\n", 2132 | "\n", 2133 | "Accuracy: 0.41315315315315315\n" 2134 | ] 2135 | } 2136 | ] 2137 | }, 2138 | { 2139 | "cell_type": "code", 2140 | "source": [ 2141 | "from sklearn.model_selection import train_test_split\n", 2142 | "from sklearn.ensemble import RandomForestClassifier\n", 2143 | "from sklearn.metrics import classification_report, accuracy_score\n", 2144 | "import pandas as pd\n", 2145 | "\n", 2146 | "# Optional: Drop index columns if they exist\n", 2147 | "df_real = df_real.drop(columns=['Unnamed: 0'], errors='ignore')\n", 2148 | "df_mostlyai = df_mostlyai.drop(columns=['Unnamed: 0'], errors='ignore')\n", 2149 | "\n", 2150 | "# Step 1: Split df_real into training and test sets (20% held out for testing)\n", 2151 | "real_train, real_test = train_test_split(df_real, test_size=0.2, random_state=42)\n", 2152 | "\n", 2153 | "# Separate features and targets\n", 2154 | "X_real_train = real_train.drop(['Test Results'], axis=1)\n", 2155 | "y_real_train = real_train['Test Results']\n", 2156 | "\n", 2157 | "X_real_test = real_test.drop(['Test Results'], axis=1)\n", 2158 | "y_real_test = real_test['Test Results']\n", 2159 | "\n", 2160 | "# Step 2: Train model_real on only the real training data\n", 2161 | "model_real = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2162 | "model_real.fit(X_real_train, y_real_train)\n", 2163 | "\n", 2164 | "# Step 3: Combine synthetic and real training data, then train model_combined\n", 2165 | "combined_train_df = pd.concat([df_mostlyai, real_train], ignore_index=True)\n", 2166 | "X_combined_train = combined_train_df.drop(['Test Results'], axis=1)\n", 2167 | "y_combined_train = combined_train_df['Test Results']\n", 2168 | "\n", 2169 | "model_combined = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2170 | "model_combined.fit(X_combined_train, y_combined_train)\n", 2171 | "\n", 2172 | "# Step 4: Evaluate both models on the same real test set\n", 2173 | "y_pred_real = model_real.predict(X_real_test)\n", 2174 | "y_pred_combined = model_combined.predict(X_real_test)\n", 2175 | "\n", 2176 | "# Step 5: Print classification reports\n", 2177 | "print(\"=== Model Trained Only on Real Data ===\")\n", 2178 | "print(classification_report(y_real_test, y_pred_real))\n", 2179 | "print(\"Accuracy:\", accuracy_score(y_real_test, y_pred_real))\n", 2180 | "\n", 2181 | "print(\"\\n=== Model Trained on Real + MostlyAI Synthetic Data ===\")\n", 2182 | "print(classification_report(y_real_test, y_pred_combined))\n", 2183 | "print(\"Accuracy:\", accuracy_score(y_real_test, y_pred_combined))" 2184 | ], 2185 | "metadata": { 2186 | "colab": { 2187 | "base_uri": "https://localhost:8080/" 2188 | }, 2189 | "id": "mXbUrihBG1in", 2190 | "outputId": "83b689c4-3a40-479d-ec82-9a8fa0973725" 2191 | }, 2192 | "execution_count": 29, 2193 | "outputs": [ 2194 | { 2195 | "output_type": "stream", 2196 | "name": "stdout", 2197 | "text": [ 2198 | "=== Model Trained Only on Real Data ===\n", 2199 | " precision recall f1-score support\n", 2200 | "\n", 2201 | " 0.0 0.42 0.42 0.42 3754\n", 2202 | " 1.0 0.42 0.43 0.42 3617\n", 2203 | " 2.0 0.43 0.42 0.42 3729\n", 2204 | "\n", 2205 | " accuracy 0.42 11100\n", 2206 | " macro avg 0.42 0.42 0.42 11100\n", 2207 | "weighted avg 0.42 0.42 0.42 11100\n", 2208 | "\n", 2209 | "Accuracy: 0.4218018018018018\n", 2210 | "\n", 2211 | "=== Model Trained on Real + MostlyAI Synthetic Data ===\n", 2212 | " precision recall f1-score support\n", 2213 | "\n", 2214 | " 0.0 0.42 0.42 0.42 3754\n", 2215 | " 1.0 0.42 0.40 0.41 3617\n", 2216 | " 2.0 0.42 0.44 0.43 3729\n", 2217 | "\n", 2218 | " accuracy 0.42 11100\n", 2219 | " macro avg 0.42 0.42 0.42 11100\n", 2220 | "weighted avg 0.42 0.42 0.42 11100\n", 2221 | "\n", 2222 | "Accuracy: 0.42\n" 2223 | ] 2224 | } 2225 | ] 2226 | }, 2227 | { 2228 | "cell_type": "code", 2229 | "source": [ 2230 | "from sklearn.model_selection import train_test_split\n", 2231 | "from sklearn.ensemble import RandomForestClassifier\n", 2232 | "from sklearn.metrics import classification_report, accuracy_score\n", 2233 | "import pandas as pd\n", 2234 | "\n", 2235 | "\n", 2236 | "# Step 1: Split df_real into training and test sets (20% held out for testing)\n", 2237 | "real_train, real_test = train_test_split(df_real, test_size=0.2, random_state=42)\n", 2238 | "\n", 2239 | "# Separate features and targets\n", 2240 | "X_real_train = real_train.drop(['Test Results'], axis=1)\n", 2241 | "y_real_train = real_train['Test Results']\n", 2242 | "\n", 2243 | "X_real_test = real_test.drop(['Test Results'], axis=1)\n", 2244 | "y_real_test = real_test['Test Results']\n", 2245 | "\n", 2246 | "# Step 2: Train model_real on only the real training data\n", 2247 | "model_real = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2248 | "model_real.fit(X_real_train, y_real_train)\n", 2249 | "\n", 2250 | "# Step 3: Combine synthetic and real training data, then train model_combined\n", 2251 | "combined_train_df = pd.concat([df_syncora, real_train], ignore_index=True)\n", 2252 | "X_combined_train = combined_train_df.drop(['Test Results'], axis=1)\n", 2253 | "y_combined_train = combined_train_df['Test Results']\n", 2254 | "\n", 2255 | "model_combined = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2256 | "model_combined.fit(X_combined_train, y_combined_train)\n", 2257 | "\n", 2258 | "# Step 4: Evaluate both models on the same real test set\n", 2259 | "y_pred_real = model_real.predict(X_real_test)\n", 2260 | "y_pred_combined = model_combined.predict(X_real_test)\n", 2261 | "\n", 2262 | "# Step 5: Print classification reports\n", 2263 | "print(\"=== Model Trained Only on Real Data ===\")\n", 2264 | "print(classification_report(y_real_test, y_pred_real))\n", 2265 | "print(\"Accuracy:\", accuracy_score(y_real_test, y_pred_real))\n", 2266 | "\n", 2267 | "print(\"\\n=== Model Trained on Real + Syncora Synthetic Data ===\")\n", 2268 | "print(classification_report(y_real_test, y_pred_combined))\n", 2269 | "print(\"Accuracy:\", accuracy_score(y_real_test, y_pred_combined))" 2270 | ], 2271 | "metadata": { 2272 | "colab": { 2273 | "base_uri": "https://localhost:8080/" 2274 | }, 2275 | "id": "w4uB_iKPGkdO", 2276 | "outputId": "760a4f94-995e-47f2-c9c6-767599d3de6b" 2277 | }, 2278 | "execution_count": 30, 2279 | "outputs": [ 2280 | { 2281 | "output_type": "stream", 2282 | "name": "stdout", 2283 | "text": [ 2284 | "=== Model Trained Only on Real Data ===\n", 2285 | " precision recall f1-score support\n", 2286 | "\n", 2287 | " 0.0 0.42 0.42 0.42 3754\n", 2288 | " 1.0 0.42 0.43 0.42 3617\n", 2289 | " 2.0 0.43 0.42 0.42 3729\n", 2290 | "\n", 2291 | " accuracy 0.42 11100\n", 2292 | " macro avg 0.42 0.42 0.42 11100\n", 2293 | "weighted avg 0.42 0.42 0.42 11100\n", 2294 | "\n", 2295 | "Accuracy: 0.4218018018018018\n", 2296 | "\n", 2297 | "=== Model Trained on Real + Syncora Synthetic Data ===\n", 2298 | " precision recall f1-score support\n", 2299 | "\n", 2300 | " 0.0 0.63 0.62 0.62 3754\n", 2301 | " 1.0 0.61 0.63 0.62 3617\n", 2302 | " 2.0 0.63 0.62 0.62 3729\n", 2303 | "\n", 2304 | " accuracy 0.62 11100\n", 2305 | " macro avg 0.62 0.62 0.62 11100\n", 2306 | "weighted avg 0.62 0.62 0.62 11100\n", 2307 | "\n", 2308 | "Accuracy: 0.6222522522522522\n" 2309 | ] 2310 | } 2311 | ] 2312 | }, 2313 | { 2314 | "cell_type": "code", 2315 | "execution_count": 31, 2316 | "metadata": { 2317 | "id": "I9bmBoho2Ruw", 2318 | "colab": { 2319 | "base_uri": "https://localhost:8080/" 2320 | }, 2321 | "outputId": "7df3c8d1-d52a-4f1a-c2c7-54ca643699a1" 2322 | }, 2323 | "outputs": [ 2324 | { 2325 | "output_type": "stream", 2326 | "name": "stdout", 2327 | "text": [ 2328 | "Classification Report for MostlyAI:\n", 2329 | " precision recall f1-score support\n", 2330 | "\n", 2331 | " 0 0.34 0.35 0.34 2000\n", 2332 | " 1 0.30 0.27 0.29 1828\n", 2333 | " 2 0.36 0.39 0.38 2172\n", 2334 | "\n", 2335 | " accuracy 0.34 6000\n", 2336 | " macro avg 0.34 0.34 0.34 6000\n", 2337 | "weighted avg 0.34 0.34 0.34 6000\n", 2338 | "\n", 2339 | "Accuracy: 0.33916666666666667\n" 2340 | ] 2341 | } 2342 | ], 2343 | "source": [ 2344 | "# Import necessary libraries\n", 2345 | "from sklearn.model_selection import train_test_split\n", 2346 | "from sklearn.ensemble import RandomForestClassifier\n", 2347 | "from sklearn.metrics import classification_report, accuracy_score\n", 2348 | "\n", 2349 | "# Define features (X) and target (y)\n", 2350 | "X = df_mostlyai.drop(['Test Results'], axis=1)\n", 2351 | "y = df_mostlyai['Test Results']\n", 2352 | "\n", 2353 | "# Split data into training and testing sets\n", 2354 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 2355 | "\n", 2356 | "# Initialize and train the Random Forest Classifier\n", 2357 | "model_mostlyai = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2358 | "model_mostlyai.fit(X_train, y_train)\n", 2359 | "\n", 2360 | "# Make predictions on the test set\n", 2361 | "y_pred = model_mostlyai.predict(X_test)\n", 2362 | "\n", 2363 | "# Print the classification report\n", 2364 | "print(\"Classification Report for MostlyAI:\")\n", 2365 | "print(classification_report(y_test, y_pred))\n", 2366 | "\n", 2367 | "# Print the accuracy score\n", 2368 | "print(\"Accuracy:\", accuracy_score(y_test, y_pred))" 2369 | ] 2370 | }, 2371 | { 2372 | "cell_type": "code", 2373 | "source": [ 2374 | "# Import necessary libraries\n", 2375 | "from sklearn.model_selection import train_test_split\n", 2376 | "from sklearn.ensemble import RandomForestClassifier\n", 2377 | "from sklearn.metrics import classification_report, accuracy_score\n", 2378 | "\n", 2379 | "# Define features (X) and target (y)\n", 2380 | "X = df_gretel.drop(['Test Results'], axis=1)\n", 2381 | "y = df_gretel['Test Results']\n", 2382 | "\n", 2383 | "# Split data into training and testing sets\n", 2384 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 2385 | "\n", 2386 | "# Initialize and train the Random Forest Classifier\n", 2387 | "model_gretel = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2388 | "model_gretel.fit(X_train, y_train)\n", 2389 | "\n", 2390 | "# Make predictions on the test set\n", 2391 | "y_pred = model_gretel.predict(X_test)\n", 2392 | "\n", 2393 | "# Print the classification report\n", 2394 | "print(\"Classification Report for Gretel:\")\n", 2395 | "print(classification_report(y_test, y_pred))\n", 2396 | "\n", 2397 | "# Print the accuracy score\n", 2398 | "print(\"Accuracy:\", accuracy_score(y_test, y_pred))" 2399 | ], 2400 | "metadata": { 2401 | "colab": { 2402 | "base_uri": "https://localhost:8080/" 2403 | }, 2404 | "id": "GD96lqizN_Rj", 2405 | "outputId": "ec95b12a-6cdc-4c57-e691-9b11e1d5f8e2" 2406 | }, 2407 | "execution_count": 37, 2408 | "outputs": [ 2409 | { 2410 | "output_type": "stream", 2411 | "name": "stdout", 2412 | "text": [ 2413 | "Classification Report for Gretel:\n", 2414 | " precision recall f1-score support\n", 2415 | "\n", 2416 | " 0 0.33 0.33 0.33 1903\n", 2417 | " 1 0.33 0.33 0.33 2007\n", 2418 | " 2 0.35 0.35 0.35 2090\n", 2419 | "\n", 2420 | " accuracy 0.34 6000\n", 2421 | " macro avg 0.34 0.34 0.34 6000\n", 2422 | "weighted avg 0.34 0.34 0.34 6000\n", 2423 | "\n", 2424 | "Accuracy: 0.3358333333333333\n" 2425 | ] 2426 | } 2427 | ] 2428 | }, 2429 | { 2430 | "cell_type": "code", 2431 | "execution_count": 32, 2432 | "metadata": { 2433 | "id": "z002rkwm2tjx", 2434 | "colab": { 2435 | "base_uri": "https://localhost:8080/" 2436 | }, 2437 | "outputId": "a3c29ad8-185f-48b6-ce34-bdd2ce69c17f" 2438 | }, 2439 | "outputs": [ 2440 | { 2441 | "output_type": "stream", 2442 | "name": "stdout", 2443 | "text": [ 2444 | "Classification Report Syncora:\n", 2445 | " precision recall f1-score support\n", 2446 | "\n", 2447 | " 0.0 0.56 0.56 0.56 2008\n", 2448 | " 1.0 0.57 0.57 0.57 2034\n", 2449 | " 2.0 0.55 0.55 0.55 1958\n", 2450 | "\n", 2451 | " accuracy 0.56 6000\n", 2452 | " macro avg 0.56 0.56 0.56 6000\n", 2453 | "weighted avg 0.56 0.56 0.56 6000\n", 2454 | "\n", 2455 | "Accuracy: 0.559\n" 2456 | ] 2457 | } 2458 | ], 2459 | "source": [ 2460 | "# Import necessary libraries\n", 2461 | "from sklearn.model_selection import train_test_split\n", 2462 | "from sklearn.ensemble import RandomForestClassifier\n", 2463 | "from sklearn.metrics import classification_report, accuracy_score\n", 2464 | "\n", 2465 | "# Define features (X) and target (y)\n", 2466 | "X = df_syncora.drop(['Test Results'], axis=1)\n", 2467 | "y = df_syncora['Test Results']\n", 2468 | "\n", 2469 | "# Split data into training and testing sets\n", 2470 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", 2471 | "\n", 2472 | "# Initialize and train the Random Forest Classifier\n", 2473 | "model_syncora = RandomForestClassifier(n_estimators=100, random_state=42)\n", 2474 | "model_syncora.fit(X_train, y_train)\n", 2475 | "\n", 2476 | "# Make predictions on the test set\n", 2477 | "y_pred = model_syncora.predict(X_test)\n", 2478 | "\n", 2479 | "# Print the classification report\n", 2480 | "print(\"Classification Report Syncora:\")\n", 2481 | "print(classification_report(y_test, y_pred))\n", 2482 | "\n", 2483 | "# Print the accuracy score\n", 2484 | "print(\"Accuracy:\", accuracy_score(y_test, y_pred))" 2485 | ] 2486 | }, 2487 | { 2488 | "cell_type": "code", 2489 | "execution_count": 33, 2490 | "metadata": { 2491 | "id": "BmCznQhp1mYz", 2492 | "colab": { 2493 | "base_uri": "https://localhost:8080/" 2494 | }, 2495 | "outputId": "40691e0b-1d54-4dd0-80d6-808b4d96734b" 2496 | }, 2497 | "outputs": [ 2498 | { 2499 | "output_type": "stream", 2500 | "name": "stdout", 2501 | "text": [ 2502 | "Accuracy on df_real: 0.8843603603603604\n", 2503 | "\n", 2504 | "Classification Report on df_real:\n", 2505 | " precision recall f1-score support\n", 2506 | "\n", 2507 | " 0.0 0.88 0.88 0.88 18627\n", 2508 | " 1.0 0.88 0.89 0.88 18356\n", 2509 | " 2.0 0.89 0.88 0.88 18517\n", 2510 | "\n", 2511 | " accuracy 0.88 55500\n", 2512 | " macro avg 0.88 0.88 0.88 55500\n", 2513 | "weighted avg 0.88 0.88 0.88 55500\n", 2514 | "\n" 2515 | ] 2516 | } 2517 | ], 2518 | "source": [ 2519 | "# prompt: now convert whole df_real as a test data and print the accuracy of above model on that whole data\n", 2520 | "\n", 2521 | "# Convert the entire df_real to test data\n", 2522 | "X_test_real = df_real.drop(['Test Results'], axis=1)\n", 2523 | "y_test_real = df_real['Test Results']\n", 2524 | "\n", 2525 | "# Make predictions on the entire df_real test data\n", 2526 | "y_pred_real = model.predict(X_test_real)\n", 2527 | "\n", 2528 | "# Print the accuracy score on the entire df_real data\n", 2529 | "print(\"Accuracy on df_real:\", accuracy_score(y_test_real, y_pred_real))\n", 2530 | "\n", 2531 | "# Print the classification report on the entire df_real data\n", 2532 | "print(\"\\nClassification Report on df_real:\")\n", 2533 | "print(classification_report(y_test_real, y_pred_real))" 2534 | ] 2535 | }, 2536 | { 2537 | "cell_type": "code", 2538 | "execution_count": 35, 2539 | "metadata": { 2540 | "id": "7Dlzyw9Q2f_I", 2541 | "colab": { 2542 | "base_uri": "https://localhost:8080/" 2543 | }, 2544 | "outputId": "bb71cb90-bd0e-4d72-9601-da9e51f7fdd0" 2545 | }, 2546 | "outputs": [ 2547 | { 2548 | "output_type": "stream", 2549 | "name": "stdout", 2550 | "text": [ 2551 | "Accuracy on df_mostlyai: 0.3367747747747748\n", 2552 | "\n", 2553 | "Classification Report on df_mostlyai:\n", 2554 | " precision recall f1-score support\n", 2555 | "\n", 2556 | " 0.0 0.34 0.35 0.35 18627\n", 2557 | " 1.0 0.34 0.27 0.30 18356\n", 2558 | " 2.0 0.33 0.38 0.36 18517\n", 2559 | "\n", 2560 | " accuracy 0.34 55500\n", 2561 | " macro avg 0.34 0.34 0.33 55500\n", 2562 | "weighted avg 0.34 0.34 0.34 55500\n", 2563 | "\n" 2564 | ] 2565 | } 2566 | ], 2567 | "source": [ 2568 | "# Make predictions on the entire df_real test data\n", 2569 | "y_pred_mostlyai = model_mostlyai.predict(X_test_real)\n", 2570 | "\n", 2571 | "# Print the accuracy score on the entire df_real data\n", 2572 | "print(\"Accuracy on df_mostlyai:\", accuracy_score(y_test_real, y_pred_mostlyai))\n", 2573 | "\n", 2574 | "# Print the classification report on the entire df_real data\n", 2575 | "print(\"\\nClassification Report on df_mostlyai:\")\n", 2576 | "print(classification_report(y_test_real, y_pred_mostlyai))" 2577 | ] 2578 | }, 2579 | { 2580 | "cell_type": "code", 2581 | "execution_count": 36, 2582 | "metadata": { 2583 | "id": "b-K4ceOn2law", 2584 | "colab": { 2585 | "base_uri": "https://localhost:8080/" 2586 | }, 2587 | "outputId": "056eaf40-0a34-44f9-8d50-b60657d6f4eb" 2588 | }, 2589 | "outputs": [ 2590 | { 2591 | "output_type": "stream", 2592 | "name": "stdout", 2593 | "text": [ 2594 | "Accuracy on df_syncora: 0.57009009009009\n", 2595 | "\n", 2596 | "Classification Report on df_syncora:\n", 2597 | " precision recall f1-score support\n", 2598 | "\n", 2599 | " 0.0 0.57 0.56 0.57 18627\n", 2600 | " 1.0 0.56 0.58 0.57 18356\n", 2601 | " 2.0 0.57 0.57 0.57 18517\n", 2602 | "\n", 2603 | " accuracy 0.57 55500\n", 2604 | " macro avg 0.57 0.57 0.57 55500\n", 2605 | "weighted avg 0.57 0.57 0.57 55500\n", 2606 | "\n" 2607 | ] 2608 | } 2609 | ], 2610 | "source": [ 2611 | "# Make predictions on the entire df_real test data\n", 2612 | "y_pred_syncora = model_syncora.predict(X_test_real)\n", 2613 | "\n", 2614 | "# Print the accuracy score on the entire df_real data\n", 2615 | "print(\"Accuracy on df_syncora:\", accuracy_score(y_test_real, y_pred_syncora))\n", 2616 | "\n", 2617 | "# Print the classification report on the entire df_real data\n", 2618 | "print(\"\\nClassification Report on df_syncora:\")\n", 2619 | "print(classification_report(y_test_real, y_pred_syncora))" 2620 | ] 2621 | }, 2622 | { 2623 | "cell_type": "code", 2624 | "source": [ 2625 | "# Make predictions on the entire df_real test data\n", 2626 | "y_pred_gretel = model_gretel.predict(X_test_real)\n", 2627 | "\n", 2628 | "# Print the accuracy score on the entire df_real data\n", 2629 | "print(\"Accuracy on df_gretel:\", accuracy_score(y_test_real, y_pred_gretel))\n", 2630 | "\n", 2631 | "# Print the classification report on the entire df_real data\n", 2632 | "print(\"\\nClassification Report on df_gretel:\")\n", 2633 | "print(classification_report(y_test_real, y_pred_gretel))" 2634 | ], 2635 | "metadata": { 2636 | "colab": { 2637 | "base_uri": "https://localhost:8080/" 2638 | }, 2639 | "id": "eAdeWh-EORnj", 2640 | "outputId": "12781903-b026-4655-c2b3-877b3036170e" 2641 | }, 2642 | "execution_count": 38, 2643 | "outputs": [ 2644 | { 2645 | "output_type": "stream", 2646 | "name": "stdout", 2647 | "text": [ 2648 | "Accuracy on df_gretel: 0.33535135135135136\n", 2649 | "\n", 2650 | "Classification Report on df_gretel:\n", 2651 | " precision recall f1-score support\n", 2652 | "\n", 2653 | " 0.0 0.34 0.33 0.33 18627\n", 2654 | " 1.0 0.33 0.33 0.33 18356\n", 2655 | " 2.0 0.34 0.35 0.34 18517\n", 2656 | "\n", 2657 | " accuracy 0.34 55500\n", 2658 | " macro avg 0.34 0.34 0.34 55500\n", 2659 | "weighted avg 0.34 0.34 0.34 55500\n", 2660 | "\n" 2661 | ] 2662 | } 2663 | ] 2664 | }, 2665 | { 2666 | "cell_type": "code", 2667 | "execution_count": 41, 2668 | "metadata": { 2669 | "id": "-dbuURPJk1Gj", 2670 | "colab": { 2671 | "base_uri": "https://localhost:8080/" 2672 | }, 2673 | "outputId": "df4bfd6e-4c8c-4fda-ed7c-56a0cd4d8ff4" 2674 | }, 2675 | "outputs": [ 2676 | { 2677 | "output_type": "stream", 2678 | "name": "stdout", 2679 | "text": [ 2680 | "KS Complement (Continuous Columns):\n", 2681 | "Syncora vs Real: {'Age': np.float64(0.9890484218467417), 'Billing Amount': np.float64(0.996284565517581)}\n", 2682 | "Gretel vs Real: {'Age': np.float64(0.9857360360360361), 'Billing Amount': np.float64(0.829954054054054)}\n", 2683 | "MostlyAI vs Real: {'Age': np.float64(0.9907990990990991), 'Billing Amount': np.float64(0.9946549549549549)}\n", 2684 | "\n", 2685 | "TV Complement (Discrete Columns):\n", 2686 | "Syncora vs Real: {'Gender': 0.9986341720078636, 'Blood Type': 0.9938083187527817, 'Medical Condition': 0.995611438360155, 'Admission Type': 0.995003384610025, 'Medication': 0.9946757231262865, 'Test Results': 0.9941829683540464}\n", 2687 | "Gretel vs Real: {'Gender': 0.9641324324324324, 'Blood Type': 0.9745486486486487, 'Medical Condition': 0.9740297297297297, 'Admission Type': 0.994536036036036, 'Medication': 0.9823783783783784, 'Test Results': 0.9860450450450451}\n", 2688 | "MostlyAI vs Real: {'Gender': 0.9987342342342342, 'Blood Type': 0.9771702702702703, 'Medical Condition': 0.9870252252252252, 'Admission Type': 0.9906954954954955, 'Medication': 0.9888261261261261, 'Test Results': 0.9769945945945946}\n" 2689 | ] 2690 | } 2691 | ], 2692 | "source": [ 2693 | "# prompt: Using dataframe df_real: using sdmetrics, find the two things for df_syncora vs df_real , df_gretel vs df_real, df_mostlyai vs df_real KS Complement(for continuos values) and TV Complement for discrete values.\n", 2694 | "\n", 2695 | "from sdmetrics.single_column import KSComplement, TVComplement\n", 2696 | "\n", 2697 | "# Assuming df_syncora, df_gretel, and df_mostlyai are also loaded DataFrames\n", 2698 | "\n", 2699 | "# List of continuous columns (excluding the identifier 'Unnamed: 0')\n", 2700 | "continuous_cols = ['Age', 'Billing Amount']\n", 2701 | "\n", 2702 | "# List of discrete columns\n", 2703 | "discrete_cols = ['Gender', 'Blood Type', 'Medical Condition', 'Admission Type', 'Medication', 'Test Results']\n", 2704 | "\n", 2705 | "# Calculate KS Complement for continuous columns\n", 2706 | "ks_syncora_real = {col: KSComplement.compute(df_real[col], df_syncora[col]) for col in continuous_cols}\n", 2707 | "ks_gretel_real = {col: KSComplement.compute(df_real[col], df_gretel[col]) for col in continuous_cols}\n", 2708 | "ks_mostlyai_real = {col: KSComplement.compute(df_real[col], df_mostlyai[col]) for col in continuous_cols}\n", 2709 | "\n", 2710 | "# Calculate TV Complement for discrete columns\n", 2711 | "tv_syncora_real = {col: TVComplement.compute(df_real[col], df_syncora[col]) for col in discrete_cols}\n", 2712 | "tv_gretel_real = {col: TVComplement.compute(df_real[col], df_gretel[col]) for col in discrete_cols}\n", 2713 | "tv_mostlyai_real = {col: TVComplement.compute(df_real[col], df_mostlyai[col]) for col in discrete_cols}\n", 2714 | "\n", 2715 | "# Print the results\n", 2716 | "print(\"KS Complement (Continuous Columns):\")\n", 2717 | "print(\"Syncora vs Real:\", ks_syncora_real)\n", 2718 | "print(\"Gretel vs Real:\", ks_gretel_real)\n", 2719 | "print(\"MostlyAI vs Real:\", ks_mostlyai_real)\n", 2720 | "print(\"\\nTV Complement (Discrete Columns):\")\n", 2721 | "print(\"Syncora vs Real:\", tv_syncora_real)\n", 2722 | "print(\"Gretel vs Real:\", tv_gretel_real)\n", 2723 | "print(\"MostlyAI vs Real:\", tv_mostlyai_real)" 2724 | ] 2725 | } 2726 | ], 2727 | "metadata": { 2728 | "colab": { 2729 | "provenance": [], 2730 | "gpuType": "T4" 2731 | }, 2732 | "kernelspec": { 2733 | "display_name": "Python 3", 2734 | "name": "python3" 2735 | }, 2736 | "language_info": { 2737 | "name": "python" 2738 | }, 2739 | "accelerator": "GPU" 2740 | }, 2741 | "nbformat": 4, 2742 | "nbformat_minor": 0 2743 | } --------------------------------------------------------------------------------