├── .gitignore ├── LICENSE ├── README.md ├── binary_diffusion_tabular ├── __init__.py ├── dataset.py ├── diffusion.py ├── model.py ├── trainer.py ├── transformation.py └── utils.py ├── configs ├── adult.yaml ├── diabetes.yaml ├── heloc.yaml ├── housing.yaml ├── sick.yaml └── travel.yaml ├── data ├── adult.csv ├── adult_test.csv ├── adult_train.csv ├── diabetes.csv ├── diabetes_test.csv ├── diabetes_train.csv ├── heloc.csv ├── heloc_test.csv ├── heloc_train.csv ├── housing.csv ├── housing_test.csv ├── housing_train.csv ├── sick.csv ├── sick_test.csv ├── sick_train.csv ├── travel.csv ├── travel_test.csv └── travel_train.csv ├── environment.yml ├── pyproject.toml ├── requirements.txt ├── sample.py ├── tests ├── test_fixed_size_binary_table_transformation.py └── test_model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks,images 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm,jupyternotebooks,images 3 | 4 | # directory with cache for metrics evaluation 5 | cache/ 6 | plots/ 7 | 8 | .idea/ 9 | data/cifar-10-batches-py 10 | results/ 11 | wandb/ 12 | checkpoint/ 13 | temp/ 14 | 15 | # pytorch checkpoints and weights 16 | *.pth 17 | *.pt 18 | *.pkl 19 | *.ckpt 20 | 21 | ### Images ### 22 | # JPEG 23 | *.jpg 24 | *.jpeg 25 | *.jpe 26 | *.jif 27 | *.jfif 28 | *.jfi 29 | 30 | # JPEG 2000 31 | *.jp2 32 | *.j2k 33 | *.jpf 34 | *.jpx 35 | *.jpm 36 | *.mj2 37 | 38 | # JPEG XR 39 | *.jxr 40 | *.hdp 41 | *.wdp 42 | 43 | # Graphics Interchange Format 44 | *.gif 45 | 46 | # RAW 47 | *.raw 48 | 49 | # Web P 50 | *.webp 51 | 52 | # Portable Network Graphics 53 | *.png 54 | 55 | # Animated Portable Network Graphics 56 | *.apng 57 | 58 | # Multiple-image Network Graphics 59 | *.mng 60 | 61 | # Tagged Image File Format 62 | *.tiff 63 | *.tif 64 | 65 | # Scalable Vector Graphics 66 | *.svg 67 | *.svgz 68 | 69 | # Portable Document Format 70 | *.pdf 71 | 72 | # X BitMap 73 | *.xbm 74 | 75 | # BMP 76 | *.bmp 77 | *.dib 78 | 79 | # ICO 80 | *.ico 81 | 82 | # 3D Images 83 | *.3dm 84 | *.max 85 | 86 | ### JupyterNotebooks ### 87 | # gitignore template for Jupyter Notebooks 88 | # website: http://jupyter.org/ 89 | 90 | .ipynb_checkpoints 91 | */.ipynb_checkpoints/* 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # Remove previous ipynb_checkpoints 98 | # git rm -r .ipynb_checkpoints/ 99 | 100 | ### PyCharm ### 101 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 102 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 103 | 104 | # User-specific stuff 105 | .idea/**/workspace.xml 106 | .idea/**/tasks.xml 107 | .idea/**/usage.statistics.xml 108 | .idea/**/dictionaries 109 | .idea/**/shelf 110 | 111 | # AWS User-specific 112 | .idea/**/aws.xml 113 | 114 | # Generated files 115 | .idea/**/contentModel.xml 116 | 117 | # Sensitive or high-churn files 118 | .idea/**/dataSources/ 119 | .idea/**/dataSources.ids 120 | .idea/**/dataSources.local.xml 121 | .idea/**/sqlDataSources.xml 122 | .idea/**/dynamic.xml 123 | .idea/**/uiDesigner.xml 124 | .idea/**/dbnavigator.xml 125 | 126 | # Gradle 127 | .idea/**/gradle.xml 128 | .idea/**/libraries 129 | 130 | # Gradle and Maven with auto-import 131 | # When using Gradle or Maven with auto-import, you should exclude module files, 132 | # since they will be recreated, and may cause churn. Uncomment if using 133 | # auto-import. 134 | # .idea/artifacts 135 | # .idea/compiler.xml 136 | # .idea/jarRepositories.xml 137 | # .idea/modules.xml 138 | # .idea/*.iml 139 | # .idea/modules 140 | # *.iml 141 | # *.ipr 142 | 143 | # CMake 144 | cmake-build-*/ 145 | 146 | # Mongo Explorer plugin 147 | .idea/**/mongoSettings.xml 148 | 149 | # File-based project format 150 | *.iws 151 | 152 | # IntelliJ 153 | out/ 154 | 155 | # mpeltonen/sbt-idea plugin 156 | .idea_modules/ 157 | 158 | # JIRA plugin 159 | atlassian-ide-plugin.xml 160 | 161 | # Cursive Clojure plugin 162 | .idea/replstate.xml 163 | 164 | # SonarLint plugin 165 | .idea/sonarlint/ 166 | 167 | # Crashlytics plugin (for Android Studio and IntelliJ) 168 | com_crashlytics_export_strings.xml 169 | crashlytics.properties 170 | crashlytics-build.properties 171 | fabric.properties 172 | 173 | # Editor-based Rest Client 174 | .idea/httpRequests 175 | 176 | # Android studio 3.1+ serialized cache file 177 | .idea/caches/build_file_checksums.ser 178 | 179 | ### PyCharm Patch ### 180 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 181 | 182 | # *.iml 183 | # modules.xml 184 | # .idea/misc.xml 185 | # *.ipr 186 | 187 | # Sonarlint plugin 188 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 189 | .idea/**/sonarlint/ 190 | 191 | # SonarQube Plugin 192 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 193 | .idea/**/sonarIssues.xml 194 | 195 | # Markdown Navigator plugin 196 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 197 | .idea/**/markdown-navigator.xml 198 | .idea/**/markdown-navigator-enh.xml 199 | .idea/**/markdown-navigator/ 200 | 201 | # Cache file creation bug 202 | # See https://youtrack.jetbrains.com/issue/JBR-2257 203 | .idea/$CACHE_FILE$ 204 | 205 | # CodeStream plugin 206 | # https://plugins.jetbrains.com/plugin/12206-codestream 207 | .idea/codestream.xml 208 | 209 | # Azure Toolkit for IntelliJ plugin 210 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 211 | .idea/**/azureSettings.xml 212 | 213 | ### Python ### 214 | # Byte-compiled / optimized / DLL files 215 | __pycache__/ 216 | *.py[cod] 217 | *$py.class 218 | 219 | # C extensions 220 | *.so 221 | 222 | # Distribution / packaging 223 | .Python 224 | build/ 225 | develop-eggs/ 226 | dist/ 227 | downloads/ 228 | eggs/ 229 | .eggs/ 230 | lib/ 231 | lib64/ 232 | parts/ 233 | sdist/ 234 | var/ 235 | wheels/ 236 | share/python-wheels/ 237 | *.egg-info/ 238 | .installed.cfg 239 | *.egg 240 | MANIFEST 241 | 242 | # PyInstaller 243 | # Usually these files are written by a python script from a template 244 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 245 | *.manifest 246 | *.spec 247 | 248 | # Installer logs 249 | pip-log.txt 250 | pip-delete-this-directory.txt 251 | 252 | # Unit test / coverage reports 253 | htmlcov/ 254 | .tox/ 255 | .nox/ 256 | .coverage 257 | .coverage.* 258 | .cache 259 | nosetests.xml 260 | coverage.xml 261 | *.cover 262 | *.py,cover 263 | .hypothesis/ 264 | .pytest_cache/ 265 | cover/ 266 | 267 | # Translations 268 | *.mo 269 | *.pot 270 | 271 | # Django stuff: 272 | *.log 273 | local_settings.py 274 | db.sqlite3 275 | db.sqlite3-journal 276 | 277 | # Flask stuff: 278 | instance/ 279 | .webassets-cache 280 | 281 | # Scrapy stuff: 282 | .scrapy 283 | 284 | # Sphinx documentation 285 | docs/_build/ 286 | 287 | # PyBuilder 288 | .pybuilder/ 289 | target/ 290 | 291 | # Jupyter Notebook 292 | 293 | # IPython 294 | 295 | # pyenv 296 | # For a library or package, you might want to ignore these files since the code is 297 | # intended to run in multiple environments; otherwise, check them in: 298 | # .python-version 299 | 300 | # pipenv 301 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 302 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 303 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 304 | # install all needed dependencies. 305 | #Pipfile.lock 306 | 307 | # poetry 308 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 309 | # This is especially recommended for binary packages to ensure reproducibility, and is more 310 | # commonly ignored for libraries. 311 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 312 | #poetry.lock 313 | 314 | # pdm 315 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 316 | #pdm.lock 317 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 318 | # in version control. 319 | # https://pdm.fming.dev/#use-with-ide 320 | .pdm.toml 321 | 322 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 323 | __pypackages__/ 324 | 325 | # Celery stuff 326 | celerybeat-schedule 327 | celerybeat.pid 328 | 329 | # SageMath parsed files 330 | *.sage.py 331 | 332 | # Environments 333 | .env 334 | .venv 335 | env/ 336 | venv/ 337 | ENV/ 338 | env.bak/ 339 | venv.bak/ 340 | 341 | # Spyder project settings 342 | .spyderproject 343 | .spyproject 344 | 345 | # Rope project settings 346 | .ropeproject 347 | 348 | # mkdocs documentation 349 | /site 350 | 351 | # mypy 352 | .mypy_cache/ 353 | .dmypy.json 354 | dmypy.json 355 | 356 | # Pyre type checker 357 | .pyre/ 358 | 359 | # pytype static type analyzer 360 | .pytype/ 361 | 362 | # Cython debug symbols 363 | cython_debug/ 364 | 365 | # PyCharm 366 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 367 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 368 | # and can be added to the global gitignore or merged into this file. For a more nuclear 369 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 370 | #.idea/ 371 | 372 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm,jupyternotebooks,images 373 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kathrin Seßler and Vadim Borisov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tabular-data-generation-using-binary/tabular-data-generation-on-adult-census)](https://paperswithcode.com/sota/tabular-data-generation-on-adult-census?p=tabular-data-generation-using-binary) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tabular-data-generation-using-binary/tabular-data-generation-on-diabetes)](https://paperswithcode.com/sota/tabular-data-generation-on-diabetes?p=tabular-data-generation-using-binary) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tabular-data-generation-using-binary/tabular-data-generation-on-travel)](https://paperswithcode.com/sota/tabular-data-generation-on-travel?p=tabular-data-generation-using-binary) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tabular-data-generation-using-binary/tabular-data-generation-on-sick)](https://paperswithcode.com/sota/tabular-data-generation-on-sick?p=tabular-data-generation-using-binary) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tabular-data-generation-using-binary/tabular-data-generation-on-california-housing)](https://paperswithcode.com/sota/tabular-data-generation-on-california-housing?p=tabular-data-generation-using-binary) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tabular-data-generation-using-binary/tabular-data-generation-on-heloc)](https://paperswithcode.com/sota/tabular-data-generation-on-heloc?p=tabular-data-generation-using-binary) 7 | 8 | # Tabular Data Generation using Binary Diffusion 9 | 10 | This repository contains the official implementation of the paper "[Tabular Data Generation using Binary Diffusion](https://arxiv.org/abs/2409.13882)", 11 | accepted to [3rd Table Representation Learning Workshop @ NeurIPS 2024](https://table-representation-learning.github.io/). 12 | 13 | # Abstract 14 | 15 | Generating synthetic tabular data is critical in machine learning, especially when real data is limited or sensitive. 16 | Traditional generative models often face challenges due to the unique characteristics of tabular data, such as mixed 17 | data types and varied distributions, and require complex preprocessing or large pretrained models. In this paper, we 18 | introduce a novel, lossless binary transformation method that converts any tabular data into fixed-size binary 19 | representations, and a corresponding new generative model called Binary Diffusion, specifically designed for binary 20 | data. Binary Diffusion leverages the simplicity of XOR operations for noise addition and removal and employs binary 21 | cross-entropy loss for training. Our approach eliminates the need for extensive preprocessing, complex noise parameter 22 | tuning, and pretraining on large datasets. We evaluate our model on several popular tabular benchmark datasets, 23 | demonstrating that Binary Diffusion outperforms existing state-of-the-art models on Travel, Adult Income, and Diabetes 24 | datasets while being significantly smaller in size. 25 | 26 | # Installation 27 | 28 | ## Pip install from repository 29 | 30 | ```bash 31 | pip install git+https://github.com/vkinakh/binary-diffusion-tabular.git 32 | ```` 33 | 34 | ## Local conda installation 35 | 36 | Conda environment was tested on Ubuntu 22.04 LTS and on Mac OS Sonoma 14.6 with M3 chip. 37 | 38 | ```bash 39 | conda env create -f environment.yml 40 | ``` 41 | 42 | ## Local pip installation 43 | 44 | ```bash 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | # Quickstart 49 | 50 | ## Run tests 51 | ```bash 52 | python -m unittest discover -s ./tests 53 | ``` 54 | 55 | ## Training 56 | 57 | To use [train.py](train.py) script, first fill configuration fill. See examples in [configs](configs). Then run: 58 | ```bash 59 | python train.py -c= 60 | ``` 61 | 62 | ## Example training script 63 | ```python 64 | import pandas as pd 65 | import wandb 66 | 67 | from binary_diffusion_tabular import ( 68 | FixedSizeBinaryTableDataset, 69 | SimpleTableGenerator, 70 | BinaryDiffusion1D, 71 | FixedSizeTableBinaryDiffusionTrainer, 72 | drop_fill_na 73 | ) 74 | 75 | df = pd.read_csv("./data/adult_train.csv") 76 | columns_numerical = [ 77 | "age", 78 | "fnlwgt", 79 | "education-num", 80 | "capital-gain", 81 | "capital-loss", 82 | "hours-per-week", 83 | ] 84 | 85 | columns_categorical = [ 86 | "workclass", 87 | "education", 88 | "marital-status", 89 | "occupation", 90 | "relationship", 91 | "race", 92 | "sex", 93 | "native-country" 94 | ] 95 | 96 | task = "classification" 97 | column_target = "label" 98 | 99 | df = drop_fill_na( 100 | df=df, 101 | columns_numerical=columns_numerical, 102 | columns_categorical=columns_categorical, 103 | dropna=True, 104 | fillna=False 105 | ) 106 | 107 | dataset = FixedSizeBinaryTableDataset( 108 | table=df, 109 | target_column=column_target, # conditional generation 110 | split_feature_target=True, 111 | task=task, 112 | numerical_columns=columns_numerical, 113 | categorical_columns=columns_categorical 114 | ) 115 | 116 | classifier_free_guidance = True 117 | target_diffusion = "two_way" 118 | 119 | dim = 256 120 | n_res_blocks = 3 121 | device = "cuda" 122 | 123 | model = SimpleTableGenerator( 124 | data_dim=dataset.row_size, 125 | dim=dim, 126 | n_res_blocks=n_res_blocks, 127 | out_dim=( 128 | dataset.row_size * 2 129 | if target_diffusion == "two_way" 130 | else dataset.row_size 131 | ), 132 | task=task, 133 | conditional=dataset.conditional, 134 | n_classes=0 if task == "regression" else dataset.n_classes, 135 | classifier_free_guidance=classifier_free_guidance, 136 | ).to(device) 137 | 138 | schedule = "quad" 139 | n_timesteps = 1000 140 | 141 | diffusion = BinaryDiffusion1D( 142 | denoise_model=model, 143 | schedule="quad", 144 | n_timesteps=n_timesteps, 145 | target=target_diffusion 146 | ).to(device) 147 | 148 | logger = wandb.init( 149 | project="your-project", 150 | config={"key": "value"}, 151 | name="adult_CFG" 152 | ) 153 | 154 | num_training_steps = 500_000 155 | log_every = 100 156 | save_every = 10_000 157 | save_num_samples = 64 158 | ema_decay = 0.995 159 | ema_update_every = 10 160 | lr = 1e-4 161 | opt_type = "adam" 162 | batch_size = 256 163 | n_workers = 16 164 | zero_token_probability = 0.1 165 | results_folder = "./results/adult_CFG" 166 | 167 | trainer = FixedSizeTableBinaryDiffusionTrainer( 168 | diffusion=diffusion, 169 | dataset=dataset, 170 | train_num_steps=num_training_steps, 171 | log_every=log_every, 172 | save_every=save_every, 173 | save_num_samples=save_num_samples, 174 | max_grad_norm=None, 175 | gradient_accumulate_every=1, 176 | ema_decay=ema_decay, 177 | ema_update_every=ema_update_every, 178 | lr=lr, 179 | opt_type=opt_type, 180 | batch_size=batch_size, 181 | dataloader_workers=n_workers, 182 | classifier_free_guidance=classifier_free_guidance, 183 | zero_token_probability=zero_token_probability, 184 | logger=logger, 185 | results_folder=results_folder 186 | ) 187 | 188 | trainer.train() 189 | ``` 190 | 191 | ## Sampling 192 | 193 | To use [sample.py](sample.py) script, you need a pretrained model and data transformation. Then run 194 | ```bash 195 | python sample.py \ 196 | --ckpt= \ 197 | --ckpt_transformation= \ 198 | --n_timesteps= \ 199 | --out= \ 200 | --n_samples= \ 201 | --batch_size= \ 202 | --threshold= \ # 0.5 default 203 | --strategy= \ # target or mask 204 | --seed= \ # default no seed 205 | --guidance_scale= \ # 0 default, no classifier free guidance 206 | --target_column_name= \ # name of target column, in case of conditional generation 207 | --device= \ 208 | --use_ema # whether to use EMA diffusion model 209 | ``` 210 | 211 | # Results 212 | 213 | # Results 214 | 215 | The table below presents the **Binary Diffusion** results across various datasets and models. Performance metrics are shown as **mean ± standard deviation**. 216 | 217 | | **Dataset** | **LR (Binary Diffusion)** | **DT (Binary Diffusion)** | **RF (Binary Diffusion)** | **Params** | **Model link** | 218 | |-------------------------|---------------------------|---------------------------|---------------------------|------------|-------------------------------------------------------------------------------------| 219 | | **Travel** | **83.79 ± 0.08** | **88.90 ± 0.57** | **89.95 ± 0.44** | **1.1M** | [Link](https://huggingface.co/vitaliykinakh/binary-ddpm-tabular/tree/main/travel) | 220 | | **Sick** | 96.14 ± 0.63 | **97.07 ± 0.24** | 96.59 ± 0.55 | **1.4M** | [Link](https://huggingface.co/vitaliykinakh/binary-ddpm-tabular/tree/main/sick) | 221 | | **HELOC** | 71.76 ± 0.30 | 70.25 ± 0.43 | 70.47 ± 0.32 | **2.6M** | [Link](https://huggingface.co/vitaliykinakh/binary-ddpm-tabular/tree/main/heloc) | 222 | | **Adult Income** | **85.45 ± 0.11** | **85.27 ± 0.11** | **85.74 ± 0.11** | **1.4M** | [Link](https://huggingface.co/vitaliykinakh/binary-ddpm-tabular/tree/main/adult) | 223 | | **Diabetes** | **57.75 ± 0.04** | **57.13 ± 0.15** | 57.52 ± 0.12 | **1.8M** | [Link](https://huggingface.co/vitaliykinakh/binary-ddpm-tabular/tree/main/diabetes) | 224 | | **California Housing** | *0.55 ± 0.00* | 0.45 ± 0.00 | 0.39 ± 0.00 | **1.5M** | [Link](https://huggingface.co/vitaliykinakh/binary-ddpm-tabular/tree/main/housing) | 225 | 226 | --- 227 | 228 | # Citation 229 | ``` 230 | @article{kinakh2024tabular, 231 | title={Tabular Data Generation using Binary Diffusion}, 232 | author={Kinakh, Vitaliy and Voloshynovskiy, Slava}, 233 | journal={arXiv preprint arXiv:2409.13882}, 234 | year={2024} 235 | } 236 | ``` -------------------------------------------------------------------------------- /binary_diffusion_tabular/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .transformation import * 3 | from .dataset import * 4 | from .model import * 5 | from .diffusion import * 6 | from .trainer import * 7 | -------------------------------------------------------------------------------- /binary_diffusion_tabular/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Optional, Dict 2 | 3 | import pandas as pd 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from binary_diffusion_tabular.transformation import ( 9 | FixedSizeBinaryTableTransformation, 10 | TASK, 11 | ) 12 | 13 | 14 | __all__ = ["FixedSizeBinaryTableDataset", "drop_fill_na"] 15 | 16 | 17 | def drop_fill_na( 18 | df: pd.DataFrame, 19 | columns_numerical: List[str], 20 | columns_categorical: List[str], 21 | dropna: bool, 22 | fillna: bool, 23 | ) -> pd.DataFrame: 24 | """Drops or fills NaN values in a dataframe 25 | 26 | Args: 27 | df: dataframe 28 | columns_numerical: numerical column names 29 | columns_categorical: categorical column names 30 | dropna: if True, drops NaN values 31 | fillna: if True, fills NaN values. Numerical columns are replaced with mean. Categorical columns are replaced 32 | with mode. 33 | 34 | Returns: 35 | pd.DataFrame: dataframe with NaN values dropped/filled 36 | """ 37 | 38 | if dropna and fillna: 39 | raise ValueError("Cannot have both dropna and fillna") 40 | 41 | if dropna: 42 | df = df.dropna() 43 | 44 | if fillna: 45 | for col in columns_numerical: 46 | df[col] = df[col].fillna(df[col].mean()) 47 | 48 | # replace na for categorical columns with mode 49 | for col in columns_categorical: 50 | df[col] = df[col].fillna(df[col].mode()[0]) 51 | 52 | return df 53 | 54 | 55 | class FixedSizeBinaryTableDataset(Dataset): 56 | """Pytorch dataset for fixed size binary tables.""" 57 | 58 | def __init__( 59 | self, 60 | *, 61 | table: pd.DataFrame, 62 | target_column: Optional[str] = None, 63 | split_feature_target: bool, 64 | task: TASK, 65 | numerical_columns: List[str] = None, 66 | categorical_columns: List[str] = None, 67 | ): 68 | """ 69 | Args: 70 | table: pandas dataframe with categorical and numerical columns. Dataframe should not have nan 71 | target_column: name of the target column. Optional. Should be provided if split_feature_target is True. 72 | split_feature_target: split features columns and target column 73 | task: task for which dataset is used. Can be 'classification' or 'regression' 74 | numerical_columns: list of columns with numerical values 75 | categorical_columns: list of columns with categorical values 76 | """ 77 | 78 | if numerical_columns is None: 79 | numerical_columns = [] 80 | 81 | if categorical_columns is None: 82 | categorical_columns = [] 83 | 84 | self.table = table 85 | self.target_column = target_column 86 | self.split_feature_target = split_feature_target 87 | self.task = task 88 | self.numerical_columns = numerical_columns 89 | self.categorical_columns = categorical_columns 90 | 91 | self.transformation = FixedSizeBinaryTableTransformation( 92 | task, numerical_columns, categorical_columns 93 | ) 94 | 95 | if self.split_feature_target: 96 | target = self.table[self.target_column] 97 | features = self.table.drop(columns=[self.target_column]) 98 | 99 | self.features_binary, self.targets_binary = ( 100 | self.transformation.fit_transform(features, target) 101 | ) 102 | else: 103 | self.features_binary = self.transformation.fit_transform(self.table) 104 | 105 | @classmethod 106 | def from_config(cls, config: Dict) -> "FixedSizeBinaryTableDataset": 107 | path_table = config["path_table"] 108 | df = pd.read_csv(path_table) 109 | dropna = config["dropna"] 110 | fillna = config["fillna"] 111 | columns_numerical = config["numerical_columns"] 112 | columns_categorical = config["categorical_columns"] 113 | columns_to_drop = config["columns_to_drop"] 114 | task = config["task"] 115 | 116 | if columns_to_drop: 117 | df = df.drop(columns=columns_to_drop) 118 | df = drop_fill_na(df, columns_numerical, columns_categorical, dropna, fillna) 119 | 120 | return cls( 121 | table=df, 122 | target_column=config["target_column"], 123 | task=task, 124 | split_feature_target=config["split_feature_target"], 125 | numerical_columns=config["numerical_columns"], 126 | categorical_columns=config["categorical_columns"], 127 | ) 128 | 129 | @property 130 | def n_classes(self) -> int: 131 | return self.transformation.n_classes 132 | 133 | @property 134 | def row_size(self) -> int: 135 | return self.transformation.row_size 136 | 137 | @property 138 | def conditional(self) -> bool: 139 | return self.split_feature_target 140 | 141 | def __len__(self) -> int: 142 | return len(self.features_binary) 143 | 144 | def __getitem__( 145 | self, idx: int 146 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 147 | row = self.features_binary[idx] 148 | 149 | if self.split_feature_target: 150 | target = self.targets_binary[idx] 151 | return row, target 152 | return row 153 | -------------------------------------------------------------------------------- /binary_diffusion_tabular/diffusion.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Literal, Optional, Callable, Dict, Tuple 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchmetrics.functional import accuracy 9 | 10 | from binary_diffusion_tabular.model import BaseModel, SimpleTableGenerator 11 | 12 | 13 | __all__ = [ 14 | "BinaryDiffusion1D", 15 | "BaseDiffusion", 16 | "SCHEDULE", 17 | "DENOISING_TARGET", 18 | "SAMPLING_STRATEGY", 19 | "make_beta_schedule", 20 | "get_mask_torch", 21 | "flip_values", 22 | ] 23 | 24 | 25 | SCHEDULE = Literal["linear", "quad", "sigmoid"] 26 | DENOISING_TARGET = Literal["mask", "target", "two_way"] 27 | SAMPLING_STRATEGY = Literal["mask", "target", "two_way"] 28 | 29 | 30 | def make_beta_schedule( 31 | schedule: SCHEDULE = "linear", 32 | n_timesteps: int = 1000, 33 | start: float = 1e-5, 34 | end: float = 0.5, 35 | ) -> torch.Tensor: 36 | """Make a beta schedule. 37 | 38 | Args: 39 | schedule: type of schedule to use. Can be "linear", "quad", "sigmoid". 40 | n_timesteps: number of timesteps to use. 41 | start: start value. Defaults to 1e-5. Should be generally close to 0 42 | end: end value. Defaults to 0.5. Should be close to 0.5 43 | 44 | Returns: 45 | torch.Tensor:beta schedule. 46 | """ 47 | 48 | if schedule == "linear": 49 | betas = torch.linspace(start, end, n_timesteps) 50 | elif schedule == "quad": 51 | betas = torch.linspace(start**0.5, end**0.5, n_timesteps) ** 2 52 | elif schedule == "sigmoid": 53 | betas = torch.linspace(-6, 6, n_timesteps) 54 | betas = torch.sigmoid(betas) * (end - start) + start 55 | else: 56 | raise ValueError("Incorrect beta schedule type") 57 | return betas 58 | 59 | 60 | def get_mask_torch(betas: torch.Tensor, shape, device="cuda") -> torch.Tensor: 61 | """Returns masks for a list of betas, each with a given percentage of 1 values 62 | 63 | Args: 64 | betas: tensor containing percentages of 1 values for each mask 65 | shape: shape of each mask 66 | device: the device for the operation, e.g., 'cuda' or 'cpu' 67 | 68 | Returns: 69 | Tensor of masks, one for each beta value 70 | """ 71 | 72 | # Move tensors to the specified device 73 | betas = betas.to(device) 74 | 75 | num_masks = betas.shape[0] 76 | flattened_shape = torch.prod(torch.tensor(shape)).item() 77 | 78 | random_values = torch.rand((num_masks, flattened_shape), device=device) 79 | masks = (random_values < betas.unsqueeze(-1)).int() 80 | return masks.reshape(num_masks, *shape) 81 | 82 | 83 | def flip_values(val): 84 | """Function that changes 0 to 1 and 1 to 0""" 85 | return 1 - val 86 | 87 | 88 | class BaseDiffusion(nn.Module, ABC): 89 | 90 | def __init__( 91 | self, 92 | denoise_model: BaseModel, 93 | ): 94 | super().__init__() 95 | self.model = denoise_model 96 | self.device = next(self.model.parameters()).device 97 | 98 | @classmethod 99 | @abstractmethod 100 | def from_config(cls, denoise_model: BaseModel, config: Dict) -> "BaseDiffusion": 101 | pass 102 | 103 | @property 104 | @abstractmethod 105 | def config(self) -> Dict: 106 | pass 107 | 108 | @abstractmethod 109 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict, Dict]: 110 | pass 111 | 112 | @abstractmethod 113 | def sample( 114 | self, 115 | *, 116 | model_fn: Optional[Callable] = None, 117 | y: Optional[torch.Tensor] = None, 118 | n: int, 119 | ) -> torch.Tensor: 120 | pass 121 | 122 | 123 | class BinaryDiffusion1D(BaseDiffusion): 124 | """Binary Diffusion 1D model.""" 125 | 126 | def __init__( 127 | self, 128 | denoise_model: SimpleTableGenerator, 129 | *, 130 | schedule: SCHEDULE = "linear", 131 | n_timesteps: int, 132 | target: DENOISING_TARGET = "mask", 133 | ): 134 | """ 135 | Args: 136 | denoise_model: denoiser model to use. 137 | schedule: beta schedule to use. Can be "linear", "quad", "sigmoid". 138 | n_timesteps: number of timesteps to use. 139 | size: size of the 1d data 140 | target: what denoiser model predictions to use. Can be "mask", "target", "two_way". 141 | two_way: predict both mask and denoiser target 142 | """ 143 | 144 | super().__init__(denoise_model) 145 | self.size = denoise_model.data_dim 146 | 147 | if target not in ["mask", "target", "two_way"]: 148 | raise ValueError("Incorrect target type") 149 | 150 | if target == "two_way" and self.model.out_dim != 2 * self.size: 151 | raise ValueError( 152 | "Incorrect target size. For `two_way` diffusion output should be 2*size" 153 | ) 154 | 155 | self.target = target 156 | 157 | self.n_timesteps = n_timesteps 158 | self.schedule = schedule 159 | 160 | self.loss = F.binary_cross_entropy_with_logits 161 | self.betas = make_beta_schedule(schedule, n_timesteps, start=1 / self.size).to( 162 | self.device 163 | ) 164 | self.flip_values = flip_values 165 | self.pred_postproc = torch.sigmoid 166 | 167 | @classmethod 168 | def from_config( 169 | cls, denoise_model: SimpleTableGenerator, config: Dict 170 | ) -> "BaseDiffusion": 171 | return cls( 172 | denoise_model, 173 | **config, 174 | ) 175 | 176 | @property 177 | def config(self) -> Dict: 178 | return { 179 | "schedule": self.schedule, 180 | "n_timesteps": self.n_timesteps, 181 | "target": self.target, 182 | } 183 | 184 | @property 185 | def conditional(self) -> bool: 186 | return self.model.conditional 187 | 188 | @property 189 | def classifier_free_guidance(self) -> bool: 190 | return self.model.classifier_free_guidance 191 | 192 | @property 193 | def n_classes(self) -> int: 194 | return self.model.n_classes 195 | 196 | def q_sample( 197 | self, x_0: torch.Tensor, t: torch.Tensor, mask: Optional[torch.Tensor] = None 198 | ) -> torch.Tensor: 199 | if mask is None: 200 | shape = x_0.shape 201 | beta = torch.tensor([self.betas[t]] * shape[0]).to(self.device) 202 | mask = get_mask_torch(beta, shape[1:]).to(self.device) 203 | mask = mask.to(bool).to(self.device) 204 | x_copy = x_0.clone().to(self.device) 205 | x_copy[mask] = self.flip_values(x_copy[mask]) 206 | return x_copy 207 | 208 | def p_sample(self, x: torch.Tensor, mask_pred: torch.Tensor): 209 | mask_pred = mask_pred.to(bool) 210 | x_out = x.clone() 211 | x_out[mask_pred] = self.flip_values(x_out[mask_pred]) 212 | return x_out 213 | 214 | def _apply_sampling_strategy( 215 | self, 216 | x_t: torch.Tensor, 217 | pred_target: torch.Tensor, 218 | pred_mask: torch.Tensor, 219 | t: int, 220 | strategy: SAMPLING_STRATEGY = "target", 221 | ) -> torch.Tensor: 222 | if strategy == "target": 223 | return pred_target.float() 224 | elif strategy == "mask": 225 | return self.p_sample(x_t, pred_mask) 226 | elif strategy == "half-half": 227 | return ( 228 | self.p_sample(x_t, pred_mask) 229 | if t < self.n_timesteps // 2 230 | else pred_target.float() 231 | ) 232 | 233 | @torch.inference_mode() 234 | def p_sample_loop( 235 | self, 236 | n: int, 237 | model_fn: Optional[Callable] = None, 238 | y: Optional[torch.Tensor] = None, 239 | timesteps: Optional[int] = None, 240 | threshold: float = 0.5, 241 | strategy: Optional[SAMPLING_STRATEGY] = None, 242 | ) -> torch.Tensor: 243 | if self.target == "two_way" and strategy is None: 244 | strategy = "target" 245 | 246 | if strategy not in ["target", "mask", "half-half"]: 247 | raise ValueError("Incorrect strategy type") 248 | 249 | if strategy is not None and self.target != "two_way": 250 | raise ValueError("Strategy can only be used with two_way target") 251 | 252 | if timesteps is None: 253 | timesteps = list(range(self.n_timesteps)) 254 | 255 | if model_fn is None: 256 | model_fn = self.model 257 | else: 258 | model_fn = partial(model_fn, model=self.model) 259 | 260 | x_t = torch.randint(0, 2, size=(n, self.size)).float().to(self.device) 261 | for t in reversed(timesteps): 262 | ts = torch.tensor([t] * n).to(self.device) 263 | 264 | pred = model_fn(x_t, ts, y=y) 265 | 266 | if self.target == "two_way": 267 | pred_target, pred_mask = pred.chunk(2, dim=1) 268 | 269 | pred_mask = self.pred_postproc(pred_mask) 270 | pred_target = self.pred_postproc(pred_target) 271 | 272 | pred_mask = pred_mask > threshold 273 | pred_target = pred_target > threshold 274 | 275 | x_t = self._apply_sampling_strategy( 276 | x_t, pred_target, pred_mask, t, strategy 277 | ) 278 | elif self.target == "target": 279 | pred = self.pred_postproc(pred) 280 | pred = pred > threshold 281 | x_t = pred.float() 282 | else: 283 | pred = self.pred_postproc(pred) 284 | pred = pred > threshold 285 | x_t = self.p_sample(x_t, pred) 286 | 287 | if t != 0: 288 | beta = torch.tensor([self.betas[t]] * n).to(self.device) 289 | mask = get_mask_torch(beta, x_t.shape[1:], self.device) 290 | x_t = self.q_sample(x_t, t, mask) 291 | 292 | return x_t 293 | 294 | @torch.inference_mode() 295 | def sample( 296 | self, 297 | *, 298 | model_fn: Optional[Callable] = None, 299 | y: Optional[torch.Tensor] = None, 300 | n: int, 301 | timesteps: Optional[int] = None, 302 | threshold: float = 0.5, 303 | strategy: SAMPLING_STRATEGY = "target", 304 | ) -> torch.Tensor: 305 | """Samples data 306 | 307 | Args: 308 | model_fn: denoising model to use 309 | y: optional conditioning to use 310 | n: number of samples to generate 311 | timesteps: number of timesteps to use during sampling 312 | threshold: threshold to use for sampling 313 | strategy: sampling strategy to use. Choices: target, mask, half-half 314 | 315 | Returns: 316 | torch.Tensor: sampled data 317 | """ 318 | 319 | x = self.p_sample_loop( 320 | n=n, 321 | model_fn=model_fn, 322 | y=y, 323 | timesteps=timesteps, 324 | threshold=threshold, 325 | strategy=strategy, 326 | ) 327 | return x 328 | 329 | def forward( 330 | self, x: torch.Tensor, y: Optional[torch.Tensor] = None 331 | ) -> Tuple[torch.Tensor, Dict, Dict]: 332 | """Runs binary diffusion model training step 333 | 334 | Model selects random timesteps, adds binary noise to data samples, runs denoiser and computed losses and 335 | accuracies 336 | 337 | Args: 338 | x: input data. Shape (BS, data_dim) 339 | 340 | y: optional conditioning. Shape (BS, ...) 341 | 342 | Returns: 343 | torch.Tensor, Dict, Dict: training loss, losses to log, accuracies to log 344 | """ 345 | 346 | bs = x.shape[0] 347 | sample_shape = x.shape[1:] 348 | 349 | # select random t step 350 | t = torch.randint(0, self.n_timesteps, size=(bs,)).to(self.device) 351 | 352 | # sample mask 353 | beta = self.betas[t].to(self.device) 354 | mask = get_mask_torch(beta, sample_shape, self.device) 355 | 356 | x_t = self.q_sample(x, t, mask) 357 | pred = self.model(x_t, t, y=y) 358 | 359 | if self.target == "mask": 360 | loss = self.loss(pred, mask.float()) 361 | acc = accuracy(self.pred_postproc(pred), mask, task="binary") 362 | losses = {"loss": loss} 363 | accs = {"acc": acc} 364 | elif self.target == "target": 365 | loss = self.loss(pred, x) 366 | acc = accuracy(self.pred_postproc(pred), x, task="binary") 367 | losses = {"loss": loss} 368 | accs = {"acc": acc} 369 | else: 370 | pred_target, pred_mask = pred.chunk(2, dim=1) 371 | loss_target = self.loss(pred_target, x) 372 | loss_mask = self.loss(pred_mask, mask.float()) 373 | loss = loss_target + loss_mask 374 | acc_target = accuracy(self.pred_postproc(pred_target), x, task="binary") 375 | acc_mask = accuracy(self.pred_postproc(pred_mask), mask, task="binary") 376 | 377 | losses = {"loss_target": loss_target, "loss_mask": loss_mask} 378 | accs = {"acc_target": acc_target, "acc_mask": acc_mask} 379 | 380 | return loss, losses, accs 381 | -------------------------------------------------------------------------------- /binary_diffusion_tabular/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Dict 3 | import math 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from binary_diffusion_tabular import TASK 9 | 10 | 11 | __all__ = ["BaseModel", "SimpleTableGenerator"] 12 | 13 | 14 | class SinusoidalPosEmb(nn.Module): 15 | def __init__(self, dim: int): 16 | super().__init__() 17 | self.dim = dim 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | device = x.device 21 | half_dim = self.dim // 2 22 | emb = math.log(10000) / (half_dim - 1) 23 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 24 | emb = x[:, None] * emb[None, :] 25 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 26 | return emb 27 | 28 | 29 | class Residual(nn.Module): 30 | """Residual layer with timestep embedding.""" 31 | 32 | def __init__( 33 | self, i: int, o: int, time_emb_dim: Optional[int] = None, use_bias: bool = True 34 | ): 35 | super(Residual, self).__init__() 36 | self.fc = nn.Linear(i, o) 37 | self.bn = nn.BatchNorm1d(o) 38 | self.relu = nn.ReLU() 39 | self.use_bias = use_bias 40 | 41 | # Timestep embedding MLP 42 | if time_emb_dim is not None: 43 | self.mlp = nn.Sequential( 44 | nn.SiLU(), nn.Linear(time_emb_dim, o * 2, bias=use_bias) 45 | ) 46 | else: 47 | self.mlp = None 48 | 49 | def forward( 50 | self, x: torch.Tensor, time_emb: Optional[torch.Tensor] = None 51 | ) -> torch.Tensor: 52 | out = self.fc(x) 53 | out = self.bn(out) 54 | out = self.relu(out) 55 | 56 | # Apply timestep embedding if available 57 | if self.mlp is not None and time_emb is not None: 58 | time_emb = self.mlp(time_emb) 59 | scale, shift = time_emb.chunk(2, dim=1) 60 | out = out * scale + shift 61 | 62 | return torch.cat([out, x], dim=1) 63 | 64 | 65 | class BaseModel(nn.Module, ABC): 66 | 67 | def __init__( 68 | self, 69 | data_dim: int, 70 | out_dim: int, 71 | ): 72 | super().__init__() 73 | self.data_dim = data_dim 74 | self.out_dim = out_dim 75 | 76 | @classmethod 77 | @abstractmethod 78 | def from_config(cls, config: Dict) -> "BaseModel": 79 | pass 80 | 81 | @property 82 | @abstractmethod 83 | def config(self) -> Dict: 84 | pass 85 | 86 | @abstractmethod 87 | def forward( 88 | self, 89 | x: torch.Tensor, 90 | t: torch.Tensor, 91 | ) -> torch.Tensor: 92 | pass 93 | 94 | 95 | class SimpleTableGenerator(BaseModel): 96 | """Simple denoiser model for table generation 97 | 98 | Model works with 1d signals of fixed size""" 99 | 100 | def __init__( 101 | self, 102 | data_dim: int, 103 | dim: int, 104 | n_res_blocks: int, 105 | out_dim: int, 106 | task: TASK, 107 | conditional: bool = False, 108 | n_classes: int = 0, 109 | classifier_free_guidance: bool = False, 110 | ): 111 | """ 112 | Args: 113 | data_dim: dimension of data 114 | dim: internal dimensionality 115 | n_res_blocks: number of residual blocks 116 | out_dim: number of output dimensions 117 | task: task to generate data for. Options: classification, regression 118 | conditional: if True, generative model is conditional 119 | n_classes: number of classes for classification 120 | classifier_free_guidance: if True, classifier free guidance is used during sampling 121 | """ 122 | 123 | if task not in ["classification", "regression"]: 124 | raise ValueError(f"Invalid task: {task}") 125 | 126 | if task == "classification" and conditional and n_classes <= 0: 127 | raise ValueError("n_classes must be greater than 0 for classification") 128 | 129 | super(SimpleTableGenerator, self).__init__(data_dim, out_dim) 130 | 131 | self.dim = dim 132 | self.n_res_blocks = n_res_blocks 133 | self.n_classes = n_classes 134 | self.classifier_free_guidance = classifier_free_guidance 135 | self.conditional = conditional 136 | self.task = task 137 | 138 | time_dim = dim 139 | self.time_mlp = nn.Sequential( 140 | SinusoidalPosEmb(time_dim), 141 | nn.Linear(time_dim, time_dim, bias=True), 142 | nn.GELU(), 143 | nn.Linear(time_dim, time_dim, bias=True), 144 | ) 145 | 146 | if self.task == "classification": 147 | if self.conditional: 148 | if n_classes > 0: 149 | if self.classifier_free_guidance: 150 | self.class_emb = nn.Linear(n_classes, dim, bias=True) 151 | else: 152 | self.class_emb = nn.Embedding(n_classes, dim) 153 | else: 154 | if self.conditional: 155 | # Regression task 156 | self.cond_emb = nn.Linear(1, dim, bias=True) 157 | 158 | self.data_proj = nn.Linear(self.data_dim, dim, bias=True) 159 | 160 | item = dim 161 | self.blocks = nn.ModuleList([]) 162 | for _ in range(n_res_blocks): 163 | self.blocks.append(Residual(dim, item, time_emb_dim=time_dim)) 164 | dim += item 165 | 166 | self.out = nn.Linear(dim, self.out_dim, bias=True) 167 | 168 | @property 169 | def config(self) -> Dict: 170 | """Returns model configuration in dictionary 171 | 172 | Returns: 173 | dict: model configuration with the following keys: 174 | data_dim: dimension of data 175 | dim: internal dimensionality 176 | n_res_blocks: number of residual blocks 177 | out_dim: number of output dimensions 178 | task: task to generate data for. Options: classification, regression 179 | conditional: if True, generative model is conditional 180 | n_classes: number of classes for classification 181 | classifier_free_guidance: if True, classifier free guidance is used during sampling 182 | """ 183 | 184 | return { 185 | "data_dim": self.data_dim, 186 | "dim": self.dim, 187 | "n_res_blocks": self.n_res_blocks, 188 | "out_dim": self.out_dim, 189 | "task": self.task, 190 | "conditional": self.conditional, 191 | "n_classes": self.n_classes, 192 | "classifier_free_guidance": self.classifier_free_guidance, 193 | } 194 | 195 | @classmethod 196 | def from_config(cls, config: Dict) -> "SimpleTableGenerator": 197 | return cls(**config) 198 | 199 | def forward( 200 | self, 201 | x: torch.Tensor, 202 | t: torch.Tensor, 203 | y: Optional[torch.Tensor] = None, 204 | *args, 205 | **kwargs, 206 | ) -> torch.Tensor: 207 | """Run denoising step for table generation 208 | 209 | Args: 210 | x: noisy input data (BS, data_dim) 211 | t: timesteps (BS,) 212 | y: conditional input data (BS,) or (BS, n_classes) if classifier free guidance 213 | *args: 214 | **kwargs: 215 | 216 | Returns: 217 | torch.Tensor: denoised data 218 | """ 219 | 220 | t = self.time_mlp(t) 221 | 222 | if y is not None and hasattr(self, "cond_emb"): 223 | y = self.cond_emb(y) 224 | t = t + y 225 | 226 | if y is not None and hasattr(self, "class_emb"): 227 | y = self.class_emb(y) 228 | t = t + y 229 | 230 | x = self.data_proj(x) 231 | 232 | for block in self.blocks: 233 | x = block(x, t) 234 | return self.out(x) 235 | -------------------------------------------------------------------------------- /binary_diffusion_tabular/trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Optional, Literal, Any 3 | from pathlib import Path 4 | from collections import defaultdict 5 | 6 | import accelerate 7 | from tqdm.auto import tqdm 8 | 9 | import torch 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader 12 | import torch.nn.functional as F 13 | from accelerate import Accelerator 14 | from ema_pytorch import EMA 15 | import wandb 16 | 17 | from binary_diffusion_tabular.model import SimpleTableGenerator 18 | from binary_diffusion_tabular.diffusion import BaseDiffusion, BinaryDiffusion1D 19 | from binary_diffusion_tabular.dataset import FixedSizeBinaryTableDataset 20 | from binary_diffusion_tabular.utils import ( 21 | PathOrStr, 22 | exists, 23 | cycle, 24 | zero_out_randomly, 25 | get_base_model, 26 | save_config, 27 | get_random_labels, 28 | ) 29 | 30 | 31 | __all__ = ["BaseTrainer", "FixedSizeTableBinaryDiffusionTrainer"] 32 | 33 | 34 | OPTIMIZERS = Literal["adam", "adamw"] 35 | 36 | 37 | class BaseTrainer(ABC): 38 | """Base class for training.""" 39 | 40 | def __init__( 41 | self, 42 | *, 43 | diffusion: BaseDiffusion, 44 | train_num_steps: int = 200_000, 45 | log_every: int = 100, 46 | save_every: int = 10_000, 47 | save_num_samples: int = 64, 48 | max_grad_norm: Optional[float] = None, 49 | gradient_accumulate_every: int = 1, 50 | ema_decay: float = 0.995, 51 | ema_update_every: int = 10, 52 | lr: float = 3e-4, 53 | opt_type: OPTIMIZERS, 54 | opt_params: Dict[str, Any] = None, 55 | batch_size: int = 256, 56 | dataloader_workers: int = 16, 57 | logger, 58 | results_folder: PathOrStr, 59 | ): 60 | """ 61 | Args: 62 | diffusion: diffusion model to train, should be a BaseDiffusion subclass 63 | train_num_steps: number of training steps. Default is 200000 64 | log_every: log every n steps. Default is 100 65 | save_every: saving generated samples frequency. Default is 10000 66 | save_num_samples: number of samples save. Default is 64 67 | max_grad_norm: norm to clip gradients. Defaults to None, which means no clipping 68 | gradient_accumulate_every: gradient accumulation frequency. Defaults to 1 69 | ema_decay: decay factor for EMA updates. Defaults to 0.995 70 | ema_update_every: ema update frequency. Defaults to 10 71 | lr: learning rate. Defaults to 3e-4 72 | opt_type: optimizer type. Can be "adam" or "adamw" 73 | opt_params: optimizer parameters. See each optimizer parameters 74 | batch_size: batch size. Defaults to 256 75 | dataloader_workers: number of dataloader workers. Defaults to 16 76 | logger: wandb logger to use 77 | results_folder: results folder, where to save samples and trained checkpoints 78 | """ 79 | 80 | self.diffusion = diffusion 81 | 82 | self.train_num_steps = train_num_steps 83 | self.log_every = log_every 84 | self.save_every = save_every 85 | self.save_num_samples = save_num_samples 86 | self.max_grad_norm = max_grad_norm 87 | self.gradient_accumulate_every = gradient_accumulate_every 88 | self.ema_decay = ema_decay 89 | self.ema_update_every = ema_update_every 90 | self.lr = lr 91 | self.opt_type = opt_type 92 | self.opt_params = {} if opt_params is None else opt_params 93 | self.batch_size = batch_size 94 | self.dataloader_workers = dataloader_workers 95 | self.accelerator = Accelerator( 96 | gradient_accumulation_steps=self.gradient_accumulate_every 97 | ) 98 | self.device = self.accelerator.device 99 | self.ema = EMA( 100 | self.diffusion, beta=self.ema_decay, update_every=self.ema_update_every 101 | ).to(self.device) 102 | 103 | self.opt = self._create_optimizer() 104 | 105 | if self.accelerator.is_main_process: 106 | self.results_folder = Path(results_folder) 107 | self.results_folder.mkdir(exist_ok=True, parents=True) 108 | self.logger = logger 109 | 110 | self.step = 0 111 | 112 | @classmethod 113 | @abstractmethod 114 | def from_config(cls, config: Dict[str, Any]) -> "BaseTrainer": 115 | pass 116 | 117 | @classmethod 118 | @abstractmethod 119 | def from_checkpoint(cls, checkpoint: PathOrStr) -> "BaseTrainer": 120 | pass 121 | 122 | @abstractmethod 123 | def train(self) -> None: 124 | pass 125 | 126 | def load_checkpoint(self, path_checkpoint: PathOrStr) -> None: 127 | ckpt = torch.load(path_checkpoint) 128 | 129 | model = self.accelerator.unwrap_model(self.diffusion) 130 | model.load_state_dict(ckpt["model"]) 131 | 132 | self.step = ckpt["step"] 133 | self.opt.load_state_dict(ckpt["opt"]) 134 | 135 | try: 136 | self.ema.load_state_dict(ckpt["ema"]) 137 | except: 138 | for name, param in ckpt["ema"].items(): 139 | if name == "initted" or name == "step": 140 | ckpt["ema"][name] = param.unsqueeze( 141 | 0 142 | ) # Convert from shape [] to [1] 143 | 144 | # Load the adjusted state dict 145 | self.ema.load_state_dict(ckpt["ema"]) 146 | 147 | if exists(self.accelerator.scaler) and exists(ckpt["scaler"]): 148 | self.accelerator.scaler.load_state_dict(ckpt["scaler"]) 149 | 150 | self.diffusion, self.opt = self.accelerator.prepare(model, self.opt) 151 | print(f"Loaded model from {path_checkpoint}") 152 | 153 | def save_checkpoint(self, milestone) -> None: 154 | if not self.accelerator.is_local_main_process: 155 | return 156 | 157 | config = { 158 | "train_num_steps": self.train_num_steps, 159 | "log_every": self.log_every, 160 | "save_every": self.save_every, 161 | "save_num_samples": self.save_num_samples, 162 | "max_grad_norm": self.max_grad_norm, 163 | "gradient_accumulate_every": self.gradient_accumulate_every, 164 | "ema_decay": self.ema_decay, 165 | "ema_update_every": self.ema_update_every, 166 | "lr": self.lr, 167 | "opt_type": self.opt_type, 168 | "opt_params": self.opt_params, 169 | } 170 | 171 | data = { 172 | "step": self.step, 173 | "diffusion": self.accelerator.get_state_dict(self.diffusion), 174 | # save diffusion and model configs for easy loading without dataset preprocessing 175 | "config_diffusion": self.diffusion.config, 176 | "config_model": self.diffusion.model.config, 177 | "opt": self.opt.state_dict(), 178 | "diffusion_ema": self.ema.state_dict(), 179 | "scaler": ( 180 | self.accelerator.scaler.state_dict() 181 | if exists(self.accelerator.scaler) 182 | else None 183 | ), 184 | # save train config as well 185 | "config_train": config, 186 | } 187 | 188 | torch.save(data, str(self.results_folder / f"model-{milestone}.pt")) 189 | 190 | @abstractmethod 191 | def sample_save_samples(self, milestone, *args, **kwargs) -> None: 192 | pass 193 | 194 | @abstractmethod 195 | def _create_dataloader(self): 196 | pass 197 | 198 | def _create_optimizer(self): 199 | if hasattr(self.diffusion, "model"): 200 | params = self.diffusion.model.parameters() 201 | else: 202 | params = self.diffusion.parameters() 203 | 204 | if self.opt_type == "adam": 205 | opt = optim.Adam(params, lr=self.lr, **self.opt_params) 206 | elif self.opt_type == "adamw": 207 | opt = optim.AdamW(params, lr=self.lr, **self.opt_params) 208 | else: 209 | raise ValueError(f"Unknown optimizer type: {self.opt_type}") 210 | return opt 211 | 212 | 213 | class FixedSizeTableBinaryDiffusionTrainer(BaseTrainer): 214 | """Trainer for binary diffusion""" 215 | 216 | def __init__( 217 | self, 218 | *, 219 | diffusion: BinaryDiffusion1D, 220 | dataset: FixedSizeBinaryTableDataset, 221 | train_num_steps: int = 200_000, 222 | log_every: int = 100, 223 | save_every: int = 10_000, 224 | save_num_samples: int = 64, 225 | max_grad_norm: Optional[float] = None, 226 | gradient_accumulate_every: int = 1, 227 | ema_decay: float = 0.995, 228 | ema_update_every: int = 10, 229 | lr: float = 3e-4, 230 | opt_type: OPTIMIZERS, 231 | opt_params: Dict[str, Any] = None, 232 | batch_size: int = 256, 233 | dataloader_workers: int = 16, 234 | classifier_free_guidance: bool, 235 | zero_token_probability: float = 0.0, 236 | logger, 237 | results_folder: PathOrStr, 238 | ): 239 | """ 240 | Args: 241 | diffusion: diffusion model to train, should be a BinaryDiffusion1D 242 | dataset: dataset to train on, should be a FixedSizeBinaryTableDataset 243 | train_num_steps: number of training steps. Default is 200000 244 | log_every: log every n steps. Default is 100 245 | save_every: saving generated samples frequency. Default is 10000 246 | save_num_samples: number of samples save. Default is 64 247 | max_grad_norm: norm to clip gradients. Defaults to None, which means no clipping 248 | gradient_accumulate_every: gradient accumulation frequency. Defaults to 1 249 | ema_decay: decay factor for EMA updates. Defaults to 0.995 250 | ema_update_every: ema update frequency. Defaults to 10 251 | lr: learning rate. Defaults to 3e-4 252 | opt_type: optimizer type. Can be "adam" or "adamw" 253 | opt_params: optimizer parameters. See each optimizer parameters 254 | batch_size: batch size. Defaults to 256 255 | dataloader_workers: number of dataloader workers. Defaults to 16 256 | classifier_free_guidance: if True classifier free guidance is applied, when training 257 | zero_token_probability: zero token probability for classifier free guidance training. Defaults to 0.0 258 | logger: wandb logger to use 259 | results_folder: results folder, where to save samples and trained checkpoints 260 | """ 261 | 262 | if not (dataset.split_feature_target == diffusion.conditional): 263 | raise ValueError( 264 | "split_feature_target must be same as diffusion.conditional" 265 | ) 266 | 267 | if classifier_free_guidance and zero_token_probability == 0: 268 | raise ValueError( 269 | "zero_token_probability must be non-zero when classifier_free_guidance is True" 270 | ) 271 | 272 | if not (diffusion.classifier_free_guidance == classifier_free_guidance): 273 | raise ValueError( 274 | "classifier_free_guidance must be same as diffusion.classifier_free_guidance" 275 | ) 276 | self.conditional = diffusion.conditional 277 | self.classifier_free_guidance = classifier_free_guidance 278 | self.n_classes = diffusion.n_classes 279 | self.task = dataset.task 280 | self.zero_token_probability = zero_token_probability 281 | 282 | super().__init__( 283 | diffusion=diffusion, 284 | train_num_steps=train_num_steps, 285 | log_every=log_every, 286 | save_every=save_every, 287 | save_num_samples=save_num_samples, 288 | max_grad_norm=max_grad_norm, 289 | gradient_accumulate_every=gradient_accumulate_every, 290 | ema_decay=ema_decay, 291 | ema_update_every=ema_update_every, 292 | lr=lr, 293 | opt_type=opt_type, 294 | opt_params=opt_params, 295 | batch_size=batch_size, 296 | dataloader_workers=dataloader_workers, 297 | logger=logger, 298 | results_folder=results_folder, 299 | ) 300 | 301 | self.dataset = dataset 302 | self.transformation = dataset.transformation 303 | self.dataloader = self.accelerator.prepare(self._create_dataloader()) 304 | self.dataloader = cycle(self.dataloader) 305 | 306 | if self.task == "classification" and not ( 307 | self.dataset.n_classes == self.diffusion.n_classes 308 | ): 309 | raise RuntimeError("dataset.n_classes must equal diffusion.n_classes") 310 | 311 | # save transformation in joblib format 312 | self.dataset.transformation.save_checkpoint(self.results_folder / "transformation.joblib") 313 | 314 | @classmethod 315 | def from_checkpoint( 316 | cls, path_checkpoint: PathOrStr 317 | ) -> "FixedSizeTableBinaryDiffusionTrainer": 318 | """Loads trainer from checkpoint. 319 | 320 | Args: 321 | path_checkpoint: path to the checkpoint 322 | 323 | Returns: 324 | FixedSizeTableBinaryDiffusionTrainer: trainer 325 | """ 326 | 327 | ckpt = torch.load(path_checkpoint) 328 | config = ckpt["config_train"] 329 | logger = wandb.init( 330 | project="binary-diffusion-tabular", config=config, name=config["comment"] 331 | ) 332 | trainer = FixedSizeTableBinaryDiffusionTrainer.from_config(config, logger) 333 | 334 | trainer.load_checkpoint(path_checkpoint) 335 | return trainer 336 | 337 | @classmethod 338 | def from_config( 339 | cls, config: Dict[str, Any], logger 340 | ) -> "FixedSizeTableBinaryDiffusionTrainer": 341 | """Builds trainer, model, diffusion and dataset from config. 342 | 343 | Config should have the following structure: 344 | 345 | data: 346 | path_table: path to the csv file with table data 347 | numerical_columns: list of numerical column names 348 | categorical_columns: list of categorical column names 349 | columns_to_drop: list of column names to drop 350 | dropna: if True, will drop columns with NaNs 351 | fillna: if True, will fill NaNs. Numerical replaced with mean, categorical replaced with mode 352 | target_column: optional target column name, should be provided for conditional training 353 | split_feature_target: if True, will split the feature target into training and test sets, should be True for conditional training 354 | task: task for which dataset is used. Options: classification, regression 355 | 356 | model: 357 | dim: internal dimension of model 358 | n_res_blocks: number of residual blocks to use 359 | 360 | other parameters are filled from dataset 361 | 362 | diffusion: 363 | schedule: noise schedule for diffusion. Options: linear, quad, sigmoid 364 | n_timesteps: number of diffusion steps 365 | target: target for diffusion. Options: mask, target, two_way 366 | 367 | trainer: 368 | train_num_steps: number of training steps 369 | log_every: log every n steps 370 | save_every: saving generated samples frequency 371 | save_num_samples: number of samples save 372 | max_grad_norm: norm to clip gradients. If None, no clipping 373 | gradient_accumulate_every: gradient accumulation frequency 374 | ema_decay: decay factor for EMA updates 375 | ema_update_every: ema update frequency 376 | lr: learning rate 377 | opt_type: optimizer type. Can be "adam" or "adamw" 378 | opt_params: optimizer parameters. See each optimizer parameters 379 | batch_size: batch size 380 | dataloader_workers: number of dataloader workers 381 | classifier_free_guidance: if True classifier free guidance is applied, when training 382 | zero_token_probability: zero token probability for classifier free guidance training 383 | 384 | comment: 385 | 386 | Args: 387 | config: config with parameters for dataset, denoising model, diffusion and trainer 388 | logger: wandb logger. Create by `wandb.init(project=)` 389 | 390 | Returns: 391 | FixedSizeTableBinaryDiffusionTrainer: trainer 392 | """ 393 | 394 | config_data = config["data"] 395 | config_model = config["model"] 396 | config_diffusion = config["diffusion"] 397 | config_trainer = config["trainer"] 398 | 399 | task = config_data["task"] 400 | dataset = FixedSizeBinaryTableDataset.from_config(config_data) 401 | 402 | classifier_free_guidance = config_trainer["classifier_free_guidance"] 403 | diffusion_target = config_diffusion["target"] 404 | 405 | device = accelerate.Accelerator().device 406 | 407 | # row_size is given from FixedSizeBinaryTableDataset 408 | # later SimpleTableGenerator can be loaded from config 409 | model = SimpleTableGenerator( 410 | data_dim=dataset.row_size, 411 | out_dim=( 412 | dataset.row_size * 2 413 | if diffusion_target == "two_way" 414 | else dataset.row_size 415 | ), 416 | task=dataset.task, 417 | conditional=dataset.conditional, 418 | n_classes=0 if task == "regression" else dataset.n_classes, 419 | classifier_free_guidance=classifier_free_guidance, 420 | **config_model, 421 | ).to(device) 422 | 423 | diffusion = BinaryDiffusion1D( 424 | denoise_model=model, 425 | **config_diffusion, 426 | ).to(device) 427 | 428 | comment = config["comment"] 429 | results_folder = Path(f"results/{comment}") 430 | results_folder.mkdir(parents=True, exist_ok=True) 431 | 432 | # save config as yaml file in results_folder 433 | save_config(config, results_folder / "config.yaml") 434 | 435 | if logger is None: 436 | logger = wandb.init( 437 | project="binary-diffusion-tabular", config=config, name=comment 438 | ) 439 | 440 | return cls( 441 | diffusion=diffusion, 442 | dataset=dataset, 443 | results_folder=results_folder, 444 | logger=logger, 445 | **config_trainer, 446 | ) 447 | 448 | def train(self) -> None: 449 | self.diffusion.to(self.device) 450 | self.diffusion.train() 451 | 452 | with tqdm( 453 | initial=self.step, 454 | total=self.train_num_steps, 455 | disable=not self.accelerator.is_main_process, 456 | ) as pbar: 457 | while self.step < self.train_num_steps: 458 | total_loss = defaultdict(float) 459 | total_acc = defaultdict(float) 460 | 461 | with self.accelerator.accumulate(self.diffusion): 462 | inp = next(self.dataloader) 463 | if self.conditional: 464 | data, label = inp 465 | label = self._preprocess_labels(label) 466 | else: 467 | data = inp 468 | label = None 469 | 470 | with self.accelerator.autocast(): 471 | loss, losses, accs = self.diffusion(x=data, y=label) 472 | loss = loss / self.gradient_accumulate_every 473 | 474 | gathered_losses = {} 475 | gathered_accs = {} 476 | 477 | for key in losses: 478 | gathered_losses[key] = self.accelerator.gather( 479 | losses[key].detach() 480 | ) 481 | 482 | for key in accs: 483 | gathered_accs[key] = self.accelerator.gather( 484 | accs[key].detach() 485 | ) 486 | 487 | self.accelerator.wait_for_everyone() 488 | 489 | if self.max_grad_norm is not None: 490 | self.accelerator.clip_grad_norm_( 491 | self.diffusion.parameters(), self.max_grad_norm 492 | ) 493 | 494 | self.accelerator.backward(loss) 495 | self.opt.step() 496 | self.opt.zero_grad() 497 | 498 | if self.accelerator.is_main_process: 499 | message = f"Loss: {loss.item():.4f}" 500 | for key in gathered_accs: 501 | acc_val = torch.mean(gathered_accs[key]).item() 502 | total_acc[key] += acc_val 503 | message += f" | {key}: {acc_val:.4f}" 504 | 505 | for key in gathered_losses: 506 | loss_val = torch.mean(gathered_losses[key]).item() 507 | total_loss[key] += loss_val 508 | 509 | pbar.set_description(message) 510 | 511 | self.ema.update() 512 | if self.step % self.log_every == 0: 513 | log_dict = {} 514 | 515 | for key in total_loss: 516 | log_dict[key] = total_loss[key] 517 | 518 | for key in total_acc: 519 | log_dict[key] = total_acc[key] 520 | 521 | self.logger.log(log_dict) 522 | 523 | if self.step != 0 and self.step % self.save_every == 0: 524 | milestone = self.step // self.save_every 525 | self.sample_save_samples(milestone) 526 | self.accelerator.wait_for_everyone() 527 | self.save_checkpoint(milestone) 528 | 529 | self.step += 1 530 | pbar.update(1) 531 | 532 | if self.accelerator.is_main_process: 533 | self.sample_save_samples("final") 534 | # save final model 535 | self.accelerator.wait_for_everyone() 536 | self.save_checkpoint("final") 537 | 538 | @torch.inference_mode() 539 | def sample_save_samples(self, milestone) -> None: 540 | base_model = get_base_model(self.diffusion) 541 | base_model.eval() 542 | base_model_ema = get_base_model(self.ema.ema_model) 543 | base_model_ema.eval() 544 | 545 | with self.accelerator.autocast(): 546 | labels_val = get_random_labels( 547 | conditional=self.conditional, 548 | task=self.task, 549 | n_classes=self.n_classes, 550 | classifier_free_guidance=self.classifier_free_guidance, 551 | n_labels=self.save_num_samples, 552 | device=self.device, 553 | ) 554 | 555 | # sampling without 556 | rows = base_model.sample(n=self.save_num_samples, y=labels_val) 557 | rows_ema = base_model_ema.sample(n=self.save_num_samples, y=labels_val) 558 | 559 | if self.conditional: 560 | if self.classifier_free_guidance: 561 | labels_val = torch.argmax(labels_val, dim=1).detach() 562 | 563 | rows_df, labels_df = self.transformation.inverse_transform(rows, labels_val) 564 | rows_ema_df, labels_ema_df = self.transformation.inverse_transform( 565 | rows_ema, labels_val 566 | ) 567 | rows_df[self.dataset.target_column] = labels_df 568 | rows_ema_df[self.dataset.target_column] = labels_ema_df 569 | else: 570 | rows_df = self.transformation.transform(rows) 571 | rows_ema_df = self.transformation.transform(rows_ema) 572 | 573 | rows_df.to_csv(self.results_folder / f"samples_{milestone}.csv", index=False) 574 | rows_ema_df.to_csv( 575 | self.results_folder / f"samples_{milestone}_ema.csv", index=False 576 | ) 577 | 578 | def _preprocess_labels(self, label: torch.Tensor) -> torch.Tensor: 579 | if self.task == "regression" and len(label.shape) == 1: 580 | label = label.unsqueeze(1) 581 | 582 | if self.classifier_free_guidance: 583 | if self.task == "classification": 584 | label = F.one_hot(label.long(), num_classes=self.n_classes).to( 585 | torch.float 586 | ) 587 | label = zero_out_randomly(label, self.zero_token_probability) 588 | else: 589 | # regression 590 | # -1 is zero-token for regression 591 | mask = torch.rand_like(label) < self.zero_token_probability 592 | label[mask] = -1 593 | 594 | return label.to(self.device) 595 | 596 | def _create_dataloader(self) -> DataLoader: 597 | return DataLoader( 598 | self.dataset, 599 | batch_size=self.batch_size, 600 | num_workers=self.dataloader_workers, 601 | pin_memory=True, 602 | shuffle=True, 603 | ) 604 | -------------------------------------------------------------------------------- /binary_diffusion_tabular/transformation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Literal, Optional, Dict, Union 2 | from joblib import Parallel, delayed 3 | import math 4 | import joblib 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler 9 | 10 | import torch 11 | 12 | from binary_diffusion_tabular import TASK, PathOrStr 13 | 14 | 15 | __all__ = [ 16 | "FixedSizeBinaryTableTransformation", 17 | ] 18 | 19 | 20 | COLUMN_DTYPE = Literal["numerical", "categorical"] 21 | LABELS = Union[np.ndarray, pd.Series, torch.Tensor] 22 | 23 | 24 | def column_to_fixed_size_binary( 25 | column: pd.Series, 26 | dtype: COLUMN_DTYPE, 27 | metadata: Optional[Dict] = None, 28 | size: Optional[int] = None, 29 | ) -> Tuple[pd.Series, Dict, int]: 30 | """ 31 | Convert a pandas DataFrame column to fixed-size binary representation with automatic size calculation, 32 | and return relevant metadata. 33 | 34 | Args: 35 | column (pd.Series): DataFrame column to be converted. 36 | dtype (COLUMN_DTYPE): The type of data ('numerical', 'categorical'). 37 | metadata (dict): Metadata necessary for conversion (min-max for numerical, mapping for categorical). 38 | size (int): The size of the fixed size binary. 39 | 40 | Returns: 41 | tuple: A tuple containing the converted column and metadata (min-max for numerical, mapping for categorical). 42 | 43 | Notes: 44 | size is calculated automatically. For categorical data, the size is calculated as the log2 of the number of 45 | unique values. For numerical data, the size is set to 32. 46 | """ 47 | 48 | if dtype == "categorical": 49 | unique_values = len(column.unique()) 50 | if size is None: 51 | size = math.ceil(math.log2(unique_values)) if unique_values > 1 else 1 52 | else: 53 | # default size for numerical columns 54 | size = 32 55 | 56 | def numerical_to_binary(val, min_val: float, max_val: float): 57 | return format( 58 | int((val - min_val) / (max_val - min_val) * (2**size - 1)), 59 | f"0{size}b", 60 | ) 61 | 62 | if dtype == "numerical": 63 | if not pd.api.types.is_numeric_dtype(column): 64 | raise ValueError( 65 | "Column must contain numeric values for numerical data type." 66 | ) 67 | if metadata is None: 68 | metadata = {"min": column.min(), "max": column.max()} 69 | 70 | min_val = metadata["min"] 71 | max_val = metadata["max"] 72 | 73 | converted_column = column.apply(numerical_to_binary, args=(min_val, max_val)) 74 | else: 75 | if metadata is None: 76 | category_map = { 77 | category: index for index, category in enumerate(column.unique()) 78 | } 79 | metadata = {"category_map": category_map} 80 | else: 81 | category_map = metadata["category_map"] 82 | 83 | def get_category_index(val): 84 | cat_idx = category_map.get(val, None) 85 | if cat_idx is not None: 86 | return format(cat_idx, f"0{size}b") 87 | else: 88 | return None 89 | 90 | converted_column = column.apply(get_category_index) 91 | 92 | return converted_column, metadata, size 93 | 94 | 95 | def column_from_fixed_size_binary( 96 | binary_column: pd.Series, metadata: Dict, dtype: COLUMN_DTYPE 97 | ) -> pd.Series: 98 | """ 99 | Convert a binary representation back to its original form using provided metadata. 100 | 101 | Args: 102 | binary_column (pd.Series): Column with binary representations. 103 | metadata (dict): Metadata necessary for conversion (min-max for numerical, mapping for categorical). 104 | dtype (COLUMN_DTYPE): The type of data ('numerical', 'categorical'). 105 | 106 | Returns: 107 | pd.Series: A column with original values. 108 | """ 109 | 110 | if dtype == "numerical": 111 | min_val = metadata["min"] 112 | max_val = metadata["max"] 113 | return binary_column.apply( 114 | lambda x: int(x, 2) / (2 ** len(x) - 1) * (max_val - min_val) + min_val 115 | ) 116 | 117 | elif dtype == "categorical": 118 | category_map = metadata["category_map"] 119 | inverse_map = {v: k for k, v in category_map.items()} 120 | return binary_column.apply(lambda x: inverse_map.get(int(x, 2), None)) 121 | else: 122 | raise ValueError( 123 | "Data type not recognized. Choose 'numerical' or 'categorical'." 124 | ) 125 | 126 | 127 | def pandas_row_to_tensor(row: pd.Series) -> torch.Tensor: 128 | row_str = "".join(row.astype(str)) 129 | row_np = np.array(list(row_str)) 130 | row_np = row_np.astype(int) 131 | row_binary = torch.tensor(row_np, dtype=torch.float) 132 | return row_binary 133 | 134 | 135 | class FixedSizeBinaryTableTransformation: 136 | """Transformation to convert pandas dataframe to fixed size binary tensor and back""" 137 | 138 | def __init__( 139 | self, 140 | task: TASK, 141 | numerical_columns: List[str], 142 | categorical_columns: List[str], 143 | parallel: bool = False, 144 | ): 145 | self.task = task 146 | self.numerical_columns = numerical_columns 147 | self.categorical_columns = categorical_columns 148 | self.parallel = parallel 149 | self.label_encoder = ( 150 | LabelEncoder() if self.task == "classification" else MinMaxScaler() 151 | ) 152 | 153 | self.fitted = False 154 | self.fitted_label = False 155 | self.metadata = None 156 | self.size = None 157 | 158 | def save_checkpoint(self, path_checkpoint: PathOrStr) -> None: 159 | """ 160 | Save the current state of the transformation to a file. 161 | 162 | Args: 163 | path_checkpoint: Path to the file where the state will be saved. 164 | """ 165 | joblib.dump(self, path_checkpoint) 166 | 167 | @classmethod 168 | def from_checkpoint( 169 | cls, path_checkpoint: PathOrStr 170 | ) -> "FixedSizeBinaryTableTransformation": 171 | """Loads the transformation from a .joblib file. 172 | 173 | Args: 174 | path_checkpoint: Path to the .joblib file. 175 | 176 | Returns: 177 | FixedSizeBinaryTableTransformation: The loaded transformation. 178 | """ 179 | 180 | transformer = joblib.load(path_checkpoint) 181 | return transformer 182 | 183 | @property 184 | def row_size(self) -> int: 185 | if not self.fitted: 186 | raise RuntimeError( 187 | "FixedSizeBinaryTableTransformation has not been fitted." 188 | ) 189 | return sum(self.size.values()) 190 | 191 | @property 192 | def n_classes(self) -> int: 193 | if not self.fitted: 194 | raise RuntimeError( 195 | "FixedSizeBinaryTableTransformation has not been fitted." 196 | ) 197 | 198 | if self.task != "classification": 199 | raise ValueError("Task must be 'classification'.") 200 | 201 | return len(self.label_encoder.classes_) 202 | 203 | def fit_transform(self, X: pd.DataFrame, y: Optional[pd.Series] = None): 204 | """Fits transformation and transforms the input dataframe 205 | 206 | Transformation doesn't handle the empty values 207 | All handling of empty values, dropping columns, etc should be done beforehand. 208 | 209 | Args: 210 | X: input dataframe. 211 | y: target dataframe. Can be None. 212 | 213 | Returns: 214 | Tuple[torch.Tensor, torch.Tensor]: transformed X and y. If y is provided 215 | 216 | OR 217 | 218 | torch.Tensor: transformed X, if y is not provided 219 | """ 220 | 221 | if self.fitted: 222 | raise RuntimeError( 223 | "Transformation already fitted. Use transform() instead." 224 | ) 225 | 226 | x_binary, metadata, size = self._convert_df_to_fixed_size_binary_tensor(X) 227 | self.metadata = metadata 228 | self.size = size 229 | 230 | self.fitted = True 231 | 232 | if y is not None: 233 | y_trans = self.fit_transform_label(y) 234 | self.fitted_label = True 235 | y_trans = torch.tensor(y_trans, dtype=torch.float) 236 | return x_binary, y_trans 237 | 238 | return x_binary 239 | 240 | def transform(self, X: pd.DataFrame, y: Optional[pd.Series] = None): 241 | """Transforms the input dataframe into a fixed size binary tensor. 242 | 243 | Args: 244 | X: input dataframe. 245 | y: target dataframe. Can be None. 246 | 247 | Returns: 248 | Tuple[torch.Tensor, torch.Tensor]: transformed X and y. If y is provided 249 | 250 | torch.Tensor: transformed X, if y is not provided 251 | """ 252 | 253 | if not self.fitted: 254 | raise RuntimeError("Fit before transform. Use fit_transform() instead.") 255 | 256 | x_binary, *_ = self._convert_df_to_fixed_size_binary_tensor(X) 257 | 258 | if y is not None: 259 | if not self.fitted_label: 260 | raise RuntimeError("Label encoder not fitted.") 261 | 262 | y_trans = self.transform_label(y) 263 | return x_binary, y_trans 264 | 265 | return x_binary 266 | 267 | def fit_transform_label(self, y: LABELS) -> torch.Tensor: 268 | """Fits encoder for labels and transforms the labels 269 | 270 | Args: 271 | y: labels to transform. Can be np.ndarray, pd.Series. 272 | 273 | Returns: 274 | torch.Tensor: transformed labels 275 | """ 276 | 277 | if isinstance(y, pd.Series): 278 | y = y.values 279 | elif isinstance(y, torch.Tensor): 280 | y = y.detach().cpu().numpy() 281 | 282 | y_trans = self.label_encoder.fit_transform(y.reshape(-1, 1)) 283 | y_trans = torch.tensor(y_trans, dtype=torch.float) 284 | 285 | self.fitted_label = True 286 | return y_trans 287 | 288 | def transform_label(self, y: LABELS) -> torch.Tensor: 289 | """Transforms the labels 290 | 291 | Args: 292 | y: labels to transform. Can be np.ndarray, pd.Series. 293 | 294 | Returns: 295 | torch.Tensor: transformed labels 296 | """ 297 | 298 | if not self.fitted_label: 299 | raise RuntimeError("Label encoder not fitted.") 300 | 301 | if isinstance(y, pd.Series): 302 | y = y.values 303 | elif isinstance(y, torch.Tensor): 304 | y = y.detach().cpu().numpy() 305 | 306 | y_trans = self.label_encoder.transform(y.reshape(-1, 1)) 307 | y_trans = torch.tensor(y_trans, dtype=torch.float) 308 | return y_trans 309 | 310 | def inverse_transform(self, X: torch.Tensor, y: Optional[torch.Tensor] = None): 311 | """Inverse transformation to convert binary fixed size tensor and labels back to original dataframe. 312 | 313 | Args: 314 | X: input binary fixed size tensor 315 | 316 | y: target dataframe. Can be None. 317 | 318 | Returns: 319 | pd.DataFrame: transformed dataframe, if y is not provided 320 | 321 | or 322 | 323 | pd.DataFrame, np.ndarray: transformed dataframe and labels, if y is provided 324 | """ 325 | 326 | if not self.fitted: 327 | raise RuntimeError("Fit before transform. Use fit_transform() instead.") 328 | 329 | df = self._convert_fixed_size_binary_tensor_to_df(X) 330 | 331 | if y is not None: 332 | if not self.fitted_label: 333 | raise RuntimeError("Label encoder not fitted.") 334 | 335 | y_trans = self.inverse_transform_label(y) 336 | return df, y_trans 337 | 338 | return df 339 | 340 | def inverse_transform_label(self, y: LABELS) -> np.ndarray: 341 | """Inverse transformation for labels 342 | 343 | Args: 344 | y: labels to transform. Can be np.ndarray, pd.Series or torch.Tensor. 345 | 346 | Returns: 347 | np.ndarray: transformed labels 348 | """ 349 | 350 | if not self.fitted_label: 351 | raise RuntimeError("Label encoder not fitted.") 352 | 353 | if isinstance(y, pd.Series): 354 | y = y.values 355 | elif isinstance(y, torch.Tensor): 356 | y = y.detach().cpu().numpy() 357 | 358 | if self.task == "classification": 359 | y = y.astype(int) 360 | 361 | y_trans = self.label_encoder.inverse_transform(y.reshape(-1, 1)) 362 | return y_trans 363 | 364 | def _convert_fixed_size_binary_tensor_to_df( 365 | self, rows_binary: torch.Tensor 366 | ) -> pd.DataFrame: 367 | rows_np = rows_binary.detach().cpu().numpy().astype(int) 368 | rows_str = rows_np.astype(str) 369 | 370 | df_bin = {} 371 | start = 0 372 | for col, size in self.size.items(): 373 | end = start + size 374 | df_bin[col] = ["".join(row) for row in rows_str[:, start:end]] 375 | start = end 376 | 377 | df_bin = pd.DataFrame(df_bin) 378 | 379 | df = pd.DataFrame() 380 | for col in df_bin.columns: 381 | df[col] = column_from_fixed_size_binary( 382 | df_bin[col], 383 | metadata=self.metadata[col], 384 | dtype="numerical" if col in self.numerical_columns else "categorical", 385 | ) 386 | 387 | return df 388 | 389 | def _convert_df_to_fixed_size_binary_tensor(self, df: pd.DataFrame): 390 | df_binary = pd.DataFrame() 391 | metadata = {} 392 | size = {} 393 | 394 | columns = df.columns 395 | for col in columns: 396 | col_binary, metadat_col, size_col = column_to_fixed_size_binary( 397 | column=df[col], 398 | dtype="numerical" if col in self.numerical_columns else "categorical", 399 | metadata=None if self.metadata is None else self.metadata[col], 400 | size=None if self.size is None else self.size[col], 401 | ) 402 | df_binary[col] = col_binary 403 | metadata[col] = metadat_col 404 | size[col] = size_col 405 | 406 | if self.parallel: 407 | # rows_binary = parallelize_dataframe(df, pandas_row_to_tensor, 4) 408 | n_jobs = -1 409 | rows_binary = Parallel(n_jobs=n_jobs)( 410 | delayed(pandas_row_to_tensor)(row) for _, row in df_binary.iterrows() 411 | ) 412 | else: 413 | rows_binary = df_binary.apply(pandas_row_to_tensor, axis=1).tolist() 414 | 415 | rows_binary = torch.stack(rows_binary, dim=0) 416 | return rows_binary, metadata, size 417 | -------------------------------------------------------------------------------- /binary_diffusion_tabular/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union, Dict 2 | from pathlib import Path 3 | import yaml 4 | import random 5 | import os 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | __all__ = [ 14 | "TASK", 15 | "exists", 16 | "default", 17 | "PathOrStr", 18 | "cycle", 19 | "zero_out_randomly", 20 | "get_base_model", 21 | "get_config", 22 | "save_config", 23 | "select_equally_distributed_numbers", 24 | "get_random_labels", 25 | "seed_everything" 26 | ] 27 | 28 | 29 | TASK = Literal["classification", "regression"] 30 | 31 | PathOrStr = Union[str, Path] 32 | 33 | 34 | def exists(x): 35 | return x is not None 36 | 37 | 38 | def default(val, d): 39 | if exists(val): 40 | return val 41 | return d() if callable(d) else d 42 | 43 | 44 | def cycle(dl): 45 | while True: 46 | for data in dl: 47 | yield data 48 | 49 | 50 | def zero_out_randomly( 51 | tensor: torch.Tensor, probability: float, dim: int = 0 52 | ) -> torch.Tensor: 53 | """Zero out randomly selected elements of a tensor with a given probability at a given dimension 54 | 55 | Args: 56 | tensor: tensor to zero out 57 | probability: probability of zeroing out an element 58 | dim: dimension along which to zero out elements 59 | 60 | Returns: 61 | torch.Tensor: tensor with randomly zeroed out elements 62 | """ 63 | 64 | mask = torch.rand(tensor.shape[dim]) < probability 65 | tensor[mask] = 0 66 | return tensor 67 | 68 | 69 | def get_base_model(model): 70 | if hasattr(model, "module"): 71 | return model.module 72 | return model 73 | 74 | 75 | def get_config(config): 76 | with open(config, "r") as stream: 77 | return yaml.load(stream, Loader=yaml.FullLoader) 78 | 79 | 80 | def save_config(config: Dict, yaml_file_path: PathOrStr) -> None: 81 | """save config to yaml file 82 | 83 | Args: 84 | config: config to save 85 | yaml_file_path: path to yaml file 86 | """ 87 | 88 | try: 89 | with open(yaml_file_path, "w") as file: 90 | yaml.dump(config, file, sort_keys=False, default_flow_style=False) 91 | except Exception as e: 92 | print(f"Error saving dictionary to YAML file: {e}") 93 | 94 | 95 | def select_equally_distributed_numbers(N: int, K: int) -> np.ndarray: 96 | if N % K == 0: 97 | return np.arange(0, N, N // K) 98 | 99 | step = (N - 1) // (K - 1) 100 | return np.arange(0, N, step)[:K] 101 | 102 | 103 | def get_random_labels( 104 | *, 105 | conditional: bool, 106 | task: TASK, 107 | n_classes: int, 108 | classifier_free_guidance: bool, 109 | n_labels: int, 110 | device, 111 | ) -> torch.Tensor | None: 112 | """Get random labels for a given task 113 | 114 | Args: 115 | conditional: if conditional generate labels, if not return None 116 | task: task to generate labels for 117 | n_classes: number of classes for classification task 118 | classifier_free_guidance: if True, classification labels will be one-hot encoded 119 | n_labels: number of labels to generate 120 | device: device to use 121 | 122 | Returns: 123 | torch.Tensor | None: labels to generate or None 124 | """ 125 | 126 | if not conditional: 127 | return None 128 | 129 | if task == "classification": 130 | labels = torch.randint(0, n_classes, size=(n_labels,), device=device) 131 | 132 | if classifier_free_guidance: 133 | labels = F.one_hot(labels, num_classes=n_classes).to(device=device).float() 134 | else: 135 | labels = torch.rand((n_labels, 1), device=device) 136 | 137 | return labels 138 | 139 | 140 | def seed_everything(seed: int) -> None: 141 | random.seed(seed) 142 | os.environ["PYTHONHASHSEED"] = str(seed) 143 | np.random.seed(seed) 144 | torch.manual_seed(seed) 145 | torch.cuda.manual_seed(seed) 146 | torch.backends.cudnn.deterministic = True 147 | torch.backends.cudnn.benchmark = False 148 | print(f"Setting all seeds to be {seed} to reproduce...") 149 | -------------------------------------------------------------------------------- /configs/adult.yaml: -------------------------------------------------------------------------------- 1 | data: # dataset parameters 2 | path_table: ./data/adult_train.csv # path to dataset in csv format 3 | numerical_columns: # list of numerical columns 4 | - age 5 | - fnlwgt 6 | - education-num 7 | - capital-gain 8 | - capital-loss 9 | - hours-per-week 10 | categorical_columns: # list of categorical columns 11 | - workclass 12 | - education 13 | - marital-status 14 | - occupation 15 | - relationship 16 | - race 17 | - sex 18 | - native-country 19 | columns_to_drop: # list of columns to drop 20 | dropna: True # if True, rows with nan values are dropped 21 | fillna: False # if True, numerical nan values are replaced with mean, categorical are replaced with mode. Either dropna or fillna can be True 22 | target_column: label # target column, if conditional generation. If None, unconditional generation 23 | split_feature_target: True # should be True for conditional generation 24 | task: classification # table task, can be `classification` or `regression` 25 | 26 | model: # denoiser model parameters 27 | dim: 256 # dimensionality of internal blocks 28 | n_res_blocks: 3 # number of residual blocks 29 | 30 | diffusion: # diffusion parameters 31 | schedule: quad # noise schedule, can be `linear`, `quad`, `sigmoid` 32 | n_timesteps: 1000 # number of denoising steps in denoiser pretraining 33 | target: two_way # denoiser prediction target: `mask`, `target`, `two_way` 34 | 35 | trainer: # trainer parameters 36 | train_num_steps: 500000 # number of training steps 37 | log_every: 100 # logging frequency 38 | save_every: 10000 # model saving frequency 39 | save_num_samples: 64 # number of generated samples to save during evaluation 40 | max_grad_norm: # maximum gradient norm, if None gradient clipping is not applied 41 | gradient_accumulate_every: 1 # gradient accumulation steps 42 | ema_decay: 0.995 # EMA decay 43 | ema_update_every: 10 # EMA update frequency 44 | lr: 0.0001 # learning rate 45 | opt_type: adam # type of optimizer to use. Options: `adam`, `adamw` 46 | opt_params: # optimizer parameters. See optimizer implementation for details 47 | batch_size: 256 # training batch size 48 | dataloader_workers: 16 # number of dataloader workers 49 | classifier_free_guidance: True # if True, classifier free guidance is used 50 | zero_token_probability: 0.1 # zero token probability in classifier free guidance. If classifier_free_guidance False, not use 51 | 52 | fine_tune_from: # path to model checkpoint to fine tune from 53 | 54 | comment: adult_CFG # comment for the results folder and logging -------------------------------------------------------------------------------- /configs/diabetes.yaml: -------------------------------------------------------------------------------- 1 | data: # dataset parameters 2 | path_table: ./data/diabetes_train.csv # path to dataset in csv format 3 | numerical_columns: # list of numerical columns 4 | - num_lab_procedures 5 | - num_procedures 6 | - num_medications 7 | - number_outpatient 8 | - number_emergency 9 | - number_inpatient 10 | - number_diagnoses 11 | - time_in_hospital 12 | categorical_columns: # list of categorical columns 13 | - race 14 | - gender 15 | - age 16 | - weight 17 | - admission_type_id 18 | - discharge_disposition_id 19 | - admission_source_id 20 | - payer_code 21 | - medical_specialty 22 | - diag_1 23 | - diag_2 24 | - diag_3 25 | - max_glu_serum 26 | - A1Cresult 27 | - metformin 28 | - repaglinide 29 | - nateglinide 30 | - chlorpropamide 31 | - glimepiride 32 | - acetohexamide 33 | - glipizide 34 | - glyburide 35 | - tolbutamide 36 | - pioglitazone 37 | - rosiglitazone 38 | - acarbose 39 | - miglitol 40 | - troglitazone 41 | - tolazamide 42 | - examide 43 | - citoglipton 44 | - insulin 45 | - glyburide-metformin 46 | - glipizide-metformin 47 | - glimepiride-pioglitazone 48 | - metformin-rosiglitazone 49 | - metformin-pioglitazone 50 | - change 51 | - diabetesMed 52 | columns_to_drop: # list of columns to drop 53 | - encounter_id 54 | - patient_nbr 55 | dropna: True # if True, rows with nan values are dropped 56 | fillna: False # if True, numerical nan values are replaced with mean, categorical are replaced with mode. Either dropna or fillna can be True 57 | target_column: readmitted # target column, if conditional generation. If None, unconditional generation 58 | split_feature_target: True # should be True for conditional generation 59 | task: classification # table task, can be `classification` or `regression` 60 | 61 | model: # denoiser model parameters 62 | dim: 256 # dimensionality of internal blocks 63 | n_res_blocks: 3 # number of residual blocks 64 | 65 | diffusion: # diffusion parameters 66 | schedule: quad # noise schedule, can be `linear`, `quad`, `sigmoid` 67 | n_timesteps: 1000 # number of denoising steps in denoiser pretraining 68 | target: two_way # denoiser prediction target: `mask`, `target`, `two_way` 69 | 70 | trainer: # trainer parameters 71 | train_num_steps: 500000 # number of training steps 72 | log_every: 100 # logging frequency 73 | save_every: 10000 # model saving frequency 74 | save_num_samples: 64 # number of generated samples to save during evaluation 75 | max_grad_norm: # maximum gradient norm, if None gradient clipping is not applied 76 | gradient_accumulate_every: 1 # gradient accumulation steps 77 | ema_decay: 0.995 # EMA decay 78 | ema_update_every: 10 # EMA update frequency 79 | lr: 0.0001 # learning rate 80 | opt_type: adam # type of optimizer to use. Options: `adam`, `adamw` 81 | opt_params: # optimizer parameters. See optimizer implementation for details 82 | batch_size: 256 # training batch size 83 | dataloader_workers: 16 # number of dataloader workers 84 | classifier_free_guidance: True # if True, classifier free guidance is used 85 | zero_token_probability: 0.1 # zero token probability in classifier free guidance. If classifier_free_guidance False, not use 86 | 87 | fine_tune_from: # path to model checkpoint to fine tune from 88 | 89 | comment: diabetes_CFG # comment for the results folder and logging -------------------------------------------------------------------------------- /configs/heloc.yaml: -------------------------------------------------------------------------------- 1 | data: # dataset parameters 2 | path_table: ./data/heloc_train.csv # path to dataset in csv format 3 | numerical_columns: # list of numerical columns 4 | - ExternalRiskEstimate 5 | - MSinceOldestTradeOpen 6 | - MSinceMostRecentTradeOpen 7 | - AverageMInFile 8 | - NumSatisfactoryTrades 9 | - NumTrades60Ever2DerogPubRec 10 | - NumTrades90Ever2DerogPubRec 11 | - PercentTradesNeverDelq 12 | - MSinceMostRecentDelq 13 | - MaxDelq2PublicRecLast12M 14 | - MaxDelqEver 15 | - NumTotalTrades 16 | - NumTradesOpeninLast12M 17 | - PercentInstallTrades 18 | - MSinceMostRecentInqexcl7days 19 | - NumInqLast6M 20 | - NumInqLast6Mexcl7days 21 | - NetFractionRevolvingBurden 22 | - NetFractionInstallBurden 23 | - NumRevolvingTradesWBalance 24 | - NumInstallTradesWBalance 25 | - NumBank2NatlTradesWHighUtilization 26 | - PercentTradesWBalance 27 | categorical_columns: # list of categorical columns 28 | - RiskPerformance 29 | columns_to_drop: # list of columns to drop 30 | dropna: True # if True, rows with nan values are dropped 31 | fillna: False # if True, numerical nan values are replaced with mean, categorical are replaced with mode. Either dropna or fillna can be True 32 | target_column: RiskPerformance # target column, if conditional generation. If None, unconditional generation 33 | split_feature_target: True # should be True for conditional generation 34 | task: classification # table task, can be `classification` or `regression` 35 | 36 | model: # denoiser model parameters 37 | dim: 256 # dimensionality of internal blocks 38 | n_res_blocks: 3 # number of residual blocks 39 | 40 | diffusion: # diffusion parameters 41 | schedule: quad # noise schedule, can be `linear`, `quad`, `sigmoid` 42 | n_timesteps: 1000 # number of denoising steps in denoiser pretraining 43 | target: two_way # denoiser prediction target: `mask`, `target`, `two_way` 44 | 45 | trainer: # trainer parameters 46 | train_num_steps: 500000 # number of training steps 47 | log_every: 100 # logging frequency 48 | save_every: 10000 # model saving frequency 49 | save_num_samples: 64 # number of generated samples to save during evaluation 50 | max_grad_norm: # maximum gradient norm, if None gradient clipping is not applied 51 | gradient_accumulate_every: 1 # gradient accumulation steps 52 | ema_decay: 0.995 # EMA decay 53 | ema_update_every: 10 # EMA update frequency 54 | lr: 0.0001 # learning rate 55 | opt_type: adam # type of optimizer to use. Options: `adam`, `adamw` 56 | opt_params: # optimizer parameters. See optimizer implementation for details 57 | batch_size: 256 # training batch size 58 | dataloader_workers: 16 # number of dataloader workers 59 | classifier_free_guidance: True # if True, classifier free guidance is used 60 | zero_token_probability: 0.1 # zero token probability in classifier free guidance. If classifier_free_guidance False, not use 61 | 62 | fine_tune_from: # path to model checkpoint to fine tune from 63 | 64 | comment: heloc_CFG # comment for the results folder and logging -------------------------------------------------------------------------------- /configs/housing.yaml: -------------------------------------------------------------------------------- 1 | data: # dataset parameters 2 | path_table: ./data/housing_train.csv # path to dataset in csv format 3 | numerical_columns: # list of numerical columns 4 | - longitude 5 | - latitude 6 | - housing_median_age 7 | - total_rooms 8 | - total_bedrooms 9 | - population 10 | - households 11 | - median_income 12 | - median_house_value 13 | categorical_columns: # list of categorical columns 14 | - ocean_proximity 15 | columns_to_drop: # list of columns to drop 16 | dropna: True # if True, rows with nan values are dropped 17 | fillna: False # if True, numerical nan values are replaced with mean, categorical are replaced with mode. Either dropna or fillna can be True 18 | target_column: median_house_value # target column, if conditional generation. If None, unconditional generation 19 | split_feature_target: True # should be True for conditional generation 20 | task: regression # table task, can be `classification` or `regression` 21 | 22 | model: # denoiser model parameters 23 | dim: 256 # dimensionality of internal blocks 24 | n_res_blocks: 3 # number of residual blocks 25 | 26 | diffusion: # diffusion parameters 27 | schedule: quad # noise schedule, can be `linear`, `quad`, `sigmoid` 28 | n_timesteps: 1000 # number of denoising steps in denoiser pretraining 29 | target: two_way # denoiser prediction target: `mask`, `target`, `two_way` 30 | 31 | trainer: # trainer parameters 32 | train_num_steps: 500000 # number of training steps 33 | log_every: 100 # logging frequency 34 | save_every: 10000 # model saving frequency 35 | save_num_samples: 64 # number of generated samples to save during evaluation 36 | max_grad_norm: # maximum gradient norm, if None gradient clipping is not applied 37 | gradient_accumulate_every: 1 # gradient accumulation steps 38 | ema_decay: 0.995 # EMA decay 39 | ema_update_every: 10 # EMA update frequency 40 | lr: 0.0001 # learning rate 41 | opt_type: adam # type of optimizer to use. Options: `adam`, `adamw` 42 | opt_params: # optimizer parameters. See optimizer implementation for details 43 | batch_size: 256 # training batch size 44 | dataloader_workers: 16 # number of dataloader workers 45 | classifier_free_guidance: True # if True, classifier free guidance is used 46 | zero_token_probability: 0.1 # zero token probability in classifier free guidance. If classifier_free_guidance False, not use 47 | 48 | fine_tune_from: # path to model checkpoint to fine tune from 49 | 50 | comment: housing_CFG # comment for the results folder and logging -------------------------------------------------------------------------------- /configs/sick.yaml: -------------------------------------------------------------------------------- 1 | data: # dataset parameters 2 | path_table: ./data/sick_train.csv # path to dataset in csv format 3 | numerical_columns: # list of numerical columns 4 | - age 5 | - TSH 6 | - T3 7 | - TT4 8 | - T4U 9 | - FTI 10 | categorical_columns: # list of categorical columns 11 | - Class 12 | - sex 13 | - on_thyroxine 14 | - query_on_thyroxine 15 | - on_antithyroid_medication 16 | - sick 17 | - pregnant 18 | - thyroid_surgery 19 | - I131_treatment 20 | - query_hypothyroid 21 | - query_hyperthyroid 22 | - lithium 23 | - goitre 24 | - tumor 25 | - hypopituitary 26 | - psych 27 | - TSH_measured 28 | - T3_measured 29 | - TT4_measured 30 | - T4U_measured 31 | - FTI_measured 32 | - referral_source 33 | columns_to_drop: # list of columns to drop 34 | - TBG 35 | - TBG_measured 36 | dropna: True # if True, rows with nan values are dropped 37 | fillna: False # if True, numerical nan values are replaced with mean, categorical are replaced with mode. Either dropna or fillna can be True 38 | target_column: Class # target column, if conditional generation. If None, unconditional generation 39 | split_feature_target: True # should be True for conditional generation 40 | task: classification # table task, can be `classification` or `regression` 41 | 42 | model: # denoiser model parameters 43 | dim: 256 # dimensionality of internal blocks 44 | n_res_blocks: 3 # number of residual blocks 45 | 46 | diffusion: # diffusion parameters 47 | schedule: quad # noise schedule, can be `linear`, `quad`, `sigmoid` 48 | n_timesteps: 1000 # number of denoising steps in denoiser pretraining 49 | target: two_way # denoiser prediction target: `mask`, `target`, `two_way` 50 | 51 | trainer: # trainer parameters 52 | train_num_steps: 500000 # number of training steps 53 | log_every: 100 # logging frequency 54 | save_every: 10000 # model saving frequency 55 | save_num_samples: 64 # number of generated samples to save during evaluation 56 | max_grad_norm: # maximum gradient norm, if None gradient clipping is not applied 57 | gradient_accumulate_every: 1 # gradient accumulation steps 58 | ema_decay: 0.995 # EMA decay 59 | ema_update_every: 10 # EMA update frequency 60 | lr: 0.0001 # learning rate 61 | opt_type: adam # type of optimizer to use. Options: `adam`, `adamw` 62 | opt_params: # optimizer parameters. See optimizer implementation for details 63 | batch_size: 256 # training batch size 64 | dataloader_workers: 16 # number of dataloader workers 65 | classifier_free_guidance: True # if True, classifier free guidance is used 66 | zero_token_probability: 0.1 # zero token probability in classifier free guidance. If classifier_free_guidance False, not use 67 | 68 | fine_tune_from: # path to model checkpoint to fine tune from 69 | 70 | comment: sick_CFG # comment for the results folder and logging -------------------------------------------------------------------------------- /configs/travel.yaml: -------------------------------------------------------------------------------- 1 | data: # dataset parameters 2 | path_table: ./data/travel_train.csv # path to dataset in csv format 3 | numerical_columns: # list of numerical columns 4 | - Age 5 | - ServicesOpted 6 | categorical_columns: # list of categorical columns 7 | - Target 8 | - FrequentFlyer 9 | - AnnualIncomeClass 10 | - AccountSyncedToSocialMedia 11 | - BookedHotelOrNot 12 | columns_to_drop: # list of columns to drop 13 | dropna: True # if True, rows with nan values are dropped 14 | fillna: False # if True, numerical nan values are replaced with mean, categorical are replaced with mode. Either dropna or fillna can be True 15 | target_column: Target # target column, if conditional generation. If None, unconditional generation 16 | split_feature_target: True # should be True for conditional generation 17 | task: classification # table task, can be `classification` or `regression` 18 | 19 | model: # denoiser model parameters 20 | dim: 256 # dimensionality of internal blocks 21 | n_res_blocks: 3 # number of residual blocks 22 | 23 | diffusion: # diffusion parameters 24 | schedule: quad # noise schedule, can be `linear`, `quad`, `sigmoid` 25 | n_timesteps: 1000 # number of denoising steps in denoiser pretraining 26 | target: two_way # denoiser prediction target: `mask`, `target`, `two_way` 27 | 28 | trainer: # trainer parameters 29 | train_num_steps: 500000 # number of training steps 30 | log_every: 100 # logging frequency 31 | save_every: 10000 # model saving frequency 32 | save_num_samples: 64 # number of generated samples to save during evaluation 33 | max_grad_norm: # maximum gradient norm, if None gradient clipping is not applied 34 | gradient_accumulate_every: 1 # gradient accumulation steps 35 | ema_decay: 0.995 # EMA decay 36 | ema_update_every: 10 # EMA update frequency 37 | lr: 0.0001 # learning rate 38 | opt_type: adam # type of optimizer to use. Options: `adam`, `adamw` 39 | opt_params: # optimizer parameters. See optimizer implementation for details 40 | batch_size: 256 # training batch size 41 | dataloader_workers: 16 # number of dataloader workers 42 | classifier_free_guidance: True # if True, classifier free guidance is used 43 | zero_token_probability: 0.1 # zero token probability in classifier free guidance. If classifier_free_guidance False, not use 44 | 45 | fine_tune_from: # path to model checkpoint to fine tune from 46 | 47 | comment: travel_CFG # comment for the results folder and logging -------------------------------------------------------------------------------- /data/travel.csv: -------------------------------------------------------------------------------- 1 | Age,FrequentFlyer,AnnualIncomeClass,ServicesOpted,AccountSyncedToSocialMedia,BookedHotelOrNot,Target 2 | 34,No,Middle Income,6,No,Yes,0 3 | 34,Yes,Low Income,5,Yes,No,1 4 | 37,No,Middle Income,3,Yes,No,0 5 | 30,No,Middle Income,2,No,No,0 6 | 30,No,Low Income,1,No,No,0 7 | 27,Yes,High Income,1,No,Yes,1 8 | 34,No,Middle Income,4,Yes,Yes,0 9 | 34,No,Low Income,2,Yes,No,1 10 | 30,No,Low Income,3,No,Yes,0 11 | 36,Yes,High Income,1,No,No,1 12 | 34,No,Low Income,1,Yes,Yes,0 13 | 28,No,Middle Income,2,No,No,1 14 | 35,No Record,Middle Income,1,Yes,Yes,0 15 | 34,Yes,Low Income,4,No,No,0 16 | 34,No,Middle Income,5,No,No,0 17 | 37,Yes,Low Income,6,No,Yes,0 18 | 30,No,Low Income,1,Yes,Yes,0 19 | 30,Yes,High Income,1,Yes,No,0 20 | 31,No,Middle Income,1,No,Yes,0 21 | 37,No,Low Income,2,Yes,No,1 22 | 30,No,Middle Income,4,No,Yes,0 23 | 31,Yes,High Income,1,No,No,1 24 | 34,Yes,Low Income,1,Yes,No,0 25 | 30,No Record,Middle Income,2,No,No,0 26 | 34,No,Middle Income,1,No,Yes,0 27 | 38,Yes,Low Income,1,No,Yes,0 28 | 37,No,Middle Income,3,Yes,No,0 29 | 30,No,Middle Income,5,Yes,No,0 30 | 28,No,Low Income,1,Yes,No,0 31 | 34,Yes,High Income,1,No,No,0 32 | 33,No,Middle Income,6,No,Yes,0 33 | 34,No,Low Income,2,No,No,0 34 | 27,No,Middle Income,3,Yes,No,0 35 | 35,Yes,High Income,1,No,No,1 36 | 30,No Record,Low Income,4,No,No,0 37 | 36,No,Middle Income,2,No,Yes,0 38 | 34,Yes,Low Income,1,Yes,Yes,0 39 | 37,Yes,Low Income,1,Yes,No,1 40 | 37,No,Middle Income,3,No,No,0 41 | 36,No,Middle Income,2,No,No,0 42 | 27,No,Low Income,5,No,Yes,0 43 | 36,Yes,High Income,4,No,No,0 44 | 28,No,Middle Income,1,Yes,Yes,0 45 | 30,No,Low Income,2,No,No,0 46 | 27,No,Middle Income,3,No,No,0 47 | 37,Yes,High Income,6,No,Yes,1 48 | 27,No,Low Income,1,Yes,No,0 49 | 38,No,Middle Income,2,Yes,No,0 50 | 30,No,Middle Income,4,No,Yes,0 51 | 34,Yes,Low Income,1,No,No,0 52 | 34,Yes,Low Income,3,No,Yes,1 53 | 31,No,Middle Income,2,No,No,0 54 | 34,No,Low Income,1,Yes,No,0 55 | 30,Yes,High Income,5,No,No,0 56 | 31,No,Middle Income,1,No,Yes,0 57 | 28,No,Low Income,4,Yes,Yes,1 58 | 30,No Record,Middle Income,3,Yes,Yes,0 59 | 37,Yes,High Income,1,Yes,No,1 60 | 36,No,Low Income,1,No,No,0 61 | 36,No,Middle Income,2,No,No,0 62 | 34,No,Middle Income,6,No,Yes,0 63 | 35,Yes,Low Income,1,No,No,0 64 | 30,No,Middle Income,4,Yes,No,0 65 | 29,No,Middle Income,2,No,No,0 66 | 33,Yes,Low Income,1,Yes,Yes,0 67 | 28,Yes,High Income,1,No,Yes,1 68 | 33,No,Middle Income,5,Yes,Yes,0 69 | 37,No Record,Low Income,2,Yes,No,1 70 | 31,No,Middle Income,3,No,No,0 71 | 34,Yes,High Income,4,No,No,1 72 | 37,No,Low Income,1,No,Yes,0 73 | 30,No,Low Income,2,No,No,0 74 | 30,No,Middle Income,1,Yes,Yes,0 75 | 30,Yes,Low Income,1,Yes,No,1 76 | 30,No,Middle Income,3,No,No,0 77 | 37,No,Middle Income,6,No,Yes,0 78 | 31,No,Low Income,4,Yes,No,0 79 | 34,Yes,High Income,1,Yes,No,0 80 | 34,Yes,Low Income,1,No,Yes,0 81 | 34,No,Low Income,5,No,No,0 82 | 28,No,Middle Income,3,No,Yes,0 83 | 27,Yes,High Income,1,No,No,1 84 | 30,No,Low Income,1,Yes,No,0 85 | 37,No,Middle Income,4,No,No,0 86 | 35,No,Middle Income,1,No,Yes,0 87 | 27,Yes,Low Income,1,No,Yes,1 88 | 35,No,Middle Income,3,Yes,No,0 89 | 30,No,Middle Income,2,Yes,No,0 90 | 37,No,Low Income,1,No,Yes,0 91 | 30,Yes,High Income,1,No,No,0 92 | 29,No,Middle Income,6,No,Yes,0 93 | 30,No,Low Income,2,Yes,No,1 94 | 30,No,Low Income,5,Yes,No,0 95 | 36,Yes,High Income,1,No,No,1 96 | 37,No,Low Income,1,No,No,0 97 | 28,No,Middle Income,2,No,Yes,1 98 | 30,No,Middle Income,1,Yes,Yes,0 99 | 31,Yes,Low Income,4,Yes,No,1 100 | 31,No,Middle Income,3,No,No,0 101 | 31,No,Low Income,2,No,No,0 102 | 30,No Record,Low Income,1,Yes,Yes,0 103 | 36,Yes,High Income,1,No,No,0 104 | 31,No,Middle Income,1,Yes,Yes,0 105 | 28,No,Low Income,2,No,No,1 106 | 30,No,Middle Income,4,No,Yes,0 107 | 30,Yes,High Income,6,No,Yes,1 108 | 37,Yes,Low Income,1,Yes,No,0 109 | 37,No,Middle Income,2,Yes,No,0 110 | 28,No,Middle Income,1,No,Yes,0 111 | 27,Yes,Low Income,1,Yes,No,1 112 | 34,No,Middle Income,3,No,Yes,0 113 | 30,No Record,Middle Income,4,No,No,0 114 | 33,No,Low Income,1,Yes,Yes,0 115 | 31,Yes,High Income,1,No,No,0 116 | 31,No,Middle Income,1,No,Yes,0 117 | 31,No,Low Income,2,No,Yes,0 118 | 30,No,Middle Income,3,Yes,No,0 119 | 34,Yes,High Income,1,Yes,No,1 120 | 34,No,Low Income,5,Yes,No,0 121 | 34,No,Middle Income,2,No,No,0 122 | 30,No,Low Income,6,No,Yes,0 123 | 28,Yes,Low Income,1,No,No,1 124 | 35,No Record,Middle Income,3,Yes,No,0 125 | 29,No,Middle Income,2,No,No,0 126 | 35,No,Low Income,1,No,No,0 127 | 31,Yes,High Income,4,No,Yes,0 128 | 35,No,Middle Income,1,Yes,Yes,0 129 | 29,No,Low Income,2,Yes,No,1 130 | 34,No,Middle Income,3,No,Yes,0 131 | 31,Yes,High Income,1,No,No,1 132 | 30,No,Low Income,1,No,Yes,0 133 | 30,No,Middle Income,5,No,No,0 134 | 30,No,Middle Income,4,Yes,Yes,0 135 | 27,Yes,Low Income,1,No,No,1 136 | 30,No,Low Income,3,No,No,0 137 | 30,No,Middle Income,6,No,Yes,0 138 | 31,No,Low Income,1,Yes,Yes,0 139 | 31,Yes,High Income,1,Yes,No,0 140 | 30,No,Middle Income,1,No,Yes,0 141 | 38,No,Low Income,4,No,No,0 142 | 30,No,Middle Income,3,No,Yes,0 143 | 27,Yes,High Income,1,No,No,1 144 | 35,No,Low Income,1,Yes,No,0 145 | 28,No,Middle Income,2,No,No,1 146 | 34,No Record,Middle Income,5,No,Yes,0 147 | 30,Yes,Low Income,1,Yes,Yes,1 148 | 33,No,Middle Income,4,Yes,No,0 149 | 34,No,Middle Income,2,Yes,No,0 150 | 27,No,Low Income,1,No,No,0 151 | 30,Yes,High Income,1,No,No,0 152 | 30,No,Middle Income,6,No,Yes,0 153 | 31,No,Low Income,2,No,No,0 154 | 30,No,Middle Income,3,Yes,Yes,0 155 | 27,Yes,High Income,4,No,No,1 156 | 30,No,Low Income,1,Yes,No,0 157 | 37,Yes,Low Income,2,No,Yes,0 158 | 30,No,Middle Income,1,Yes,Yes,0 159 | 30,Yes,Low Income,5,Yes,No,1 160 | 34,No,Middle Income,3,No,No,0 161 | 29,No,Middle Income,2,No,No,0 162 | 34,No,Low Income,4,No,Yes,0 163 | 37,Yes,High Income,1,No,No,0 164 | 37,Yes,Low Income,1,Yes,Yes,0 165 | 35,No,Low Income,2,Yes,No,1 166 | 30,No,Middle Income,3,No,No,0 167 | 36,Yes,High Income,6,No,Yes,1 168 | 30,No Record,Low Income,1,Yes,No,0 169 | 34,No,Middle Income,4,Yes,No,0 170 | 29,No,Middle Income,1,No,Yes,0 171 | 27,Yes,Low Income,1,No,No,1 172 | 34,No,Middle Income,5,No,Yes,0 173 | 30,No,Middle Income,2,No,No,0 174 | 29,No,Low Income,1,Yes,No,0 175 | 28,Yes,High Income,1,No,No,1 176 | 37,No,Middle Income,4,No,Yes,0 177 | 33,No,Low Income,2,No,Yes,0 178 | 37,Yes,Low Income,3,Yes,Yes,1 179 | 30,Yes,High Income,1,Yes,No,1 180 | 35,No,Low Income,1,No,No,0 181 | 28,No,Middle Income,2,No,No,1 182 | 30,No,Middle Income,6,No,Yes,0 183 | 28,Yes,Low Income,4,Yes,No,1 184 | 38,No,Middle Income,3,Yes,No,0 185 | 27,No,Low Income,5,No,No,1 186 | 27,No,Low Income,1,No,No,0 187 | 30,Yes,High Income,1,No,Yes,0 188 | 30,No,Middle Income,1,Yes,Yes,0 189 | 30,No,Low Income,2,Yes,No,1 190 | 28,No Record,Middle Income,4,No,No,0 191 | 28,Yes,High Income,1,No,No,1 192 | 29,No,Low Income,1,Yes,Yes,0 193 | 33,No,Middle Income,2,No,No,0 194 | 30,No,Middle Income,1,Yes,Yes,0 195 | 37,Yes,Low Income,1,No,No,0 196 | 36,No,Middle Income,3,No,No,0 197 | 27,No,Middle Income,6,No,Yes,1 198 | 37,No,Low Income,5,Yes,No,1 199 | 31,Yes,High Income,1,Yes,No,0 200 | 27,No,Middle Income,1,No,Yes,0 201 | 33,No Record,Low Income,2,Yes,No,1 202 | 29,No,Middle Income,3,No,Yes,0 203 | 34,Yes,High Income,1,No,No,1 204 | 30,No,Low Income,4,Yes,No,0 205 | 30,No,Middle Income,2,No,No,0 206 | 35,Yes,Low Income,1,No,Yes,0 207 | 31,Yes,Low Income,1,No,Yes,0 208 | 27,No,Middle Income,3,Yes,No,0 209 | 34,No,Middle Income,2,Yes,No,0 210 | 31,No,Low Income,1,Yes,Yes,0 211 | 27,Yes,High Income,5,No,No,1 212 | 30,No Record,Middle Income,6,No,Yes,0 213 | 29,No,Low Income,2,No,No,0 214 | 30,No,Middle Income,3,Yes,No,0 215 | 28,Yes,High Income,1,No,No,1 216 | 29,No,Low Income,1,No,No,0 217 | 36,No,Middle Income,2,No,Yes,0 218 | 37,No,Middle Income,4,Yes,Yes,0 219 | 30,Yes,Low Income,1,Yes,No,1 220 | 28,No,Low Income,3,No,No,0 221 | 34,No,Middle Income,2,No,No,0 222 | 38,No,Low Income,1,No,Yes,0 223 | 30,Yes,High Income,1,No,No,0 224 | 30,No,Middle Income,5,Yes,Yes,0 225 | 33,No,Low Income,4,No,No,0 226 | 34,No,Middle Income,3,No,Yes,0 227 | 34,Yes,High Income,6,No,Yes,1 228 | 37,No,Low Income,1,Yes,No,0 229 | 37,No,Middle Income,2,Yes,No,0 230 | 37,No,Middle Income,1,No,Yes,0 231 | 37,Yes,Low Income,1,No,No,0 232 | 31,No,Middle Income,4,No,Yes,0 233 | 34,No,Middle Income,2,No,No,0 234 | 30,No Record,Low Income,1,Yes,Yes,0 235 | 34,Yes,High Income,1,No,No,0 236 | 37,No,Middle Income,1,No,Yes,0 237 | 33,No,Low Income,5,Yes,Yes,1 238 | 28,No,Middle Income,3,Yes,No,0 239 | 33,Yes,High Income,4,Yes,No,1 240 | 31,No,Low Income,1,No,No,0 241 | 28,No,Low Income,2,No,No,1 242 | 29,No,Middle Income,6,No,Yes,0 243 | 38,Yes,Low Income,1,No,No,0 244 | 36,No,Middle Income,3,Yes,No,0 245 | 28,No Record,Middle Income,2,No,No,1 246 | 29,No,Low Income,4,Yes,No,0 247 | 30,Yes,High Income,1,No,Yes,0 248 | 34,Yes,Low Income,1,Yes,Yes,0 249 | 29,No,Low Income,2,Yes,No,1 250 | 37,No,Middle Income,5,No,Yes,1 251 | 30,Yes,High Income,1,No,No,1 252 | 37,No,Low Income,1,No,Yes,0 253 | 30,No,Middle Income,4,No,No,0 254 | 31,No,Middle Income,1,Yes,Yes,0 255 | 30,Yes,Low Income,1,Yes,No,1 256 | 30,No Record,Middle Income,3,No,No,0 257 | 37,No,Middle Income,6,No,Yes,0 258 | 37,No,Low Income,1,Yes,Yes,0 259 | 36,Yes,High Income,1,Yes,No,0 260 | 31,No,Middle Income,4,No,Yes,0 261 | 36,No,Low Income,2,No,No,0 262 | 28,No,Low Income,3,No,Yes,0 263 | 30,Yes,High Income,5,No,No,1 264 | 28,No,Low Income,1,Yes,No,0 265 | 38,No,Middle Income,2,No,No,0 266 | 31,No,Middle Income,1,No,Yes,0 267 | 29,Yes,Low Income,4,No,Yes,0 268 | 34,No,Middle Income,3,Yes,No,0 269 | 36,Yes,Low Income,2,Yes,No,0 270 | 37,No,Low Income,1,No,No,0 271 | 29,Yes,High Income,1,No,No,0 272 | 37,No,Middle Income,6,No,Yes,0 273 | 38,No,Low Income,2,Yes,No,1 274 | 30,No,Middle Income,4,Yes,Yes,0 275 | 37,Yes,High Income,1,No,No,1 276 | 35,Yes,Low Income,5,No,No,1 277 | 28,No,Middle Income,2,No,Yes,1 278 | 31,No Record,Middle Income,1,Yes,Yes,0 279 | 34,Yes,Low Income,1,Yes,No,1 280 | 35,No,Middle Income,3,No,No,0 281 | 35,No,Middle Income,4,No,No,0 282 | 35,No,Low Income,1,Yes,Yes,0 283 | 28,Yes,High Income,1,No,No,1 284 | 28,No,Middle Income,1,Yes,Yes,0 285 | 28,No,Low Income,2,No,No,1 286 | 30,No,Middle Income,3,No,No,0 287 | 36,Yes,High Income,6,No,Yes,1 288 | 30,No,Low Income,4,Yes,No,0 289 | 31,No Record,Middle Income,5,Yes,No,0 290 | 31,No,Low Income,1,No,Yes,0 291 | 28,Yes,Low Income,1,Yes,No,1 292 | 35,No,Middle Income,3,No,Yes,0 293 | 37,No,Middle Income,2,No,No,0 294 | 29,No,Low Income,1,Yes,No,0 295 | 34,Yes,High Income,4,No,No,0 296 | 31,No,Middle Income,1,No,Yes,0 297 | 30,No,Low Income,2,No,Yes,0 298 | 35,No,Middle Income,3,Yes,Yes,0 299 | 38,Yes,High Income,1,Yes,No,1 300 | 27,No Record,Low Income,1,Yes,No,0 301 | 29,No,Middle Income,2,No,No,0 302 | 35,No,Middle Income,6,No,Yes,1 303 | 29,Yes,Low Income,1,No,No,0 304 | 37,Yes,Low Income,3,Yes,No,1 305 | 34,No,Middle Income,2,No,No,0 306 | 37,No,Low Income,1,No,Yes,0 307 | 28,Yes,High Income,1,No,Yes,1 308 | 36,No,Middle Income,1,Yes,Yes,0 309 | 30,No,Low Income,4,Yes,No,1 310 | 37,No,Middle Income,3,No,No,0 311 | 36,Yes,High Income,1,No,No,1 312 | 33,No,Low Income,1,No,Yes,0 313 | 37,No,Middle Income,2,No,No,0 314 | 35,No,Middle Income,1,Yes,Yes,0 315 | 27,Yes,Low Income,5,No,No,1 316 | 28,No,Middle Income,4,No,No,0 317 | 30,No,Middle Income,6,No,Yes,0 318 | 34,Yes,Low Income,1,Yes,No,0 319 | 35,Yes,High Income,1,Yes,No,0 320 | 37,No,Middle Income,1,No,Yes,0 321 | 27,No,Low Income,2,No,No,1 322 | 38,No Record,Middle Income,3,No,Yes,0 323 | 30,Yes,High Income,4,No,No,1 324 | 30,No,Low Income,1,Yes,No,0 325 | 36,Yes,Low Income,2,No,No,0 326 | 30,No,Middle Income,1,No,Yes,0 327 | 34,Yes,Low Income,1,Yes,Yes,1 328 | 30,No,Middle Income,5,Yes,No,0 329 | 37,No,Middle Income,2,Yes,No,0 330 | 37,No,Low Income,4,No,Yes,0 331 | 30,Yes,High Income,1,No,No,0 332 | 38,Yes,Low Income,6,No,Yes,0 333 | 36,No Record,Low Income,2,No,No,0 334 | 34,No,Middle Income,3,Yes,No,0 335 | 34,Yes,High Income,1,No,No,1 336 | 31,No,Low Income,1,Yes,No,0 337 | 30,No,Middle Income,4,No,Yes,0 338 | 35,No,Middle Income,1,Yes,Yes,0 339 | 33,Yes,Low Income,1,Yes,No,1 340 | 30,No,Middle Income,3,No,No,0 341 | 31,No,Middle Income,5,No,No,0 342 | 29,No,Low Income,1,No,Yes,0 343 | 30,Yes,High Income,1,No,No,0 344 | 30,No Record,Middle Income,4,Yes,Yes,0 345 | 30,No,Low Income,2,Yes,No,1 346 | 35,Yes,Low Income,3,No,Yes,1 347 | 31,Yes,High Income,6,No,Yes,1 348 | 27,No,Low Income,1,Yes,No,0 349 | 37,No,Middle Income,2,Yes,No,0 350 | 34,No,Middle Income,1,No,Yes,0 351 | 30,Yes,Low Income,4,No,No,0 352 | 34,No,Middle Income,3,No,Yes,0 353 | 31,No,Low Income,2,No,No,0 354 | 34,No,Low Income,5,Yes,Yes,0 355 | 34,Yes,High Income,1,No,No,0 356 | 27,No,Middle Income,1,No,Yes,0 357 | 30,No,Low Income,2,No,Yes,0 358 | 37,No,Middle Income,4,Yes,No,0 359 | 37,Yes,High Income,1,Yes,No,1 360 | 37,Yes,Low Income,1,No,No,0 361 | 27,No,Middle Income,2,No,No,1 362 | 30,No,Middle Income,6,No,Yes,0 363 | 31,Yes,Low Income,1,Yes,No,1 364 | 30,No,Middle Income,3,Yes,No,0 365 | 37,No,Middle Income,4,No,No,0 366 | 34,No Record,Low Income,1,No,No,0 367 | 27,Yes,High Income,5,No,Yes,1 368 | 30,No,Middle Income,1,Yes,Yes,0 369 | 34,No,Low Income,2,Yes,No,1 370 | 38,No,Middle Income,3,No,Yes,0 371 | 35,Yes,High Income,1,No,No,1 372 | 34,No,Low Income,4,Yes,Yes,0 373 | 29,No,Middle Income,2,No,No,0 374 | 31,No,Low Income,1,Yes,Yes,0 375 | 30,Yes,Low Income,1,No,No,0 376 | 30,No,Middle Income,3,No,No,0 377 | 30,No Record,Middle Income,6,No,Yes,0 378 | 34,No,Low Income,1,Yes,Yes,0 379 | 29,Yes,High Income,4,Yes,No,0 380 | 27,No,Middle Income,5,No,Yes,0 381 | 29,No,Low Income,2,Yes,No,1 382 | 30,No,Middle Income,3,No,Yes,0 383 | 31,Yes,High Income,1,No,No,1 384 | 34,No,Low Income,1,Yes,No,0 385 | 31,No,Middle Income,2,No,No,0 386 | 36,No,Middle Income,4,No,Yes,0 387 | 37,Yes,Low Income,1,No,Yes,0 388 | 30,No Record,Low Income,3,Yes,No,0 389 | 30,No,Middle Income,2,Yes,No,0 390 | 30,No,Low Income,1,Yes,No,0 391 | 34,Yes,High Income,1,No,No,0 392 | 36,No,Middle Income,6,No,Yes,0 393 | 30,No,Low Income,5,No,No,0 394 | 34,No,Middle Income,3,Yes,Yes,0 395 | 31,Yes,High Income,1,No,No,1 396 | 30,No,Low Income,1,No,No,0 397 | 35,No,Middle Income,2,No,Yes,0 398 | 36,No,Middle Income,1,Yes,Yes,0 399 | 27,Yes,Low Income,1,Yes,No,1 400 | 30,No,Middle Income,4,No,No,0 401 | 38,No,Middle Income,2,No,No,0 402 | 31,No,Low Income,1,No,Yes,0 403 | 37,Yes,High Income,1,No,No,0 404 | 34,No,Middle Income,1,Yes,Yes,0 405 | 30,No,Low Income,2,No,No,0 406 | 36,No,Middle Income,5,No,No,1 407 | 35,Yes,High Income,6,No,Yes,1 408 | 30,No,Low Income,1,Yes,No,0 409 | 27,No,Low Income,2,Yes,No,1 410 | 34,No Record,Middle Income,1,No,Yes,0 411 | 30,Yes,Low Income,1,No,No,0 412 | 30,No,Middle Income,3,No,Yes,0 413 | 30,No,Middle Income,2,No,No,0 414 | 30,No,Low Income,4,Yes,No,0 415 | 27,Yes,High Income,1,No,No,1 416 | 36,Yes,Low Income,1,No,Yes,0 417 | 30,No,Low Income,2,Yes,Yes,1 418 | 34,No,Middle Income,3,Yes,Yes,0 419 | 34,Yes,High Income,5,Yes,No,1 420 | 37,No,Low Income,1,No,No,0 421 | 27,No Record,Middle Income,4,No,No,1 422 | 37,No,Middle Income,6,No,Yes,0 423 | 30,Yes,Low Income,1,No,No,0 424 | 37,No,Middle Income,3,Yes,No,0 425 | 27,No,Middle Income,2,No,No,1 426 | 31,No,Low Income,1,Yes,Yes,0 427 | 31,Yes,High Income,1,No,Yes,0 428 | 27,No,Middle Income,4,Yes,Yes,0 429 | 27,No,Low Income,2,Yes,No,1 430 | 30,No,Low Income,3,No,No,0 431 | 35,Yes,High Income,1,No,No,1 432 | 28,No Record,Low Income,5,No,Yes,0 433 | 30,No,Middle Income,2,No,No,0 434 | 37,No,Middle Income,1,Yes,Yes,0 435 | 30,Yes,Low Income,4,Yes,No,1 436 | 37,No,Middle Income,3,No,No,0 437 | 37,Yes,Low Income,6,No,Yes,0 438 | 30,No,Low Income,1,Yes,No,0 439 | 29,Yes,High Income,1,Yes,No,0 440 | 36,No,Middle Income,1,No,Yes,0 441 | 30,No,Low Income,2,No,No,0 442 | 37,No,Middle Income,4,No,Yes,0 443 | 37,Yes,High Income,1,No,No,1 444 | 31,No,Low Income,1,Yes,No,0 445 | 27,No,Middle Income,5,No,No,1 446 | 30,No,Middle Income,1,No,Yes,0 447 | 29,Yes,Low Income,1,No,Yes,0 448 | 30,No,Middle Income,3,Yes,No,0 449 | 31,No,Middle Income,4,Yes,No,0 450 | 36,No,Low Income,1,No,Yes,0 451 | 34,Yes,High Income,1,No,No,0 452 | 34,No,Middle Income,6,No,Yes,0 453 | 30,No,Low Income,2,Yes,No,1 454 | 37,No Record,Middle Income,3,Yes,No,0 455 | 34,Yes,High Income,1,No,No,1 456 | 30,No,Low Income,4,No,No,0 457 | 38,No,Middle Income,2,No,Yes,0 458 | 28,No,Low Income,5,Yes,Yes,0 459 | 27,Yes,Low Income,1,Yes,No,1 460 | 28,No,Middle Income,3,No,No,0 461 | 37,No,Middle Income,2,No,No,0 462 | 27,No,Low Income,1,Yes,Yes,0 463 | 38,Yes,High Income,4,No,No,0 464 | 28,No,Middle Income,1,Yes,Yes,0 465 | 36,Yes,Low Income,2,No,No,0 466 | 37,No,Middle Income,3,No,Yes,0 467 | 28,Yes,High Income,6,No,Yes,1 468 | 30,No,Low Income,1,Yes,No,0 469 | 28,No,Middle Income,2,Yes,No,1 470 | 30,No,Middle Income,4,No,Yes,0 471 | 27,Yes,Low Income,5,Yes,No,1 472 | 30,No,Low Income,3,No,Yes,0 473 | 28,No,Middle Income,2,No,No,1 474 | 30,No,Low Income,1,Yes,Yes,0 475 | 27,Yes,High Income,1,No,No,1 476 | 34,No Record,Middle Income,1,No,Yes,0 477 | 37,No,Low Income,4,No,Yes,0 478 | 36,No,Middle Income,3,Yes,No,0 479 | 37,Yes,High Income,1,Yes,No,1 480 | 31,No,Low Income,1,Yes,No,0 481 | 37,No,Middle Income,2,No,No,0 482 | 37,No,Middle Income,6,No,Yes,0 483 | 31,Yes,Low Income,1,No,No,0 484 | 35,No,Middle Income,5,Yes,No,1 485 | 30,No,Middle Income,2,No,No,0 486 | 36,Yes,Low Income,1,No,No,0 487 | 29,Yes,High Income,1,No,Yes,0 488 | 36,No,Middle Income,1,Yes,Yes,0 489 | 37,No,Low Income,2,Yes,No,1 490 | 30,No,Middle Income,3,No,Yes,0 491 | 33,Yes,High Income,4,No,No,1 492 | 31,No,Low Income,1,No,Yes,0 493 | 30,No,Low Income,2,No,No,0 494 | 37,No,Middle Income,1,Yes,Yes,0 495 | 31,Yes,Low Income,1,No,No,0 496 | 37,No,Middle Income,3,No,No,0 497 | 34,No,Middle Income,6,No,Yes,0 498 | 28,No Record,Low Income,4,Yes,Yes,0 499 | 30,Yes,High Income,1,Yes,No,0 500 | 31,No,Low Income,1,No,Yes,0 501 | 35,No,Low Income,2,No,No,0 502 | 28,No,Middle Income,3,No,Yes,0 503 | 34,Yes,High Income,1,No,No,1 504 | 30,No,Low Income,1,Yes,No,0 505 | 34,No,Middle Income,4,No,No,0 506 | 31,No,Middle Income,1,No,Yes,0 507 | 30,Yes,Low Income,1,Yes,Yes,1 508 | 30,No,Middle Income,3,Yes,No,0 509 | 35,No Record,Middle Income,2,Yes,No,0 510 | 30,No,Low Income,5,No,No,0 511 | 30,Yes,High Income,1,No,No,0 512 | 34,No,Middle Income,6,No,Yes,0 513 | 34,No,Low Income,2,No,No,0 514 | 31,No,Low Income,3,Yes,Yes,0 515 | 29,Yes,High Income,1,No,No,1 516 | 30,No,Low Income,1,Yes,No,0 517 | 29,No,Middle Income,2,No,Yes,0 518 | 34,No,Middle Income,1,Yes,Yes,0 519 | 34,Yes,Low Income,4,Yes,No,1 520 | 31,No Record,Middle Income,3,No,No,0 521 | 27,No,Low Income,2,No,No,1 522 | 38,No,Low Income,1,No,Yes,0 523 | 30,Yes,High Income,5,No,No,0 524 | 38,No,Middle Income,1,Yes,Yes,0 525 | 36,No,Low Income,2,Yes,No,1 526 | 35,No,Middle Income,4,No,No,0 527 | 28,Yes,High Income,6,No,Yes,1 528 | 34,Yes,Low Income,1,Yes,No,0 529 | 30,No,Middle Income,2,Yes,No,0 530 | 27,No,Middle Income,1,No,Yes,0 531 | 30,Yes,Low Income,1,No,No,0 532 | 31,No,Middle Income,3,No,Yes,0 533 | 30,No,Middle Income,4,No,No,0 534 | 34,No,Low Income,1,Yes,No,0 535 | 31,Yes,High Income,1,No,No,0 536 | 29,No,Middle Income,5,No,Yes,0 537 | 34,No,Low Income,2,No,Yes,0 538 | 33,No,Middle Income,3,Yes,Yes,0 539 | 31,Yes,High Income,1,Yes,No,1 540 | 30,No,Low Income,4,No,No,0 541 | 36,No,Middle Income,2,No,No,0 542 | 28,No Record,Low Income,6,No,Yes,0 543 | 31,Yes,Low Income,1,Yes,No,1 544 | 29,No,Middle Income,3,Yes,No,0 545 | 37,No,Middle Income,2,No,No,0 546 | 34,No,Low Income,1,No,Yes,0 547 | 28,Yes,High Income,4,No,Yes,1 548 | 30,No,Middle Income,1,Yes,Yes,0 549 | 28,No,Low Income,5,Yes,No,1 550 | 29,No,Middle Income,3,No,No,0 551 | 30,Yes,High Income,1,No,No,1 552 | 37,No,Low Income,1,Yes,Yes,0 553 | 27,No Record,Middle Income,2,No,No,1 554 | 29,No,Middle Income,4,Yes,Yes,0 555 | 36,Yes,Low Income,1,No,No,0 556 | 35,Yes,Low Income,3,No,No,1 557 | 28,No,Middle Income,6,No,Yes,1 558 | 36,No,Low Income,1,Yes,No,0 559 | 29,Yes,High Income,1,Yes,No,0 560 | 37,No,Middle Income,1,No,Yes,0 561 | 31,No,Low Income,4,Yes,No,1 562 | 28,No,Middle Income,5,No,Yes,0 563 | 31,Yes,High Income,1,No,No,1 564 | 36,No Record,Low Income,1,Yes,No,0 565 | 36,No,Middle Income,2,No,No,0 566 | 31,No,Middle Income,1,No,Yes,0 567 | 30,Yes,Low Income,1,No,Yes,0 568 | 29,No,Middle Income,4,Yes,No,0 569 | 38,No,Middle Income,2,Yes,No,0 570 | 38,Yes,Low Income,1,Yes,Yes,0 571 | 29,Yes,High Income,1,No,No,0 572 | 27,No,Middle Income,6,No,Yes,0 573 | 35,No,Low Income,2,No,No,0 574 | 30,No,Middle Income,3,Yes,No,0 575 | 29,Yes,High Income,5,No,No,1 576 | 31,No,Low Income,1,No,No,0 577 | 31,No,Low Income,2,No,Yes,0 578 | 34,No,Middle Income,1,Yes,Yes,0 579 | 29,Yes,Low Income,1,Yes,No,1 580 | 37,No,Middle Income,3,No,No,0 581 | 36,No,Middle Income,2,No,No,0 582 | 34,No,Low Income,4,No,Yes,0 583 | 30,Yes,High Income,1,No,No,0 584 | 31,No,Low Income,1,Yes,Yes,0 585 | 31,No,Low Income,2,No,No,0 586 | 31,No Record,Middle Income,3,No,Yes,0 587 | 37,Yes,High Income,6,No,Yes,1 588 | 30,No,Low Income,5,Yes,No,0 589 | 31,No,Middle Income,4,Yes,No,0 590 | 38,No,Middle Income,1,No,Yes,0 591 | 28,Yes,Low Income,1,No,No,1 592 | 30,No,Middle Income,3,No,Yes,0 593 | 27,No,Middle Income,2,No,No,1 594 | 27,No,Low Income,1,Yes,No,0 595 | 30,Yes,High Income,1,No,No,0 596 | 36,No,Middle Income,4,No,Yes,0 597 | 27,No Record,Low Income,2,Yes,Yes,1 598 | 34,Yes,Low Income,3,Yes,No,1 599 | 37,Yes,High Income,1,Yes,No,1 600 | 38,No,Low Income,1,No,No,0 601 | 37,No,Middle Income,5,No,No,1 602 | 37,No,Middle Income,6,No,Yes,0 603 | 37,Yes,Low Income,4,No,No,0 604 | 37,No,Middle Income,3,Yes,No,0 605 | 36,Yes,Low Income,2,No,No,0 606 | 35,No,Low Income,1,Yes,No,0 607 | 30,Yes,High Income,1,No,Yes,0 608 | 30,No Record,Middle Income,1,Yes,Yes,0 609 | 30,No,Low Income,2,Yes,No,1 610 | 30,No,Middle Income,4,No,Yes,0 611 | 30,Yes,High Income,1,No,No,1 612 | 34,Yes,Low Income,1,No,Yes,0 613 | 30,No,Middle Income,2,No,No,0 614 | 30,No,Middle Income,5,Yes,Yes,0 615 | 34,Yes,Low Income,1,Yes,No,1 616 | 37,No,Middle Income,3,No,No,0 617 | 30,No,Middle Income,6,No,Yes,0 618 | 29,No,Low Income,1,Yes,Yes,0 619 | 33,Yes,High Income,1,Yes,No,0 620 | 34,No,Middle Income,1,No,Yes,0 621 | 38,No,Low Income,2,No,No,0 622 | 31,No,Middle Income,3,No,Yes,0 623 | 30,Yes,High Income,1,No,No,1 624 | 35,No,Low Income,4,Yes,No,0 625 | 30,No,Middle Income,2,No,No,0 626 | 37,Yes,Low Income,1,No,Yes,0 627 | 34,Yes,Low Income,5,No,Yes,0 628 | 30,No,Middle Income,3,Yes,No,0 629 | 29,No,Middle Income,2,Yes,No,0 630 | 30,No Record,Low Income,1,No,No,0 631 | 27,Yes,High Income,4,No,No,1 632 | 30,No,Middle Income,6,No,Yes,0 633 | 36,Yes,Low Income,2,Yes,No,1 634 | 37,No,Middle Income,3,Yes,Yes,0 635 | 30,Yes,High Income,1,No,No,1 636 | 37,No,Low Income,1,No,No,0 637 | 30,No,Middle Income,2,No,Yes,0 638 | 27,No,Middle Income,4,Yes,Yes,0 639 | 37,Yes,Low Income,1,Yes,No,1 640 | 28,No,Low Income,5,No,No,0 641 | 34,No Record,Middle Income,2,No,No,0 642 | 29,No,Low Income,1,Yes,Yes,0 643 | 37,Yes,High Income,1,No,No,0 644 | 28,No,Middle Income,1,Yes,Yes,0 645 | 30,No,Low Income,4,No,No,0 646 | 35,No,Middle Income,3,No,No,0 647 | 30,Yes,High Income,6,No,Yes,1 648 | 30,No,Low Income,1,Yes,No,0 649 | 37,No,Middle Income,2,Yes,No,0 650 | 30,No,Middle Income,1,No,Yes,0 651 | 30,Yes,Low Income,1,Yes,No,1 652 | 28,No Record,Middle Income,4,No,Yes,0 653 | 36,No,Middle Income,5,No,No,1 654 | 36,Yes,Low Income,1,Yes,No,0 655 | 29,Yes,High Income,1,No,No,0 656 | 35,No,Middle Income,1,No,Yes,0 657 | 30,No,Low Income,2,No,Yes,0 658 | 34,No,Middle Income,3,Yes,Yes,0 659 | 30,Yes,High Income,4,Yes,No,1 660 | 33,No,Low Income,1,Yes,No,0 661 | 37,Yes,Low Income,2,No,No,0 662 | 30,No,Middle Income,6,No,Yes,0 663 | 30,Yes,Low Income,1,No,No,0 664 | 28,No,Middle Income,3,Yes,No,0 665 | 35,No,Middle Income,2,No,No,0 666 | 37,No,Low Income,5,No,Yes,1 667 | 30,Yes,High Income,1,No,Yes,0 668 | 30,No,Low Income,1,Yes,Yes,0 669 | 28,No,Low Income,2,Yes,No,1 670 | 36,No,Middle Income,3,No,No,0 671 | 28,Yes,High Income,1,No,No,1 672 | 30,No,Low Income,1,No,Yes,0 673 | 31,No,Middle Income,4,No,No,0 674 | 29,No Record,Middle Income,1,Yes,Yes,0 675 | 33,Yes,Low Income,1,No,No,0 676 | 29,No,Middle Income,3,No,No,0 677 | 29,No,Middle Income,6,No,Yes,0 678 | 31,No,Low Income,1,Yes,No,0 679 | 29,Yes,High Income,5,Yes,No,0 680 | 27,No,Middle Income,4,No,Yes,0 681 | 28,No,Low Income,2,No,No,1 682 | 31,No,Low Income,3,No,Yes,0 683 | 30,Yes,High Income,1,No,No,1 684 | 28,No,Low Income,1,Yes,No,0 685 | 30,No Record,Middle Income,2,No,No,0 686 | 27,No,Middle Income,1,No,Yes,0 687 | 37,Yes,Low Income,4,Yes,Yes,1 688 | 36,No,Middle Income,3,Yes,No,0 689 | 30,No,Low Income,2,Yes,No,0 690 | 36,No,Low Income,1,No,Yes,0 691 | 29,Yes,High Income,1,No,No,0 692 | 36,No,Middle Income,6,No,Yes,1 693 | 34,No,Low Income,2,No,No,0 694 | 34,No,Middle Income,4,Yes,No,0 695 | 29,Yes,High Income,1,No,No,1 696 | 37,Yes,Low Income,1,Yes,No,0 697 | 30,No,Middle Income,2,No,Yes,0 698 | 35,No,Middle Income,1,Yes,Yes,0 699 | 27,Yes,Low Income,1,Yes,No,1 700 | 36,No,Middle Income,3,No,No,0 701 | 35,No,Middle Income,4,No,No,0 702 | 30,No,Low Income,1,No,Yes,0 703 | 31,Yes,High Income,1,No,No,0 704 | 30,No,Middle Income,1,Yes,Yes,0 705 | 29,No,Low Income,5,Yes,No,1 706 | 36,No,Middle Income,3,No,Yes,0 707 | 29,Yes,High Income,6,No,Yes,1 708 | 37,No,Low Income,4,Yes,No,0 709 | 30,No,Middle Income,2,Yes,No,0 710 | 37,Yes,Low Income,1,No,Yes,0 711 | 36,Yes,Low Income,1,No,No,0 712 | 28,No,Middle Income,3,No,Yes,0 713 | 37,No,Middle Income,2,No,No,0 714 | 30,No,Low Income,1,Yes,Yes,0 715 | 30,Yes,High Income,4,No,No,0 716 | 30,No,Middle Income,1,No,Yes,0 717 | 30,No,Low Income,2,No,Yes,0 718 | 31,No Record,Middle Income,5,Yes,No,0 719 | 30,Yes,High Income,1,Yes,No,1 720 | 29,No,Low Income,1,No,No,0 721 | 37,No,Middle Income,2,No,No,0 722 | 29,No,Middle Income,6,No,Yes,0 723 | 37,Yes,Low Income,1,Yes,No,1 724 | 36,Yes,Low Income,3,Yes,No,1 725 | 31,No,Middle Income,2,No,No,0 726 | 27,No,Low Income,1,No,No,0 727 | 33,Yes,High Income,1,No,Yes,0 728 | 27,No,Middle Income,1,Yes,Yes,0 729 | 37,No Record,Low Income,4,Yes,No,1 730 | 30,No,Middle Income,3,No,Yes,0 731 | 29,Yes,High Income,5,No,No,1 732 | 34,No,Low Income,1,Yes,Yes,0 733 | 30,No,Middle Income,2,No,No,0 734 | 33,No,Middle Income,1,Yes,Yes,0 735 | 29,Yes,Low Income,1,No,No,0 736 | 37,No,Middle Income,4,No,No,0 737 | 30,No,Middle Income,6,No,Yes,0 738 | 33,Yes,Low Income,1,Yes,Yes,0 739 | 30,Yes,High Income,1,Yes,No,0 740 | 30,No Record,Middle Income,1,No,Yes,0 741 | 36,No,Low Income,2,Yes,No,1 742 | 37,No,Middle Income,3,No,Yes,0 743 | 37,Yes,High Income,4,No,No,1 744 | 37,No,Low Income,5,Yes,No,1 745 | 31,No,Low Income,2,No,No,0 746 | 31,No,Middle Income,1,No,Yes,0 747 | 30,Yes,Low Income,1,No,Yes,0 748 | 29,No,Middle Income,3,Yes,No,0 749 | 37,No,Middle Income,2,Yes,No,0 750 | 30,No,Low Income,4,Yes,No,0 751 | 37,Yes,High Income,1,No,No,0 752 | 38,Yes,Low Income,6,No,Yes,0 753 | 30,No,Low Income,2,No,No,0 754 | 30,No,Middle Income,3,Yes,Yes,0 755 | 36,Yes,High Income,1,No,No,1 756 | 35,No,Low Income,1,No,No,0 757 | 31,No,Middle Income,5,No,Yes,0 758 | 30,No,Middle Income,1,Yes,Yes,0 759 | 29,Yes,Low Income,1,Yes,No,1 760 | 36,No,Middle Income,3,No,No,0 761 | 37,No,Middle Income,2,No,No,0 762 | 29,No Record,Low Income,1,No,Yes,0 763 | 28,Yes,High Income,1,No,No,1 764 | 30,No,Middle Income,4,Yes,Yes,0 765 | 30,No,Low Income,2,No,No,0 766 | 29,No,Low Income,3,No,No,0 767 | 31,Yes,High Income,6,No,Yes,1 768 | 37,No,Low Income,1,Yes,No,0 769 | 30,No,Middle Income,2,Yes,No,0 770 | 37,No,Middle Income,5,No,Yes,1 771 | 33,Yes,Low Income,4,No,No,0 772 | 30,No,Middle Income,3,No,Yes,0 773 | 31,No Record,Low Income,2,No,No,0 774 | 38,No,Low Income,1,Yes,No,0 775 | 34,Yes,High Income,1,No,No,0 776 | 30,No,Middle Income,1,No,Yes,0 777 | 35,No,Low Income,2,Yes,Yes,1 778 | 27,No,Middle Income,4,Yes,No,0 779 | 29,Yes,High Income,1,Yes,No,1 780 | 29,No,Low Income,1,No,No,0 781 | 27,No,Middle Income,2,No,No,1 782 | 34,No,Middle Income,6,No,Yes,0 783 | 29,Yes,Low Income,5,No,No,0 784 | 30,No Record,Middle Income,3,Yes,No,0 785 | 31,No,Middle Income,4,No,No,0 786 | 34,No,Low Income,1,Yes,Yes,0 787 | 31,Yes,High Income,1,No,Yes,0 788 | 31,No,Middle Income,1,Yes,Yes,0 789 | 37,No,Low Income,2,Yes,No,1 790 | 37,No,Middle Income,3,No,No,0 791 | 30,Yes,High Income,1,No,No,1 792 | 35,No,Low Income,4,No,Yes,0 793 | 38,No,Middle Income,2,No,No,0 794 | 30,No,Low Income,1,Yes,Yes,0 795 | 30,Yes,Low Income,1,Yes,No,1 796 | 28,No,Middle Income,5,No,No,0 797 | 35,No,Middle Income,6,No,Yes,0 798 | 30,No,Low Income,1,Yes,No,0 799 | 30,Yes,High Income,4,Yes,No,0 800 | 37,No,Middle Income,1,No,Yes,0 801 | 34,Yes,Low Income,2,No,No,0 802 | 34,No,Middle Income,3,No,Yes,0 803 | 34,Yes,High Income,1,No,No,1 804 | 33,No,Low Income,1,Yes,No,0 805 | 31,No,Middle Income,2,No,No,0 806 | 29,No Record,Middle Income,4,No,Yes,0 807 | 31,Yes,Low Income,1,No,Yes,0 808 | 33,Yes,Low Income,3,Yes,No,1 809 | 29,No,Middle Income,5,Yes,No,0 810 | 37,No,Low Income,1,No,Yes,0 811 | 33,Yes,High Income,1,No,No,0 812 | 31,No,Middle Income,6,No,Yes,0 813 | 28,No,Low Income,4,Yes,No,1 814 | 36,No,Middle Income,3,Yes,No,0 815 | 33,Yes,High Income,1,No,No,1 816 | 31,No,Low Income,1,No,No,0 817 | 27,No Record,Middle Income,2,No,Yes,1 818 | 37,No,Middle Income,1,Yes,Yes,0 819 | 29,Yes,Low Income,1,Yes,No,1 820 | 30,No,Middle Income,4,No,No,0 821 | 30,No,Middle Income,2,No,No,0 822 | 27,No,Low Income,5,Yes,Yes,0 823 | 37,Yes,High Income,1,No,No,0 824 | 31,No,Middle Income,1,Yes,Yes,0 825 | 30,No,Low Income,2,No,No,0 826 | 35,No,Middle Income,3,No,Yes,0 827 | 31,Yes,High Income,6,No,Yes,1 828 | 30,No Record,Low Income,1,Yes,No,0 829 | 35,Yes,Low Income,2,Yes,No,0 830 | 37,No,Middle Income,1,No,Yes,0 831 | 28,Yes,Low Income,1,Yes,No,1 832 | 37,No,Middle Income,3,No,Yes,0 833 | 29,No,Middle Income,2,No,No,0 834 | 37,No,Low Income,4,Yes,Yes,0 835 | 34,Yes,High Income,5,No,No,0 836 | 36,Yes,Low Income,1,No,Yes,0 837 | 38,No,Low Income,2,No,Yes,0 838 | 30,No,Middle Income,3,Yes,No,0 839 | 31,Yes,High Income,1,Yes,No,1 840 | 31,No,Low Income,1,Yes,No,0 841 | 28,No,Middle Income,4,No,No,1 842 | 30,No,Middle Income,6,No,Yes,0 843 | 37,Yes,Low Income,1,No,No,0 844 | 31,No,Middle Income,3,Yes,No,0 845 | 31,No,Middle Income,2,No,No,0 846 | 36,No,Low Income,1,No,No,0 847 | 30,Yes,High Income,1,No,Yes,0 848 | 30,No,Middle Income,5,Yes,Yes,0 849 | 31,No,Low Income,2,Yes,No,1 850 | 38,Yes,Low Income,3,No,Yes,1 851 | 37,Yes,High Income,1,No,No,1 852 | 35,No,Low Income,1,No,Yes,0 853 | 30,No,Middle Income,2,No,No,0 854 | 29,No,Middle Income,1,Yes,Yes,0 855 | 34,Yes,Low Income,4,No,No,0 856 | 35,No,Middle Income,3,No,No,0 857 | 28,No,Low Income,6,No,Yes,1 858 | 37,No,Low Income,1,Yes,Yes,0 859 | 30,Yes,High Income,1,Yes,No,0 860 | 29,No,Middle Income,1,No,Yes,0 861 | 30,No Record,Low Income,5,No,No,0 862 | 37,No,Middle Income,4,No,Yes,0 863 | 33,Yes,High Income,1,No,No,1 864 | 30,No,Low Income,1,Yes,No,0 865 | 31,No,Middle Income,2,No,No,0 866 | 28,No,Middle Income,1,No,Yes,0 867 | 34,Yes,Low Income,1,Yes,Yes,1 868 | 34,No,Middle Income,3,Yes,No,0 869 | 36,No,Middle Income,4,Yes,No,0 870 | 30,No,Low Income,1,No,No,0 871 | 36,Yes,High Income,1,No,No,0 872 | 38,No Record,Middle Income,6,No,Yes,0 873 | 27,No,Low Income,2,No,No,1 874 | 29,No,Middle Income,5,Yes,Yes,0 875 | 35,Yes,High Income,1,No,No,1 876 | 30,No,Low Income,4,Yes,No,0 877 | 31,No,Middle Income,2,No,Yes,0 878 | 30,No,Low Income,1,Yes,Yes,0 879 | 28,Yes,Low Income,1,Yes,No,1 880 | 28,No,Middle Income,3,No,No,0 881 | 31,No,Middle Income,2,No,No,0 882 | 30,No,Low Income,1,No,Yes,0 883 | 31,Yes,High Income,4,No,No,0 884 | 28,No,Middle Income,1,Yes,Yes,0 885 | 30,No,Low Income,2,Yes,No,1 886 | 38,No,Middle Income,3,No,No,0 887 | 35,Yes,High Income,6,No,Yes,1 888 | 34,No,Low Income,1,Yes,No,0 889 | 30,No,Middle Income,2,Yes,No,0 890 | 30,No,Middle Income,4,No,Yes,0 891 | 38,Yes,Low Income,1,No,No,0 892 | 36,Yes,Low Income,3,No,Yes,1 893 | 31,No,Middle Income,2,No,No,0 894 | 31,No Record,Low Income,1,Yes,No,0 895 | 34,Yes,High Income,1,No,No,0 896 | 33,No,Middle Income,1,No,Yes,0 897 | 30,No,Low Income,4,No,Yes,0 898 | 36,No,Middle Income,3,Yes,Yes,0 899 | 34,Yes,High Income,1,Yes,No,1 900 | 34,No,Low Income,5,No,No,0 901 | 27,No,Middle Income,2,No,No,1 902 | 31,No,Middle Income,6,No,Yes,0 903 | 36,Yes,Low Income,1,Yes,No,1 904 | 37,No,Middle Income,4,Yes,No,0 905 | 30,No Record,Middle Income,2,No,No,0 906 | 30,No,Low Income,1,No,Yes,0 907 | 30,Yes,High Income,1,No,Yes,0 908 | 37,No,Middle Income,1,Yes,Yes,0 909 | 28,No,Low Income,2,Yes,No,1 910 | 37,No,Middle Income,3,No,No,0 911 | 28,Yes,High Income,4,No,No,1 912 | 37,No,Low Income,1,Yes,Yes,0 913 | 29,No,Low Income,5,No,No,0 914 | 34,No,Middle Income,1,Yes,Yes,0 915 | 34,Yes,Low Income,1,No,No,0 916 | 35,No Record,Middle Income,3,No,No,0 917 | 30,No,Middle Income,6,No,Yes,0 918 | 28,No,Low Income,4,Yes,No,0 919 | 30,Yes,High Income,1,Yes,No,0 920 | 36,Yes,Low Income,1,No,Yes,0 921 | 28,No,Low Income,2,Yes,No,1 922 | 35,No,Middle Income,3,No,Yes,0 923 | 36,Yes,High Income,1,No,No,1 924 | 31,No,Low Income,1,Yes,No,0 925 | 30,No,Middle Income,4,No,No,0 926 | 34,No,Middle Income,5,No,Yes,0 927 | 33,Yes,Low Income,1,No,Yes,0 928 | 37,No,Middle Income,3,Yes,No,0 929 | 36,No,Middle Income,2,Yes,No,0 930 | 28,No,Low Income,1,Yes,Yes,0 931 | 28,Yes,High Income,1,No,No,1 932 | 37,No,Middle Income,6,No,Yes,0 933 | 30,No,Low Income,2,No,No,0 934 | 29,No,Low Income,3,Yes,No,0 935 | 30,Yes,High Income,1,No,No,1 936 | 30,No,Low Income,1,No,No,0 937 | 29,No,Middle Income,2,No,Yes,0 938 | 36,No Record,Middle Income,1,Yes,Yes,0 939 | 30,Yes,Low Income,5,Yes,No,1 940 | 30,No,Middle Income,3,No,No,0 941 | 34,Yes,Low Income,2,No,No,0 942 | 27,No,Low Income,1,No,Yes,0 943 | 37,Yes,High Income,1,No,No,0 944 | 27,No,Middle Income,1,Yes,Yes,0 945 | 31,No,Low Income,2,No,No,0 946 | 36,No,Middle Income,4,No,Yes,0 947 | 30,Yes,High Income,6,No,Yes,1 948 | 27,No,Low Income,1,Yes,No,0 949 | 38,No Record,Middle Income,2,Yes,No,0 950 | 31,No,Middle Income,1,No,Yes,0 951 | 31,Yes,Low Income,1,No,No,0 952 | 30,No,Middle Income,5,No,Yes,0 953 | 37,No,Middle Income,4,No,No,0 954 | 30,No,Low Income,1,Yes,Yes,0 955 | 31,Yes,High Income,1,No,No,0 956 | -------------------------------------------------------------------------------- /data/travel_test.csv: -------------------------------------------------------------------------------- 1 | Age,FrequentFlyer,AnnualIncomeClass,ServicesOpted,AccountSyncedToSocialMedia,BookedHotelOrNot,Target 2 | 36,Yes,High Income,1,No,No,1 3 | 28,No,Middle Income,2,No,No,1 4 | 37,No,Low Income,2,Yes,No,1 5 | 30,No Record,Middle Income,2,No,No,0 6 | 28,No,Low Income,1,Yes,No,0 7 | 27,No,Middle Income,3,Yes,No,0 8 | 34,Yes,Low Income,1,Yes,Yes,0 9 | 36,Yes,High Income,4,No,No,0 10 | 28,No,Middle Income,1,Yes,Yes,0 11 | 30,No,Low Income,2,No,No,0 12 | 38,No,Middle Income,2,Yes,No,0 13 | 30,Yes,High Income,5,No,No,0 14 | 37,Yes,High Income,1,Yes,No,1 15 | 36,No,Low Income,1,No,No,0 16 | 37,No,Low Income,1,No,Yes,0 17 | 30,No,Middle Income,1,Yes,Yes,0 18 | 28,No,Middle Income,3,No,Yes,0 19 | 30,No,Low Income,1,Yes,No,0 20 | 35,No,Middle Income,1,No,Yes,0 21 | 35,No,Middle Income,3,Yes,No,0 22 | 30,No,Middle Income,2,Yes,No,0 23 | 30,No,Low Income,2,Yes,No,1 24 | 37,No,Low Income,1,No,No,0 25 | 28,No,Middle Income,2,No,Yes,1 26 | 31,No,Middle Income,3,No,No,0 27 | 31,No,Low Income,2,No,No,0 28 | 31,No,Low Income,2,No,Yes,0 29 | 34,No,Middle Income,2,No,No,0 30 | 29,No,Middle Income,2,No,No,0 31 | 34,No,Middle Income,3,No,Yes,0 32 | 30,No,Low Income,1,No,Yes,0 33 | 30,No,Middle Income,5,No,No,0 34 | 38,No,Low Income,4,No,No,0 35 | 28,No,Middle Income,2,No,No,1 36 | 34,No,Middle Income,2,Yes,No,0 37 | 27,No,Low Income,1,No,No,0 38 | 30,Yes,High Income,1,No,No,0 39 | 31,No,Low Income,2,No,No,0 40 | 30,No,Middle Income,3,Yes,Yes,0 41 | 27,Yes,Low Income,1,No,No,1 42 | 37,No,Middle Income,4,No,Yes,0 43 | 30,Yes,High Income,1,Yes,No,1 44 | 30,No,Middle Income,6,No,Yes,0 45 | 38,No,Middle Income,3,Yes,No,0 46 | 27,No,Low Income,5,No,No,1 47 | 27,No,Low Income,1,No,No,0 48 | 30,No,Middle Income,1,Yes,Yes,0 49 | 31,Yes,High Income,1,Yes,No,0 50 | 34,Yes,High Income,1,No,No,1 51 | 30,No,Middle Income,2,No,No,0 52 | 27,Yes,High Income,5,No,No,1 53 | 37,No,Low Income,1,Yes,No,0 54 | 37,No,Middle Income,2,Yes,No,0 55 | 29,No,Low Income,4,Yes,No,0 56 | 37,No,Middle Income,5,No,Yes,1 57 | 37,No,Low Income,1,Yes,Yes,0 58 | 36,Yes,High Income,1,Yes,No,0 59 | 28,No,Low Income,3,No,Yes,0 60 | 29,Yes,Low Income,4,No,Yes,0 61 | 29,Yes,High Income,1,No,No,0 62 | 37,Yes,High Income,1,No,No,1 63 | 28,No,Middle Income,2,No,Yes,1 64 | 34,Yes,Low Income,1,Yes,No,1 65 | 35,No,Low Income,1,Yes,Yes,0 66 | 30,No,Low Income,4,Yes,No,0 67 | 31,No,Low Income,1,No,Yes,0 68 | 37,No,Middle Income,2,No,No,0 69 | 29,No,Low Income,1,Yes,No,0 70 | 35,No,Middle Income,3,Yes,Yes,0 71 | 28,Yes,High Income,1,No,Yes,1 72 | 30,No,Low Income,4,Yes,No,1 73 | 28,No,Middle Income,4,No,No,0 74 | 30,Yes,High Income,4,No,No,1 75 | 36,Yes,Low Income,2,No,No,0 76 | 30,No,Middle Income,1,No,Yes,0 77 | 37,No,Low Income,4,No,Yes,0 78 | 30,No,Middle Income,4,No,Yes,0 79 | 27,No,Middle Income,2,No,No,1 80 | 38,No,Middle Income,3,No,Yes,0 81 | 34,No,Low Income,4,Yes,Yes,0 82 | 29,No,Middle Income,2,No,No,0 83 | 30,Yes,Low Income,1,No,No,0 84 | 29,Yes,High Income,4,Yes,No,0 85 | 31,No,Middle Income,2,No,No,0 86 | 30,No,Middle Income,2,Yes,No,0 87 | 30,No,Low Income,1,Yes,No,0 88 | 36,No,Middle Income,1,Yes,Yes,0 89 | 30,No,Middle Income,4,No,No,0 90 | 30,No,Middle Income,3,No,Yes,0 91 | 30,No,Middle Income,2,No,No,0 92 | 27,No,Middle Income,2,No,No,1 93 | 28,No Record,Low Income,5,No,Yes,0 94 | 30,No,Middle Income,2,No,No,0 95 | 31,No,Low Income,1,Yes,No,0 96 | 30,No,Middle Income,1,No,Yes,0 97 | 36,No,Low Income,1,No,Yes,0 98 | 34,No,Middle Income,6,No,Yes,0 99 | 37,No,Middle Income,2,No,No,0 100 | 38,Yes,High Income,4,No,No,0 101 | 30,No,Low Income,1,Yes,Yes,0 102 | 36,No,Middle Income,1,Yes,Yes,0 103 | 30,No,Middle Income,3,No,Yes,0 104 | 34,No,Middle Income,6,No,Yes,0 105 | 31,No,Low Income,3,Yes,Yes,0 106 | 29,No,Middle Income,2,No,Yes,0 107 | 27,No,Middle Income,1,No,Yes,0 108 | 31,Yes,High Income,1,Yes,No,1 109 | 37,No,Middle Income,2,No,No,0 110 | 34,No,Low Income,1,No,Yes,0 111 | 37,No,Low Income,1,Yes,Yes,0 112 | 27,No Record,Middle Income,2,No,No,1 113 | 35,Yes,Low Income,3,No,No,1 114 | 28,No,Middle Income,6,No,Yes,1 115 | 31,No,Low Income,4,Yes,No,1 116 | 28,No,Middle Income,5,No,Yes,0 117 | 36,No Record,Low Income,1,Yes,No,0 118 | 30,Yes,Low Income,1,No,Yes,0 119 | 31,No,Low Income,1,No,No,0 120 | 30,Yes,High Income,1,No,No,0 121 | 28,Yes,Low Income,1,No,No,1 122 | 36,No,Middle Income,4,No,Yes,0 123 | 38,No,Low Income,1,No,No,0 124 | 37,No,Middle Income,5,No,No,1 125 | 37,No,Middle Income,6,No,Yes,0 126 | 30,No Record,Middle Income,1,Yes,Yes,0 127 | 30,No,Low Income,2,Yes,No,1 128 | 34,Yes,Low Income,1,No,Yes,0 129 | 30,No,Middle Income,2,No,No,0 130 | 33,Yes,High Income,1,Yes,No,0 131 | 30,Yes,High Income,1,No,No,1 132 | 27,Yes,High Income,4,No,No,1 133 | 30,Yes,High Income,1,No,No,1 134 | 37,Yes,Low Income,1,Yes,No,1 135 | 34,No Record,Middle Income,2,No,No,0 136 | 37,Yes,Low Income,2,No,No,0 137 | 30,No,Middle Income,6,No,Yes,0 138 | 29,No,Middle Income,6,No,Yes,0 139 | 29,Yes,High Income,5,Yes,No,0 140 | 27,No,Middle Income,1,No,Yes,0 141 | 36,No,Middle Income,6,No,Yes,1 142 | 35,No,Middle Income,1,Yes,Yes,0 143 | 27,Yes,Low Income,1,Yes,No,1 144 | 35,No,Middle Income,4,No,No,0 145 | 29,Yes,High Income,6,No,Yes,1 146 | 30,No,Middle Income,2,Yes,No,0 147 | 36,Yes,Low Income,1,No,No,0 148 | 30,No,Middle Income,1,No,Yes,0 149 | 37,No,Middle Income,2,No,No,0 150 | 31,No,Middle Income,2,No,No,0 151 | 37,No,Middle Income,4,No,No,0 152 | 36,No,Low Income,2,Yes,No,1 153 | 37,Yes,High Income,1,No,No,0 154 | 35,No,Low Income,1,No,No,0 155 | 31,No,Middle Income,5,No,Yes,0 156 | 30,No,Middle Income,1,Yes,Yes,0 157 | 30,No,Low Income,2,No,No,0 158 | 30,No,Middle Income,3,No,Yes,0 159 | 30,No,Middle Income,1,No,Yes,0 160 | 29,Yes,High Income,1,Yes,No,1 161 | 29,No,Low Income,1,No,No,0 162 | 27,No,Middle Income,2,No,No,1 163 | 30,No Record,Middle Income,3,Yes,No,0 164 | 38,No,Middle Income,2,No,No,0 165 | 30,Yes,High Income,4,Yes,No,0 166 | 33,No,Low Income,1,Yes,No,0 167 | 31,No,Middle Income,2,No,No,0 168 | 29,No Record,Middle Income,4,No,Yes,0 169 | 37,No,Middle Income,1,Yes,Yes,0 170 | 29,No,Middle Income,2,No,No,0 171 | 36,Yes,Low Income,1,No,Yes,0 172 | 38,No,Low Income,2,No,Yes,0 173 | 30,Yes,High Income,1,No,Yes,0 174 | 30,No,Middle Income,5,Yes,Yes,0 175 | 31,No,Low Income,2,Yes,No,1 176 | 37,Yes,High Income,1,No,No,1 177 | 35,No,Low Income,1,No,Yes,0 178 | 30,Yes,High Income,1,Yes,No,0 179 | 38,No Record,Middle Income,6,No,Yes,0 180 | 28,No,Middle Income,1,Yes,Yes,0 181 | 30,No,Middle Income,2,Yes,No,0 182 | 36,Yes,Low Income,3,No,Yes,1 183 | 31,No,Middle Income,2,No,No,0 184 | 34,No,Low Income,5,No,No,0 185 | 37,No,Middle Income,4,Yes,No,0 186 | 30,No,Low Income,1,No,Yes,0 187 | 28,No,Low Income,4,Yes,No,0 188 | 35,No,Middle Income,3,No,Yes,0 189 | 31,No,Low Income,1,Yes,No,0 190 | 37,No,Middle Income,6,No,Yes,0 191 | 30,No,Low Income,1,No,No,0 192 | 30,Yes,Low Income,5,Yes,No,1 193 | -------------------------------------------------------------------------------- /data/travel_train.csv: -------------------------------------------------------------------------------- 1 | Age,FrequentFlyer,AnnualIncomeClass,ServicesOpted,AccountSyncedToSocialMedia,BookedHotelOrNot,Target 2 | 30,No,Low Income,1,Yes,Yes,0 3 | 31,No,Middle Income,6,No,Yes,0 4 | 34,No,Middle Income,5,No,No,0 5 | 30,No,Middle Income,1,Yes,Yes,0 6 | 28,No,Low Income,4,Yes,Yes,1 7 | 31,Yes,High Income,1,No,Yes,0 8 | 29,No,Middle Income,3,No,Yes,0 9 | 37,No,Middle Income,6,No,Yes,0 10 | 36,Yes,Low Income,2,Yes,No,0 11 | 33,Yes,High Income,1,No,No,0 12 | 34,No,Low Income,2,No,No,0 13 | 29,No,Middle Income,2,No,No,0 14 | 37,Yes,High Income,1,Yes,No,1 15 | 27,No,Low Income,1,Yes,No,0 16 | 28,Yes,High Income,1,No,No,1 17 | 34,Yes,High Income,1,Yes,No,0 18 | 34,No,Middle Income,6,No,Yes,0 19 | 30,Yes,Low Income,1,Yes,No,1 20 | 36,Yes,High Income,1,No,No,1 21 | 31,No,Middle Income,1,No,Yes,0 22 | 30,Yes,High Income,5,No,No,1 23 | 27,No,Middle Income,1,Yes,Yes,0 24 | 30,Yes,High Income,1,No,No,1 25 | 31,No,Low Income,2,No,No,0 26 | 34,No,Low Income,2,Yes,No,1 27 | 37,No,Middle Income,3,No,No,0 28 | 35,No,Low Income,1,Yes,No,0 29 | 27,Yes,High Income,1,No,No,1 30 | 30,No,Middle Income,4,No,No,0 31 | 37,No,Middle Income,3,Yes,No,0 32 | 30,No,Middle Income,2,No,No,0 33 | 33,Yes,Low Income,3,Yes,No,1 34 | 29,No,Middle Income,5,Yes,No,0 35 | 30,No,Middle Income,2,Yes,No,0 36 | 30,Yes,Low Income,1,Yes,Yes,1 37 | 38,Yes,Low Income,1,Yes,Yes,0 38 | 28,No,Middle Income,3,No,No,0 39 | 28,No,Middle Income,1,Yes,Yes,0 40 | 34,No,Low Income,1,Yes,No,0 41 | 29,Yes,Low Income,5,No,No,0 42 | 34,No,Middle Income,3,Yes,No,0 43 | 35,Yes,Low Income,1,No,Yes,0 44 | 30,No Record,Middle Income,1,No,Yes,0 45 | 36,Yes,Low Income,1,No,No,0 46 | 38,No,Middle Income,3,No,No,0 47 | 38,No Record,Middle Income,3,No,Yes,0 48 | 30,Yes,Low Income,1,No,No,0 49 | 37,No,Middle Income,3,No,No,0 50 | 36,No,Middle Income,6,No,Yes,0 51 | 28,No,Middle Income,4,No,No,1 52 | 35,No,Low Income,2,No,No,0 53 | 37,No,Low Income,1,No,No,0 54 | 37,Yes,High Income,1,No,No,0 55 | 30,Yes,High Income,1,No,No,1 56 | 33,No,Low Income,2,No,Yes,0 57 | 30,No,Low Income,2,No,No,0 58 | 29,Yes,High Income,1,No,No,1 59 | 31,No,Low Income,3,No,Yes,0 60 | 31,Yes,High Income,6,No,Yes,1 61 | 30,Yes,High Income,1,Yes,No,1 62 | 27,No,Low Income,1,Yes,Yes,0 63 | 31,Yes,Low Income,1,No,No,0 64 | 37,No,Middle Income,3,No,No,0 65 | 37,Yes,High Income,6,No,Yes,1 66 | 27,Yes,High Income,1,No,No,1 67 | 27,No,Low Income,1,Yes,No,0 68 | 37,No,Low Income,5,Yes,No,1 69 | 30,No Record,Low Income,5,No,No,0 70 | 34,No,Low Income,5,Yes,Yes,0 71 | 30,Yes,High Income,1,No,No,1 72 | 34,No,Middle Income,3,Yes,Yes,0 73 | 28,Yes,High Income,1,No,Yes,1 74 | 36,No,Low Income,1,No,No,0 75 | 29,No,Low Income,5,No,No,0 76 | 30,No,Middle Income,2,No,Yes,0 77 | 33,No,Middle Income,1,No,Yes,0 78 | 37,No,Middle Income,3,Yes,No,0 79 | 29,Yes,Low Income,1,Yes,No,1 80 | 31,Yes,Low Income,4,Yes,No,1 81 | 30,No Record,Middle Income,2,No,No,0 82 | 28,No,Middle Income,3,No,No,0 83 | 35,No,Middle Income,3,No,Yes,0 84 | 30,Yes,Low Income,1,No,Yes,0 85 | 34,No,Middle Income,6,No,Yes,0 86 | 30,Yes,Low Income,1,Yes,Yes,1 87 | 28,Yes,Low Income,1,Yes,No,1 88 | 30,No,Low Income,1,Yes,No,0 89 | 31,No,Middle Income,4,Yes,No,0 90 | 34,Yes,Low Income,1,Yes,Yes,1 91 | 30,No,Middle Income,4,No,No,0 92 | 34,No,Middle Income,3,No,Yes,0 93 | 28,No,Low Income,2,Yes,No,1 94 | 35,No Record,Middle Income,3,Yes,No,0 95 | 37,Yes,Low Income,1,No,Yes,0 96 | 35,No,Middle Income,6,No,Yes,0 97 | 30,No Record,Low Income,3,Yes,No,0 98 | 35,No,Middle Income,3,No,No,0 99 | 31,No,Middle Income,4,No,Yes,0 100 | 35,No,Middle Income,3,No,No,0 101 | 30,No,Low Income,3,No,Yes,0 102 | 29,No,Middle Income,5,Yes,Yes,0 103 | 36,Yes,High Income,1,No,No,0 104 | 35,Yes,High Income,1,Yes,No,0 105 | 31,No,Low Income,2,No,No,0 106 | 38,No,Low Income,2,Yes,No,1 107 | 37,No,Middle Income,1,Yes,Yes,0 108 | 37,No,Middle Income,6,No,Yes,0 109 | 34,No Record,Low Income,1,No,No,0 110 | 31,No Record,Low Income,2,No,No,0 111 | 37,No,Middle Income,4,Yes,No,0 112 | 35,No,Low Income,1,Yes,No,0 113 | 30,No Record,Low Income,4,No,No,0 114 | 30,Yes,High Income,1,No,No,1 115 | 37,No,Low Income,1,No,Yes,0 116 | 28,No,Low Income,2,No,No,1 117 | 30,No Record,Middle Income,4,Yes,Yes,0 118 | 31,No,Low Income,1,No,No,0 119 | 36,No,Middle Income,4,No,Yes,0 120 | 37,Yes,Low Income,6,No,Yes,0 121 | 29,Yes,Low Income,1,Yes,No,1 122 | 37,No,Middle Income,4,No,No,0 123 | 30,No,Middle Income,5,Yes,No,0 124 | 35,No,Middle Income,1,No,Yes,0 125 | 36,No,Middle Income,3,No,Yes,0 126 | 33,No,Middle Income,6,No,Yes,0 127 | 31,No,Middle Income,1,Yes,Yes,0 128 | 30,No,Low Income,1,Yes,No,0 129 | 36,No,Middle Income,2,No,Yes,0 130 | 30,No,Low Income,2,No,No,0 131 | 30,No,Middle Income,3,No,No,0 132 | 30,No,Low Income,2,No,Yes,0 133 | 30,No Record,Low Income,1,Yes,No,0 134 | 27,No,Low Income,1,No,Yes,0 135 | 38,Yes,Low Income,6,No,Yes,0 136 | 37,No,Middle Income,5,No,Yes,1 137 | 34,No,Middle Income,2,No,No,0 138 | 31,Yes,High Income,1,No,Yes,0 139 | 37,No,Middle Income,2,No,No,0 140 | 35,No,Middle Income,3,No,No,0 141 | 29,No,Middle Income,3,No,No,0 142 | 34,No,Middle Income,3,Yes,No,0 143 | 29,No,Middle Income,6,No,Yes,0 144 | 28,No,Low Income,2,No,No,1 145 | 30,No,Middle Income,3,Yes,No,0 146 | 34,No Record,Middle Income,5,No,Yes,0 147 | 33,No,Middle Income,1,Yes,Yes,0 148 | 34,Yes,High Income,1,No,No,1 149 | 30,No,Low Income,2,No,Yes,0 150 | 30,Yes,Low Income,1,No,No,0 151 | 29,No,Low Income,2,Yes,No,1 152 | 31,Yes,High Income,1,No,No,0 153 | 34,Yes,Low Income,1,Yes,No,0 154 | 37,Yes,Low Income,1,Yes,No,1 155 | 37,No,Low Income,4,Yes,No,0 156 | 36,No,Middle Income,4,No,Yes,0 157 | 27,Yes,High Income,1,No,Yes,1 158 | 36,No Record,Low Income,2,No,No,0 159 | 28,No,Middle Income,3,No,Yes,0 160 | 29,Yes,High Income,1,No,No,0 161 | 34,No,Low Income,1,Yes,No,0 162 | 28,No,Low Income,2,No,No,1 163 | 30,No,Middle Income,6,No,Yes,0 164 | 30,No,Middle Income,4,Yes,Yes,0 165 | 30,No,Low Income,4,Yes,No,0 166 | 27,No,Middle Income,4,Yes,Yes,0 167 | 31,No,Low Income,1,Yes,Yes,0 168 | 29,No,Middle Income,4,Yes,No,0 169 | 34,No,Middle Income,1,Yes,Yes,0 170 | 37,No,Low Income,5,Yes,No,1 171 | 28,No,Middle Income,2,No,No,1 172 | 30,No,Middle Income,3,Yes,No,0 173 | 31,No,Middle Income,2,No,No,0 174 | 31,Yes,High Income,1,No,No,0 175 | 36,Yes,Low Income,1,No,Yes,0 176 | 37,Yes,Low Income,1,Yes,No,1 177 | 28,No,Low Income,6,No,Yes,1 178 | 34,No Record,Middle Income,1,No,Yes,0 179 | 34,No,Middle Income,3,No,No,0 180 | 30,No,Low Income,1,Yes,No,0 181 | 27,No,Low Income,5,No,Yes,0 182 | 36,No,Middle Income,2,Yes,No,0 183 | 30,Yes,High Income,1,No,Yes,0 184 | 29,No,Low Income,1,No,No,0 185 | 30,No,Low Income,1,Yes,No,0 186 | 30,No,Middle Income,1,Yes,Yes,0 187 | 30,Yes,High Income,1,No,Yes,0 188 | 37,Yes,High Income,1,No,No,0 189 | 31,No,Middle Income,1,No,Yes,0 190 | 35,Yes,Low Income,2,Yes,No,0 191 | 30,No,Middle Income,2,Yes,No,0 192 | 30,No,Middle Income,5,No,Yes,0 193 | 30,No,Low Income,4,Yes,No,0 194 | 31,No,Middle Income,2,No,No,0 195 | 30,No,Middle Income,4,Yes,No,0 196 | 34,No,Low Income,5,No,No,0 197 | 30,Yes,Low Income,1,Yes,No,1 198 | 37,Yes,Low Income,1,No,No,0 199 | 28,Yes,Low Income,4,Yes,No,1 200 | 34,Yes,Low Income,1,Yes,No,0 201 | 33,Yes,High Income,1,No,Yes,0 202 | 36,No,Middle Income,1,Yes,Yes,0 203 | 30,Yes,High Income,6,No,Yes,1 204 | 35,No,Middle Income,1,Yes,Yes,0 205 | 34,Yes,High Income,4,No,No,0 206 | 37,No,Middle Income,3,Yes,No,0 207 | 37,No,Middle Income,1,Yes,Yes,0 208 | 38,No,Middle Income,2,No,No,0 209 | 28,No,Low Income,1,Yes,No,0 210 | 30,No,Middle Income,1,Yes,Yes,0 211 | 27,Yes,Low Income,1,No,Yes,1 212 | 36,No Record,Middle Income,1,Yes,Yes,0 213 | 37,No,Middle Income,6,No,Yes,0 214 | 30,No,Low Income,2,Yes,Yes,1 215 | 28,No,Middle Income,3,Yes,No,0 216 | 33,Yes,Low Income,1,Yes,Yes,0 217 | 38,Yes,Low Income,6,No,Yes,0 218 | 29,No,Low Income,1,No,No,0 219 | 36,No,Middle Income,3,Yes,No,0 220 | 34,Yes,Low Income,5,Yes,No,1 221 | 31,Yes,Low Income,1,No,No,0 222 | 33,Yes,Low Income,1,Yes,No,1 223 | 30,No,Low Income,1,No,Yes,0 224 | 27,No,Middle Income,2,No,No,1 225 | 34,No,Middle Income,4,Yes,No,0 226 | 30,No,Middle Income,3,Yes,No,0 227 | 38,No,Low Income,1,No,Yes,0 228 | 30,No,Middle Income,3,No,Yes,0 229 | 34,No,Middle Income,3,Yes,No,0 230 | 29,Yes,High Income,1,No,No,0 231 | 37,No,Middle Income,6,No,Yes,0 232 | 29,No,Middle Income,2,Yes,No,0 233 | 31,No,Middle Income,3,No,Yes,0 234 | 34,Yes,Low Income,1,Yes,No,1 235 | 33,No,Low Income,1,No,Yes,0 236 | 30,Yes,Low Income,1,No,No,0 237 | 38,No,Middle Income,2,Yes,No,0 238 | 31,No,Low Income,2,No,No,0 239 | 34,Yes,Low Income,3,No,Yes,1 240 | 37,Yes,Low Income,1,Yes,Yes,0 241 | 35,No,Low Income,4,Yes,No,0 242 | 31,No,Low Income,1,Yes,Yes,0 243 | 37,No Record,Middle Income,3,Yes,No,0 244 | 31,No,Middle Income,1,No,Yes,0 245 | 30,No,Low Income,2,No,No,0 246 | 35,No,Middle Income,4,No,No,0 247 | 31,No,Middle Income,1,Yes,Yes,0 248 | 37,Yes,Low Income,1,No,No,0 249 | 36,Yes,Low Income,1,No,No,0 250 | 31,No,Middle Income,1,No,Yes,0 251 | 37,Yes,High Income,6,No,Yes,1 252 | 29,No,Middle Income,3,Yes,No,0 253 | 29,No,Low Income,1,Yes,No,0 254 | 31,No,Middle Income,4,Yes,No,0 255 | 27,No Record,Low Income,1,Yes,No,0 256 | 27,No,Low Income,2,No,No,1 257 | 35,No,Middle Income,4,No,No,0 258 | 37,Yes,High Income,1,Yes,No,1 259 | 30,No,Middle Income,2,No,No,0 260 | 36,No,Middle Income,2,No,No,0 261 | 31,No,Low Income,2,No,No,0 262 | 30,No,Middle Income,1,Yes,Yes,0 263 | 30,No,Middle Income,5,Yes,Yes,0 264 | 30,No Record,Middle Income,6,No,Yes,0 265 | 30,No,Low Income,1,Yes,Yes,0 266 | 35,No Record,Middle Income,3,No,No,0 267 | 34,No,Middle Income,4,Yes,No,0 268 | 30,No,Middle Income,4,No,No,0 269 | 27,No,Low Income,2,Yes,No,1 270 | 37,No,Middle Income,2,No,No,0 271 | 34,No,Low Income,1,Yes,Yes,0 272 | 30,Yes,High Income,6,No,Yes,1 273 | 31,No,Middle Income,3,No,No,0 274 | 36,Yes,Low Income,3,Yes,No,1 275 | 30,No,Middle Income,2,No,No,0 276 | 37,No,Low Income,2,Yes,No,1 277 | 31,No,Middle Income,4,No,Yes,0 278 | 30,No,Low Income,4,No,No,0 279 | 30,Yes,High Income,5,No,No,0 280 | 33,Yes,High Income,1,No,No,1 281 | 34,Yes,High Income,1,No,No,0 282 | 31,Yes,High Income,4,No,No,0 283 | 27,No,Low Income,2,No,No,1 284 | 30,No,Middle Income,4,No,Yes,0 285 | 28,No,Middle Income,1,No,Yes,0 286 | 27,No,Middle Income,5,No,No,1 287 | 27,No,Middle Income,5,No,Yes,0 288 | 36,Yes,High Income,6,No,Yes,1 289 | 30,No,Low Income,2,No,No,0 290 | 34,Yes,Low Income,1,No,No,0 291 | 35,No,Middle Income,5,Yes,No,1 292 | 30,No,Middle Income,3,Yes,No,0 293 | 31,No,Middle Income,1,No,Yes,0 294 | 30,No,Middle Income,3,Yes,No,0 295 | 31,No,Low Income,4,Yes,No,0 296 | 30,No,Middle Income,4,No,Yes,0 297 | 35,No,Low Income,2,No,No,0 298 | 38,No,Middle Income,2,No,No,0 299 | 38,No,Low Income,1,Yes,No,0 300 | 30,No,Low Income,1,Yes,Yes,0 301 | 37,No,Middle Income,4,No,Yes,0 302 | 33,Yes,Low Income,1,Yes,Yes,0 303 | 30,No,Middle Income,4,Yes,Yes,0 304 | 34,No,Low Income,1,Yes,No,0 305 | 30,No,Low Income,1,No,No,0 306 | 35,No,Low Income,1,No,No,0 307 | 38,No,Low Income,2,No,No,0 308 | 27,No,Middle Income,4,No,Yes,0 309 | 29,No,Low Income,2,Yes,No,1 310 | 28,No,Low Income,2,Yes,No,1 311 | 30,Yes,High Income,1,Yes,No,0 312 | 34,No,Low Income,5,Yes,No,0 313 | 35,No Record,Middle Income,1,Yes,Yes,0 314 | 30,Yes,Low Income,5,Yes,No,1 315 | 29,No,Low Income,2,Yes,No,1 316 | 33,Yes,High Income,4,No,No,1 317 | 33,No,Low Income,5,Yes,Yes,1 318 | 28,No,Middle Income,3,Yes,No,0 319 | 38,Yes,Low Income,1,No,No,0 320 | 35,No,Middle Income,2,No,Yes,0 321 | 30,No Record,Middle Income,2,No,No,0 322 | 34,Yes,Low Income,1,No,No,0 323 | 31,No,Middle Income,6,No,Yes,0 324 | 27,No,Middle Income,4,Yes,No,0 325 | 29,No,Middle Income,1,No,Yes,0 326 | 37,No,Middle Income,2,Yes,No,0 327 | 34,No,Middle Income,6,No,Yes,0 328 | 30,Yes,High Income,1,Yes,No,0 329 | 30,No,Middle Income,1,Yes,Yes,0 330 | 37,No,Middle Income,4,No,No,0 331 | 30,No,Middle Income,3,No,No,0 332 | 34,Yes,High Income,1,No,No,0 333 | 28,Yes,High Income,6,No,Yes,1 334 | 36,No,Middle Income,3,Yes,No,0 335 | 30,No,Low Income,2,Yes,No,1 336 | 34,No,Middle Income,5,No,Yes,0 337 | 30,No,Low Income,2,Yes,No,1 338 | 37,No,Middle Income,4,No,Yes,0 339 | 30,Yes,High Income,1,Yes,No,0 340 | 29,Yes,Low Income,1,No,No,0 341 | 36,No,Middle Income,3,No,No,0 342 | 37,No,Middle Income,3,No,Yes,0 343 | 37,No,Middle Income,1,No,Yes,0 344 | 34,Yes,Low Income,1,No,Yes,0 345 | 37,No,Middle Income,2,No,No,0 346 | 27,No,Middle Income,1,No,Yes,0 347 | 30,No,Low Income,2,Yes,No,1 348 | 35,Yes,High Income,1,No,No,1 349 | 30,No,Low Income,1,Yes,Yes,0 350 | 30,No,Low Income,2,No,No,0 351 | 31,Yes,High Income,1,No,No,1 352 | 37,No,Middle Income,3,No,No,0 353 | 37,Yes,Low Income,4,Yes,Yes,1 354 | 37,No,Low Income,1,No,Yes,0 355 | 28,No Record,Middle Income,2,No,No,1 356 | 28,No,Low Income,1,Yes,Yes,0 357 | 36,No,Middle Income,2,No,Yes,0 358 | 27,Yes,High Income,1,No,No,1 359 | 29,No,Middle Income,2,No,No,0 360 | 34,Yes,High Income,1,No,No,1 361 | 34,Yes,Low Income,2,No,No,0 362 | 31,Yes,Low Income,1,No,No,0 363 | 29,No,Low Income,1,Yes,Yes,0 364 | 27,Yes,Low Income,1,Yes,No,1 365 | 27,No,Low Income,5,Yes,Yes,0 366 | 29,Yes,High Income,1,Yes,No,0 367 | 34,No,Middle Income,4,No,No,0 368 | 30,No,Low Income,5,Yes,No,0 369 | 34,No Record,Middle Income,1,No,Yes,0 370 | 30,No,Middle Income,5,Yes,Yes,0 371 | 27,No,Low Income,2,No,No,1 372 | 29,No,Middle Income,3,No,No,0 373 | 34,Yes,Low Income,2,No,No,0 374 | 30,No,Middle Income,3,No,Yes,0 375 | 29,No,Low Income,2,No,No,0 376 | 30,Yes,High Income,1,No,No,1 377 | 36,Yes,High Income,6,No,Yes,1 378 | 31,No,Low Income,1,No,No,0 379 | 28,No Record,Middle Income,4,No,No,0 380 | 36,No,Low Income,2,Yes,No,1 381 | 30,No,Middle Income,3,No,Yes,0 382 | 35,Yes,Low Income,3,No,Yes,1 383 | 29,Yes,High Income,1,No,No,1 384 | 30,No,Middle Income,4,No,No,0 385 | 30,No,Middle Income,6,No,Yes,0 386 | 30,Yes,Low Income,1,Yes,No,1 387 | 37,Yes,Low Income,2,No,Yes,0 388 | 30,No,Low Income,2,No,No,0 389 | 31,No,Middle Income,2,No,No,0 390 | 29,No,Middle Income,6,No,Yes,0 391 | 31,Yes,High Income,4,No,Yes,0 392 | 33,No,Middle Income,5,Yes,Yes,0 393 | 34,No,Low Income,4,No,Yes,0 394 | 28,No,Middle Income,1,Yes,Yes,0 395 | 34,No,Middle Income,3,No,Yes,0 396 | 37,No,Low Income,1,No,No,0 397 | 37,Yes,High Income,1,No,No,0 398 | 33,No,Low Income,1,Yes,No,0 399 | 36,No,Low Income,1,No,Yes,0 400 | 34,No,Low Income,2,No,No,0 401 | 29,No Record,Low Income,1,No,Yes,0 402 | 37,No,Low Income,5,No,Yes,1 403 | 30,No Record,Low Income,1,Yes,Yes,0 404 | 34,No,Low Income,2,No,No,0 405 | 27,No,Low Income,1,Yes,No,0 406 | 37,Yes,Low Income,1,No,No,0 407 | 31,No,Low Income,1,Yes,No,0 408 | 34,Yes,Low Income,3,Yes,No,1 409 | 28,Yes,High Income,1,No,No,1 410 | 30,Yes,High Income,4,Yes,No,1 411 | 34,No,Middle Income,3,Yes,Yes,0 412 | 28,Yes,Low Income,1,Yes,No,1 413 | 38,No,Low Income,1,No,Yes,0 414 | 30,No,Low Income,5,No,No,0 415 | 30,No,Low Income,4,Yes,No,0 416 | 36,Yes,High Income,1,No,No,1 417 | 37,No,Middle Income,3,No,Yes,0 418 | 37,No,Low Income,1,Yes,Yes,0 419 | 37,Yes,Low Income,6,No,Yes,0 420 | 36,No,Middle Income,3,No,No,0 421 | 30,Yes,High Income,1,No,Yes,0 422 | 28,Yes,High Income,1,No,No,1 423 | 34,No,Middle Income,4,Yes,Yes,0 424 | 31,No,Low Income,1,Yes,Yes,0 425 | 34,Yes,High Income,1,No,No,1 426 | 27,Yes,Low Income,1,Yes,No,1 427 | 37,No Record,Low Income,4,Yes,No,1 428 | 31,Yes,High Income,6,No,Yes,1 429 | 30,No,Middle Income,4,No,Yes,0 430 | 34,No,Middle Income,2,No,No,0 431 | 30,No,Low Income,4,Yes,No,0 432 | 31,No,Low Income,1,Yes,Yes,0 433 | 37,Yes,Low Income,3,Yes,No,1 434 | 31,No,Middle Income,2,No,No,0 435 | 29,No,Middle Income,6,No,Yes,0 436 | 27,No,Middle Income,1,No,Yes,0 437 | 37,Yes,Low Income,4,No,No,0 438 | 34,Yes,Low Income,1,Yes,No,0 439 | 30,No,Middle Income,3,No,No,0 440 | 31,No Record,Middle Income,3,No,No,0 441 | 28,No,Low Income,3,No,No,0 442 | 29,No,Low Income,5,Yes,No,1 443 | 28,No,Low Income,5,Yes,Yes,0 444 | 31,No,Low Income,1,Yes,No,0 445 | 36,Yes,Low Income,1,Yes,No,1 446 | 34,No,Middle Income,1,Yes,Yes,0 447 | 37,No,Middle Income,3,No,No,0 448 | 31,Yes,Low Income,1,No,Yes,0 449 | 35,No,Low Income,4,No,Yes,0 450 | 29,No,Low Income,1,Yes,Yes,0 451 | 31,Yes,Low Income,1,Yes,No,1 452 | 34,Yes,Low Income,1,Yes,Yes,1 453 | 30,No Record,Middle Income,6,No,Yes,0 454 | 30,Yes,High Income,1,No,Yes,0 455 | 31,Yes,High Income,6,No,Yes,1 456 | 36,No,Middle Income,5,No,No,1 457 | 36,No,Middle Income,3,No,No,0 458 | 36,Yes,Low Income,2,Yes,No,1 459 | 37,Yes,High Income,1,No,No,0 460 | 28,No Record,Low Income,4,Yes,Yes,0 461 | 34,Yes,Low Income,5,No,Yes,0 462 | 37,Yes,High Income,1,No,No,1 463 | 30,No,Middle Income,4,Yes,Yes,0 464 | 34,Yes,High Income,1,No,No,0 465 | 28,Yes,High Income,1,No,No,1 466 | 35,No,Low Income,1,No,No,0 467 | 27,Yes,Low Income,5,No,No,1 468 | 30,No,Middle Income,6,No,Yes,0 469 | 30,No,Low Income,3,No,No,0 470 | 37,No,Middle Income,1,No,Yes,0 471 | 30,No,Low Income,1,No,Yes,0 472 | 36,Yes,High Income,1,No,No,1 473 | 27,No,Low Income,1,No,No,0 474 | 27,No,Middle Income,3,Yes,No,0 475 | 37,No,Middle Income,3,Yes,Yes,0 476 | 34,No,Low Income,2,No,Yes,0 477 | 30,No,Middle Income,1,Yes,Yes,0 478 | 31,Yes,High Income,1,No,No,1 479 | 30,No,Middle Income,2,No,No,0 480 | 38,No,Middle Income,1,Yes,Yes,0 481 | 30,No,Middle Income,3,No,No,0 482 | 30,Yes,High Income,1,No,No,0 483 | 37,No,Middle Income,3,No,No,0 484 | 28,No,Middle Income,2,No,No,1 485 | 38,No Record,Middle Income,2,Yes,No,0 486 | 34,No,Middle Income,6,No,Yes,0 487 | 27,No,Low Income,1,Yes,No,0 488 | 30,No,Middle Income,2,No,No,0 489 | 34,No,Middle Income,3,Yes,Yes,0 490 | 36,No,Middle Income,3,Yes,No,0 491 | 29,Yes,High Income,1,No,Yes,0 492 | 31,No,Middle Income,1,No,Yes,0 493 | 36,Yes,Low Income,1,Yes,No,0 494 | 28,No,Middle Income,2,Yes,No,1 495 | 27,Yes,Low Income,1,Yes,No,1 496 | 36,No,Middle Income,5,No,No,1 497 | 37,No,Middle Income,2,Yes,No,0 498 | 33,Yes,Low Income,1,No,Yes,0 499 | 27,Yes,Low Income,1,No,No,1 500 | 36,Yes,Low Income,2,No,No,0 501 | 33,No,Low Income,4,No,No,0 502 | 30,No,Middle Income,1,No,Yes,0 503 | 37,No,Middle Income,2,Yes,No,0 504 | 34,Yes,High Income,1,Yes,No,1 505 | 31,No,Middle Income,1,Yes,Yes,0 506 | 31,No,Middle Income,4,No,No,0 507 | 28,No,Middle Income,1,No,Yes,0 508 | 37,No,Low Income,4,No,Yes,0 509 | 28,Yes,High Income,1,No,No,1 510 | 30,No,Middle Income,6,No,Yes,0 511 | 30,No,Middle Income,6,No,Yes,0 512 | 37,No,Middle Income,2,No,No,0 513 | 34,Yes,Low Income,1,Yes,Yes,0 514 | 34,No,Middle Income,5,No,Yes,0 515 | 36,No,Middle Income,3,Yes,No,0 516 | 38,No,Middle Income,1,No,Yes,0 517 | 34,No,Middle Income,1,Yes,Yes,0 518 | 27,Yes,High Income,5,No,Yes,1 519 | 35,No,Low Income,2,Yes,Yes,1 520 | 37,No,Middle Income,3,Yes,No,0 521 | 36,No,Middle Income,4,Yes,No,0 522 | 34,Yes,High Income,5,Yes,No,1 523 | 29,No Record,Middle Income,1,Yes,Yes,0 524 | 28,Yes,High Income,4,No,Yes,1 525 | 30,No,Middle Income,2,No,No,0 526 | 30,No,Low Income,3,No,No,0 527 | 30,No,Low Income,2,No,Yes,0 528 | 31,No,Low Income,1,Yes,No,0 529 | 29,Yes,Low Income,1,No,No,0 530 | 37,No,Middle Income,1,No,Yes,0 531 | 31,Yes,High Income,1,No,No,0 532 | 34,No,Middle Income,1,Yes,Yes,0 533 | 34,Yes,High Income,6,No,Yes,1 534 | 31,No,Middle Income,4,No,No,0 535 | 35,Yes,High Income,1,No,No,1 536 | 28,No Record,Middle Income,4,No,Yes,0 537 | 31,No,Middle Income,5,No,No,0 538 | 27,No Record,Middle Income,4,No,No,1 539 | 30,No,Low Income,1,Yes,Yes,0 540 | 37,No,Middle Income,1,Yes,Yes,0 541 | 30,No,Low Income,4,No,No,0 542 | 29,No,Middle Income,2,No,No,0 543 | 30,No,Middle Income,6,No,Yes,0 544 | 33,Yes,High Income,4,Yes,No,1 545 | 33,No,Low Income,1,Yes,Yes,0 546 | 30,Yes,High Income,1,No,No,1 547 | 34,No,Low Income,1,Yes,No,0 548 | 35,No,Middle Income,1,Yes,Yes,0 549 | 36,No,Low Income,2,No,No,0 550 | 37,Yes,Low Income,1,No,Yes,0 551 | 31,No,Low Income,1,Yes,No,0 552 | 31,Yes,High Income,1,Yes,No,0 553 | 34,No,Low Income,2,Yes,No,1 554 | 30,No,Low Income,1,Yes,No,0 555 | 28,No,Middle Income,5,No,No,0 556 | 34,No,Low Income,1,Yes,Yes,0 557 | 33,Yes,High Income,1,No,No,1 558 | 27,No,Low Income,2,Yes,No,1 559 | 37,No,Low Income,1,Yes,Yes,0 560 | 34,Yes,High Income,1,No,No,0 561 | 27,Yes,High Income,4,No,No,1 562 | 29,No,Middle Income,3,Yes,No,0 563 | 30,No,Low Income,2,Yes,No,1 564 | 38,Yes,Low Income,3,No,Yes,1 565 | 33,Yes,Low Income,4,No,No,0 566 | 37,No,Middle Income,3,No,No,0 567 | 30,No Record,Middle Income,3,Yes,Yes,0 568 | 31,No,Low Income,1,No,Yes,0 569 | 27,Yes,High Income,1,No,No,1 570 | 30,No,Middle Income,4,No,Yes,0 571 | 31,No,Middle Income,3,No,Yes,0 572 | 28,No,Low Income,5,No,No,0 573 | 31,No,Middle Income,1,No,Yes,0 574 | 30,No,Middle Income,4,No,Yes,0 575 | 34,No,Middle Income,1,No,Yes,0 576 | 30,No,Middle Income,4,No,Yes,0 577 | 35,No,Middle Income,6,No,Yes,1 578 | 30,No,Low Income,4,No,No,0 579 | 37,No,Low Income,1,No,No,0 580 | 27,No,Middle Income,6,No,Yes,0 581 | 35,No,Low Income,2,Yes,No,1 582 | 35,No,Middle Income,1,Yes,Yes,0 583 | 36,No,Middle Income,2,No,No,0 584 | 34,No,Low Income,1,Yes,Yes,0 585 | 38,Yes,Low Income,1,No,No,0 586 | 31,No,Low Income,1,No,Yes,0 587 | 30,No,Middle Income,3,Yes,Yes,0 588 | 34,Yes,High Income,1,No,No,0 589 | 30,No,Low Income,1,No,Yes,0 590 | 29,Yes,High Income,1,Yes,No,0 591 | 36,No,Middle Income,2,No,No,0 592 | 30,No,Low Income,2,Yes,No,0 593 | 30,No,Middle Income,6,No,Yes,0 594 | 34,No,Middle Income,2,No,No,0 595 | 37,No,Middle Income,2,Yes,No,0 596 | 30,No,Low Income,5,Yes,No,0 597 | 28,No,Middle Income,1,Yes,Yes,0 598 | 34,No,Middle Income,1,No,Yes,0 599 | 29,Yes,High Income,1,No,No,0 600 | 30,No,Low Income,1,No,No,0 601 | 31,No,Middle Income,1,Yes,Yes,0 602 | 36,Yes,Low Income,2,No,No,0 603 | 37,Yes,High Income,4,No,No,1 604 | 30,No,Middle Income,3,Yes,No,0 605 | 30,No,Middle Income,6,No,Yes,0 606 | 34,Yes,Low Income,4,Yes,No,1 607 | 29,Yes,High Income,5,No,No,1 608 | 29,Yes,Low Income,1,Yes,No,1 609 | 34,Yes,Low Income,4,No,No,0 610 | 29,No,Middle Income,5,No,Yes,0 611 | 30,No,Low Income,1,Yes,No,0 612 | 31,No Record,Middle Income,1,Yes,Yes,0 613 | 30,No,Middle Income,3,No,No,0 614 | 30,No,Low Income,3,No,Yes,0 615 | 27,No,Middle Income,4,Yes,Yes,0 616 | 28,Yes,High Income,1,No,No,1 617 | 30,No,Low Income,6,No,Yes,0 618 | 30,Yes,High Income,1,No,No,0 619 | 29,No,Middle Income,2,No,Yes,0 620 | 34,Yes,Low Income,4,No,No,0 621 | 37,Yes,Low Income,1,No,Yes,0 622 | 38,No,Middle Income,2,No,Yes,0 623 | 31,No Record,Middle Income,3,No,Yes,0 624 | 28,Yes,High Income,4,No,No,1 625 | 34,No,Low Income,4,No,Yes,0 626 | 30,No,Low Income,2,No,No,0 627 | 30,No,Low Income,1,Yes,No,0 628 | 28,No,Low Income,1,Yes,No,0 629 | 28,No,Middle Income,3,No,Yes,0 630 | 27,No,Middle Income,6,No,Yes,1 631 | 33,No,Middle Income,2,No,No,0 632 | 30,No Record,Low Income,1,No,No,0 633 | 33,No,Middle Income,3,Yes,Yes,0 634 | 37,Yes,Low Income,1,No,No,0 635 | 30,No,Middle Income,3,No,No,0 636 | 37,Yes,Low Income,1,Yes,No,0 637 | 30,No,Low Income,1,Yes,Yes,0 638 | 31,No,Low Income,2,No,Yes,0 639 | 30,No,Low Income,4,No,Yes,0 640 | 28,No,Low Income,2,No,No,1 641 | 28,No Record,Low Income,6,No,Yes,0 642 | 30,Yes,High Income,1,No,No,0 643 | 37,No,Middle Income,6,No,Yes,0 644 | 31,No,Low Income,1,No,Yes,0 645 | 37,No,Low Income,4,Yes,Yes,0 646 | 35,Yes,High Income,6,No,Yes,1 647 | 30,No,Middle Income,3,No,Yes,0 648 | 36,Yes,High Income,1,No,No,1 649 | 30,Yes,Low Income,4,Yes,No,1 650 | 37,No,Middle Income,4,No,No,0 651 | 31,Yes,Low Income,1,No,Yes,0 652 | 31,No,Middle Income,1,No,Yes,0 653 | 36,Yes,Low Income,1,No,Yes,0 654 | 29,No,Low Income,3,Yes,No,0 655 | 30,Yes,High Income,1,No,Yes,0 656 | 30,No,Middle Income,2,No,Yes,0 657 | 36,Yes,High Income,1,No,No,0 658 | 30,No,Low Income,2,No,Yes,0 659 | 30,No,Low Income,1,Yes,No,0 660 | 36,No,Low Income,1,Yes,No,0 661 | 37,Yes,High Income,1,Yes,No,1 662 | 34,Yes,High Income,4,No,No,1 663 | 30,No,Low Income,1,Yes,No,0 664 | 30,Yes,High Income,1,No,No,0 665 | 29,No,Low Income,1,No,Yes,0 666 | 30,Yes,High Income,1,No,No,0 667 | 35,Yes,High Income,6,No,Yes,1 668 | 33,No,Middle Income,4,Yes,No,0 669 | 28,Yes,Low Income,1,Yes,No,1 670 | 34,Yes,High Income,1,No,No,0 671 | 31,No,Middle Income,1,No,Yes,0 672 | 30,No,Low Income,5,No,No,0 673 | 34,Yes,High Income,1,No,No,0 674 | 37,No,Middle Income,3,Yes,No,0 675 | 37,Yes,Low Income,3,Yes,Yes,1 676 | 29,No,Middle Income,1,No,Yes,0 677 | 37,No,Middle Income,2,Yes,No,0 678 | 34,No,Low Income,1,Yes,Yes,0 679 | 29,Yes,High Income,5,No,No,1 680 | 34,Yes,High Income,5,No,No,0 681 | 27,No Record,Middle Income,2,No,Yes,1 682 | 30,Yes,High Income,1,Yes,No,0 683 | 28,Yes,High Income,1,No,No,1 684 | 31,No,Low Income,1,Yes,Yes,0 685 | 30,No,Middle Income,3,Yes,No,0 686 | 34,Yes,High Income,1,No,No,1 687 | 30,No Record,Middle Income,3,No,No,0 688 | 31,No,Middle Income,3,Yes,No,0 689 | 35,No,Middle Income,3,No,Yes,0 690 | 36,No,Middle Income,3,Yes,Yes,0 691 | 37,No,Low Income,1,No,Yes,0 692 | 30,No Record,Low Income,1,Yes,Yes,0 693 | 30,Yes,High Income,6,No,Yes,1 694 | 35,Yes,High Income,1,No,No,1 695 | 37,No,Low Income,1,No,Yes,0 696 | 27,No,Middle Income,3,No,No,0 697 | 30,Yes,High Income,1,No,No,0 698 | 35,Yes,Low Income,1,No,No,0 699 | 37,No,Middle Income,3,No,No,0 700 | 33,No Record,Low Income,2,Yes,No,1 701 | 35,Yes,High Income,1,No,No,1 702 | 30,Yes,High Income,4,No,No,0 703 | 30,No,Low Income,1,No,No,0 704 | 38,Yes,High Income,1,Yes,No,1 705 | 30,Yes,Low Income,1,Yes,No,1 706 | 31,Yes,High Income,1,No,No,1 707 | 28,No,Low Income,5,Yes,No,1 708 | 36,No,Middle Income,2,No,No,0 709 | 31,Yes,Low Income,1,Yes,No,1 710 | 31,Yes,High Income,1,Yes,No,1 711 | 30,No,Low Income,1,Yes,No,0 712 | 31,No Record,Low Income,1,Yes,No,0 713 | 35,No,Middle Income,2,No,No,0 714 | 30,Yes,Low Income,1,Yes,No,1 715 | 36,No,Middle Income,2,No,No,0 716 | 29,No,Middle Income,1,Yes,Yes,0 717 | 30,No,Middle Income,2,No,No,0 718 | 36,No,Middle Income,3,No,No,0 719 | 37,No,Low Income,2,Yes,No,1 720 | 30,Yes,Low Income,1,No,No,0 721 | 37,Yes,Low Income,1,Yes,No,0 722 | 30,No,Middle Income,1,No,Yes,0 723 | 30,No,Middle Income,3,Yes,No,0 724 | 29,No,Low Income,1,Yes,Yes,0 725 | 37,Yes,High Income,1,No,No,0 726 | 34,No,Middle Income,3,No,Yes,0 727 | 37,No,Middle Income,1,No,Yes,0 728 | 28,No,Low Income,2,Yes,No,1 729 | 31,Yes,High Income,1,No,No,0 730 | 27,No,Middle Income,1,Yes,Yes,0 731 | 31,No,Middle Income,2,No,Yes,0 732 | 31,No Record,Middle Income,5,Yes,No,0 733 | 37,No,Middle Income,1,No,Yes,0 734 | 37,No Record,Low Income,2,Yes,No,1 735 | 34,No,Middle Income,1,No,Yes,0 736 | 31,Yes,High Income,1,No,No,1 737 | 37,No,Middle Income,4,Yes,Yes,0 738 | 30,No,Middle Income,6,No,Yes,0 739 | 31,Yes,High Income,1,No,No,1 740 | 30,Yes,Low Income,4,No,No,0 741 | 30,No Record,Middle Income,4,No,No,0 742 | 30,No Record,Low Income,1,Yes,No,0 743 | 34,No,Middle Income,2,Yes,No,0 744 | 36,No,Middle Income,1,No,Yes,0 745 | 29,No,Middle Income,4,Yes,Yes,0 746 | 35,Yes,Low Income,5,No,No,1 747 | 28,No,Low Income,4,Yes,No,1 748 | 27,No,Middle Income,2,No,No,1 749 | 29,No,Low Income,3,No,No,0 750 | 37,No,Low Income,1,Yes,No,0 751 | 28,Yes,High Income,6,No,Yes,1 752 | 31,No Record,Middle Income,5,Yes,No,0 753 | 27,Yes,Low Income,5,Yes,No,1 754 | 30,No,Middle Income,5,Yes,No,0 755 | 28,Yes,Low Income,1,No,No,1 756 | 35,No Record,Middle Income,2,Yes,No,0 757 | 37,No,Middle Income,1,No,Yes,0 758 | 33,Yes,Low Income,1,No,No,0 759 | 27,No Record,Low Income,2,Yes,Yes,1 760 | 29,Yes,Low Income,1,No,Yes,0 761 | 34,Yes,High Income,1,Yes,No,1 762 | 37,No,Middle Income,3,No,Yes,0 763 | 38,Yes,Low Income,1,No,Yes,0 764 | 34,No,Middle Income,3,No,Yes,0 765 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: binary-diffusion-tabular 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - pip=24.2 7 | - python=3.11.10 8 | - wheel=0.44.0 9 | - pip: 10 | - accelerate==1.1.1 11 | - ema-pytorch==0.7.6 12 | - numpy==2.1.3 13 | - pandas==2.2.3 14 | - pyyaml==6.0.2 15 | - scikit-learn==1.5.2 16 | - torch==2.5.0 17 | - torchmetrics==1.6.0 18 | - tqdm==4.67.1 19 | - wandb==0.19.0 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "binary-diffusion-tabular" 7 | version = "0.1.0" 8 | description = "Binary diffusion for tabular data" 9 | readme = "README.md" 10 | license = {text = "MIT"} 11 | authors = [ 12 | {name = "Vitaliy Kinakh", email = "kinakh.vitalii@gmail.com"} 13 | ] 14 | dependencies = [ 15 | "accelerate==1.1.1", 16 | "ema-pytorch==0.7.6", 17 | "numpy==2.1.3", 18 | "pandas==2.2.3", 19 | "PyYAML==6.0.2", 20 | "scikit-learn==1.5.2", 21 | "torch==2.5.0", 22 | "torchmetrics==1.6.0", 23 | "tqdm==4.67.1", 24 | "wandb==0.19.0" 25 | ] 26 | classifiers = [ 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "License :: OSI Approved :: MIT License", 31 | "Operating System :: OS Independent", 32 | "Intended Audience :: Developers", 33 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 34 | ] 35 | requires-python = ">=3.10" 36 | 37 | [tool.setuptools] 38 | packages = ["binary_diffusion_tabular"] 39 | 40 | [project.urls] 41 | Homepage = "https://github.com/vkinakh/binary-diffusion-tabular" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.1.1 2 | ema-pytorch==0.7.6 3 | numpy==2.1.3 4 | pandas==2.2.3 5 | PyYAML==6.0.2 6 | scikit-learn==1.5.2 7 | torch==2.5.0 8 | torchmetrics==1.6.0 9 | tqdm==4.67.1 10 | wandb==0.19.0 -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from functools import partial 4 | 5 | import pandas as pd 6 | from tqdm.auto import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from binary_diffusion_tabular import ( 12 | BinaryDiffusion1D, 13 | SimpleTableGenerator, 14 | FixedSizeBinaryTableTransformation, 15 | select_equally_distributed_numbers, 16 | TASK, 17 | get_random_labels, 18 | seed_everything 19 | ) 20 | 21 | 22 | def get_sampling_args_parser() -> argparse.ArgumentParser: 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | "--ckpt", type=str, required=True, help="Path to checkpoint file" 26 | ) 27 | parser.add_argument("--ckpt_transformation", type=str, help="Path to transformation checkpoint file") 28 | parser.add_argument( 29 | "--n_timesteps", "-t", type=int, required=True, help="Number of sampling steps" 30 | ) 31 | parser.add_argument( 32 | "--out", 33 | "-o", 34 | type=str, 35 | required=True, 36 | help="Path to output folder, where to save samples", 37 | ) 38 | parser.add_argument( 39 | "--n_samples", 40 | "-n", 41 | type=int, 42 | required=True, 43 | help="Number of samples to generate", 44 | ) 45 | parser.add_argument( 46 | "--batch_size", "-b", type=int, required=True, help="Batch size for sampling" 47 | ) 48 | parser.add_argument( 49 | "--threshold", type=float, default=0.5, help="Threshold for binarization" 50 | ) 51 | parser.add_argument( 52 | "--strategy", 53 | type=str, 54 | default="target", 55 | choices=["target", "mask"], 56 | help="Sampling strategy to use", 57 | ) 58 | parser.add_argument("--seed", "-s", type=int, help="Random seed", required=False) 59 | parser.add_argument( 60 | "--guidance_scale", "-g", type=float, default=0.0, help="Guidance scale" 61 | ) 62 | parser.add_argument("--target_column_name", type=str, help="Target column name", required=False) 63 | parser.add_argument("--device", "-d", type=str, default="cuda", help="Device") 64 | parser.add_argument("--use_ema", "-e", action="store_true", help="Use EMA") 65 | parser.add_argument("--dropna", action="store_true", help="Whether to drop rows with nan during sampling") 66 | 67 | return parser 68 | 69 | 70 | def cfg_model_fn( 71 | x_t: torch.Tensor, 72 | ts: torch.Tensor, 73 | y: torch.Tensor, 74 | model: nn.Module, 75 | guidance_scale: float, 76 | task: TASK, 77 | *args, 78 | **kwargs 79 | ) -> torch.Tensor: 80 | """Classifier free guidance sampling function 81 | 82 | Args: 83 | x_t: noisy sample 84 | ts: timesteps 85 | y: conditioning 86 | model: denoising model 87 | guidance_scale: guidance scale in classifier free guidance 88 | task: dataset task 89 | 90 | Returns: 91 | torch.Tensor: denoiser output 92 | """ 93 | 94 | combine = torch.cat([x_t, x_t], dim=0) 95 | combine_ts = torch.cat([ts, ts], dim=0) 96 | 97 | if task == "classification": 98 | y_other = torch.zeros_like(y) 99 | elif task == "regression": 100 | # for regression, zero-token is -1, since values are minmax normalized to [0, 1] range 101 | y_other = torch.ones_like(y) * -1 102 | 103 | combine_y = torch.cat([y, y_other], dim=0) 104 | model_out = model(combine, combine_ts, y=combine_y) 105 | cond_eps, uncod_eps = torch.split(model_out, [y.shape[0], y.shape[0]], dim=0) 106 | eps = uncod_eps + guidance_scale * (cond_eps - uncod_eps) 107 | return eps 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = get_sampling_args_parser() 112 | cli_args = parser.parse_args() 113 | 114 | if cli_args.seed: 115 | seed_everything(cli_args.seed) 116 | 117 | path_out = Path(cli_args.out) 118 | path_out.mkdir(parents=True, exist_ok=True) 119 | 120 | ckpt = torch.load(cli_args.ckpt) 121 | device = cli_args.device 122 | batch_size = int(cli_args.batch_size) 123 | guidance_scale = cli_args.guidance_scale 124 | threshold = cli_args.threshold 125 | strategy = cli_args.strategy 126 | target_column_name = cli_args.target_column_name 127 | 128 | denoising_model = SimpleTableGenerator.from_config(ckpt["config_model"]).to(device) 129 | denoising_model.eval() 130 | 131 | diffusion = BinaryDiffusion1D.from_config( 132 | denoise_model=denoising_model, 133 | config=ckpt["config_diffusion"], 134 | ).to(device) 135 | diffusion.eval() 136 | 137 | transformation = FixedSizeBinaryTableTransformation.from_checkpoint(cli_args.ckpt_transformation) 138 | 139 | if cli_args.use_ema: 140 | diffusion.load_ema(ckpt["diffusion_ema"]) 141 | else: 142 | diffusion.load_state_dict(ckpt["diffusion"]) 143 | 144 | n_total_timesteps = diffusion.n_timesteps 145 | timesteps_sampling = select_equally_distributed_numbers( 146 | n_total_timesteps, 147 | cli_args.n_timesteps, 148 | ) 149 | task = denoising_model.task 150 | conditional = denoising_model.conditional 151 | n_classes = denoising_model.n_classes 152 | classifier_free_guidance = denoising_model.classifier_free_guidance 153 | 154 | n_generated = 0 155 | n_samples = cli_args.n_samples 156 | pbar = tqdm(total=n_samples) 157 | dfs = [] 158 | 159 | while n_generated < n_samples: 160 | labels = get_random_labels( 161 | conditional=conditional, 162 | task=task, 163 | n_classes=n_classes, 164 | classifier_free_guidance=classifier_free_guidance, 165 | n_labels=batch_size, 166 | device=device, 167 | ) 168 | 169 | x = diffusion.sample( 170 | model_fn=( 171 | partial(cfg_model_fn, guidance_scale=guidance_scale, task=task) 172 | if classifier_free_guidance and guidance_scale > 0 173 | else None 174 | ), 175 | n=batch_size, 176 | y=labels, 177 | timesteps=timesteps_sampling, 178 | threshold=threshold, 179 | strategy=strategy, 180 | ) 181 | 182 | if conditional: 183 | if classifier_free_guidance: 184 | labels = torch.argmax(labels, dim=1) 185 | 186 | x_df, labels_df = transformation.inverse_transform(x, labels) 187 | x_df[target_column_name] = labels_df 188 | else: 189 | x_df = transformation.inverse_transform(x) 190 | 191 | if cli_args.dropna: 192 | x_df = x_df.dropna() 193 | 194 | n_generated += len(x_df) 195 | pbar.update(len(x_df)) 196 | dfs.append(x_df) 197 | 198 | df = pd.concat(dfs) 199 | df.to_csv(path_out / "samples.csv", index=False) 200 | pbar.close() 201 | -------------------------------------------------------------------------------- /tests/test_fixed_size_binary_table_transformation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import tempfile 4 | 5 | import pandas as pd 6 | import torch 7 | import numpy as np 8 | 9 | from binary_diffusion_tabular import FixedSizeBinaryTableTransformation 10 | 11 | 12 | def create_sample_data(): 13 | """Create a sample DataFrame for testing based on adult dataset.""" 14 | data = { 15 | "age": [25, 38, 28, 44, 18], 16 | "fnlwgt": [226802, 89814, 336951, 160323, 103497], 17 | "education-num": [7, 9, 10, 10, 9], 18 | "capital-gain": [0, 0, 0, 7688, 0], 19 | "capital-loss": [1, 2, 3, 4, 5], 20 | "hours-per-week": [40, 50, 40, 40, 30], 21 | "workclass": ["Private", "Self-emp-not-inc", "Private", "Private", "Private"], 22 | "education": ["Bachelors", "HS-grad", "11th", "Masters", "HS-grad"], 23 | "marital-status": [ 24 | "Never-married", 25 | "Married-civ-spouse", 26 | "Married-civ-spouse", 27 | "Divorced", 28 | "Never-married", 29 | ], 30 | "occupation": [ 31 | "Tech-support", 32 | "Craft-repair", 33 | "Other-service", 34 | "Sales", 35 | "Adm-clerical", 36 | ], 37 | "relationship": [ 38 | "Not-in-family", 39 | "Husband", 40 | "Husband", 41 | "Unmarried", 42 | "Not-in-family", 43 | ], 44 | "race": ["White", "White", "Asian-Pac-Islander", "Black", "White"], 45 | "sex": ["Male", "Male", "Male", "Male", "Female"], 46 | "native-country": [ 47 | "United-States", 48 | "United-States", 49 | "United-States", 50 | "United-States", 51 | "United-States", 52 | ], 53 | "label": [1, 0, 1, 0, 1], 54 | } 55 | df = pd.DataFrame(data) 56 | return df 57 | 58 | 59 | class TestFixedSizeBinaryTableTransformation(unittest.TestCase): 60 | def setUp(self): 61 | """Set up sample data and transformation instance before each test.""" 62 | self.df = create_sample_data() 63 | self.numerical_cols = [ 64 | "age", 65 | "fnlwgt", 66 | "education-num", 67 | "capital-gain", 68 | "capital-loss", 69 | "hours-per-week", 70 | ] 71 | self.categorical_cols = [ 72 | "workclass", 73 | "education", 74 | "marital-status", 75 | "occupation", 76 | "relationship", 77 | "race", 78 | "sex", 79 | "native-country", 80 | ] 81 | self.transformation = FixedSizeBinaryTableTransformation( 82 | task="classification", 83 | numerical_columns=self.numerical_cols, 84 | categorical_columns=self.categorical_cols, 85 | parallel=False, # Change to True to test parallel execution 86 | ) 87 | 88 | def test_fit_transform_and_transform_consistency(self): 89 | """Test that fit_transform and transform methods produce consistent results.""" 90 | df_y = self.df["label"] 91 | df_x = self.df.drop("label", axis=1) 92 | 93 | x_binary, y_trans = self.transformation.fit_transform(df_x, df_y) 94 | x_binary_2, y_trans_2 = self.transformation.transform(df_x, df_y) 95 | self.assertTrue( 96 | torch.all(x_binary == x_binary_2), "x_binary and x_binary_2 should be equal" 97 | ) 98 | self.assertTrue( 99 | torch.all(y_trans == y_trans_2), "y_trans and y_trans_2 should be equal" 100 | ) 101 | 102 | def test_inverse_transform(self): 103 | """Test that inverse_transform accurately retrieves the original data.""" 104 | df_y = self.df["label"] 105 | df_x = self.df.drop("label", axis=1) 106 | 107 | x_binary, y_trans = self.transformation.fit_transform(df_x, df_y) 108 | df_x_back, y_back = self.transformation.inverse_transform(x_binary, y_trans) 109 | 110 | for col in self.categorical_cols: 111 | original = self.df[col].reset_index(drop=True) 112 | back = df_x_back[col].reset_index(drop=True) 113 | self.assertTrue( 114 | original.equals(back), 115 | f"Categorical column '{col}' does not match after inverse transform", 116 | ) 117 | 118 | for col in self.numerical_cols: 119 | original = self.df[col].values 120 | back = df_x_back[col].values 121 | self.assertTrue( 122 | np.allclose(original, back, atol=1e-5), 123 | f"Numerical column '{col}' does not match after inverse transform", 124 | ) 125 | 126 | def test_parallel_transformation(self): 127 | """Test that parallel and non-parallel transformations produce the same results.""" 128 | df_y = self.df["label"] 129 | df_x = self.df.drop("label", axis=1) 130 | 131 | self.transformation.parallel = False 132 | x_binary, y_trans = self.transformation.fit_transform(df_x, df_y) 133 | 134 | transformation_parallel = FixedSizeBinaryTableTransformation( 135 | task="classification", 136 | numerical_columns=self.numerical_cols, 137 | categorical_columns=self.categorical_cols, 138 | parallel=True, 139 | ) 140 | x_binary_p, y_trans_p = transformation_parallel.fit_transform(df_x, df_y) 141 | 142 | self.assertTrue( 143 | torch.all(x_binary == x_binary_p), 144 | "Binary tensors should be equal when using parallel and non-parallel transforms", 145 | ) 146 | self.assertTrue( 147 | torch.all(y_trans == y_trans_p), 148 | "Labels should be equal when using parallel and non-parallel transforms", 149 | ) 150 | 151 | def test_invalid_numerical_dtype(self): 152 | """Test that a ValueError is raised when numerical columns have non-numeric types.""" 153 | df_x = self.df.drop("label", axis=1).copy() 154 | df_x["age"] = df_x["age"].astype(str) # Introduce invalid dtype 155 | 156 | with self.assertRaises(ValueError): 157 | self.transformation.fit_transform(df_x, self.df["label"]) 158 | 159 | def test_transform_without_fit(self): 160 | """Test that transforming without fitting raises a RuntimeError.""" 161 | df_x = self.df.drop("label", axis=1) 162 | 163 | with self.assertRaises(RuntimeError): 164 | self.transformation.transform(df_x) 165 | 166 | def test_inverse_transform_without_fit(self): 167 | """Test that inverse_transform without fitting raises a RuntimeError.""" 168 | transformation_unfitted = FixedSizeBinaryTableTransformation( 169 | task="classification", 170 | numerical_columns=self.numerical_cols, 171 | categorical_columns=self.categorical_cols, 172 | parallel=False, 173 | ) 174 | fake_tensor = torch.zeros((5, 32)) 175 | 176 | with self.assertRaises(RuntimeError): 177 | transformation_unfitted.inverse_transform(fake_tensor) 178 | 179 | def test_label_transformation(self): 180 | """Test that labels are correctly transformed and inverse transformed.""" 181 | df_y = self.df["label"] 182 | df_x = self.df.drop("label", axis=1) 183 | 184 | x_binary, y_trans = self.transformation.fit_transform(df_x, df_y) 185 | y_back = self.transformation.inverse_transform_label(y_trans) 186 | 187 | original_labels = df_y.values 188 | self.assertTrue( 189 | np.array_equal(original_labels, y_back), 190 | "Original labels and inverse transformed labels should match", 191 | ) 192 | 193 | def test_save_and_load_transformation(self): 194 | """Test that saving and loading the transformation preserves its state and functionality.""" 195 | df_y = self.df["label"] 196 | df_x = self.df.drop("label", axis=1) 197 | 198 | x_binary_original, y_trans_original = self.transformation.fit_transform( 199 | df_x, df_y 200 | ) 201 | 202 | with tempfile.NamedTemporaryFile(suffix=".joblib", delete=False) as tmp_file: 203 | temp_filepath = tmp_file.name 204 | 205 | try: 206 | self.transformation.save_checkpoint(temp_filepath) 207 | 208 | loaded_transformation = FixedSizeBinaryTableTransformation.from_checkpoint( 209 | temp_filepath 210 | ) 211 | 212 | self.assertTrue( 213 | loaded_transformation.fitted, "Loaded transformer should be fitted." 214 | ) 215 | self.assertTrue( 216 | loaded_transformation.fitted_label, 217 | "Loaded transformer should have fitted labels.", 218 | ) 219 | 220 | x_binary_loaded, y_trans_loaded = loaded_transformation.transform( 221 | df_x, df_y 222 | ) 223 | 224 | self.assertTrue( 225 | torch.all(x_binary_original == x_binary_loaded), 226 | "Transformed data from loaded transformer should match the original transformer.", 227 | ) 228 | self.assertTrue( 229 | torch.all(y_trans_original == y_trans_loaded), 230 | "Transformed labels from loaded transformer should match the original transformer.", 231 | ) 232 | 233 | df_x_back_loaded, y_back_loaded = loaded_transformation.inverse_transform( 234 | x_binary_loaded, y_trans_loaded 235 | ) 236 | 237 | for col in self.categorical_cols: 238 | original = self.df[col].reset_index(drop=True) 239 | back = df_x_back_loaded[col].reset_index(drop=True) 240 | self.assertTrue( 241 | original.equals(back), 242 | f"Categorical column '{col}' does not match after inverse transform with loaded transformer", 243 | ) 244 | 245 | for col in self.numerical_cols: 246 | original = self.df[col].values 247 | back = df_x_back_loaded[col].values 248 | self.assertTrue( 249 | np.allclose(original, back, atol=1e-5), 250 | f"Numerical column '{col}' does not match after inverse transform with loaded transformer", 251 | ) 252 | 253 | finally: 254 | # Clean up the temporary file 255 | os.remove(temp_filepath) 256 | 257 | 258 | if __name__ == "__main__": 259 | unittest.main() 260 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from binary_diffusion_tabular import SimpleTableGenerator 7 | 8 | 9 | class TestSimpleTableGenerator(unittest.TestCase): 10 | def test_classification_conditional_without_classifier_free_guidance(self): 11 | """ 12 | Test SimpleTableGenerator for classification task with conditional=True and classifier_free_guidance=False. 13 | """ 14 | model = SimpleTableGenerator( 15 | data_dim=220, 16 | dim=256, 17 | n_res_blocks=3, 18 | out_dim=220, 19 | task="classification", 20 | conditional=True, 21 | n_classes=3, 22 | classifier_free_guidance=False, 23 | ) 24 | batch_size = 128 25 | tensor = torch.randn((batch_size, 220)) 26 | ts = torch.randint(0, 100, (batch_size,)).float() 27 | cls = torch.randint(0, 3, (batch_size,)) 28 | 29 | out = model(tensor, ts, cls) 30 | expected_shape = (batch_size, 220) 31 | self.assertEqual( 32 | out.shape, 33 | expected_shape, 34 | f"Output shape should be {expected_shape}, got {out.shape}", 35 | ) 36 | 37 | def test_classification_conditional_with_classifier_free_guidance(self): 38 | """ 39 | Test SimpleTableGenerator for classification task with conditional=True and classifier_free_guidance=True. 40 | """ 41 | model = SimpleTableGenerator( 42 | data_dim=220, 43 | dim=256, 44 | n_res_blocks=3, 45 | out_dim=220, 46 | task="classification", 47 | conditional=True, 48 | n_classes=3, 49 | classifier_free_guidance=True, 50 | ) 51 | batch_size = 128 52 | tensor = torch.randn((batch_size, 220)) 53 | ts = torch.randint(0, 100, (batch_size,)).float() 54 | cls = torch.randint(0, 3, (batch_size,)) 55 | 56 | # for classifier free guidance, cls should be one-hot 57 | cls = F.one_hot(cls, num_classes=3).float() 58 | 59 | out = model(tensor, ts, cls) 60 | expected_shape = (batch_size, 220) 61 | self.assertEqual( 62 | out.shape, 63 | expected_shape, 64 | f"Output shape should be {expected_shape}, got {out.shape}", 65 | ) 66 | 67 | def test_classification_unconditional(self): 68 | """ 69 | Test SimpleTableGenerator for classification task with conditional=False. 70 | """ 71 | model = SimpleTableGenerator( 72 | data_dim=220, 73 | dim=256, 74 | n_res_blocks=3, 75 | out_dim=220, 76 | task="classification", 77 | conditional=False, 78 | n_classes=0, # Irrelevant since conditional=False 79 | classifier_free_guidance=False, # Irrelevant since conditional=False 80 | ) 81 | batch_size = 128 82 | tensor = torch.randn((batch_size, 220)) 83 | ts = torch.randint(0, 100, (batch_size,)).float() 84 | 85 | out = model(tensor, ts) # No class labels provided 86 | expected_shape = (batch_size, 220) 87 | self.assertEqual( 88 | out.shape, 89 | expected_shape, 90 | f"Output shape should be {expected_shape}, got {out.shape}", 91 | ) 92 | 93 | def test_regression_conditional(self): 94 | """ 95 | Test SimpleTableGenerator for regression task with conditional=True. 96 | """ 97 | model = SimpleTableGenerator( 98 | data_dim=220, 99 | dim=256, 100 | n_res_blocks=3, 101 | out_dim=220, 102 | task="regression", 103 | conditional=True, 104 | n_classes=0, # Irrelevant for regression 105 | classifier_free_guidance=False, # Irrelevant for regression 106 | ) 107 | batch_size = 128 108 | tensor = torch.randn((batch_size, 220)) 109 | ts = torch.randint(0, 100, (batch_size,)).float() 110 | reg = torch.randn((batch_size, 1)) # Regression targets 111 | 112 | out = model(tensor, ts, reg) 113 | expected_shape = (batch_size, 220) 114 | self.assertEqual( 115 | out.shape, 116 | expected_shape, 117 | f"Output shape should be {expected_shape}, got {out.shape}", 118 | ) 119 | 120 | def test_regression_unconditional(self): 121 | """ 122 | Test SimpleTableGenerator for regression task with conditional=False. 123 | """ 124 | model = SimpleTableGenerator( 125 | data_dim=220, 126 | dim=256, 127 | n_res_blocks=3, 128 | out_dim=220, 129 | task="regression", 130 | conditional=False, 131 | n_classes=0, # Irrelevant since conditional=False 132 | classifier_free_guidance=False, # Irrelevant since conditional=False 133 | ) 134 | batch_size = 128 135 | tensor = torch.randn((batch_size, 220)) 136 | ts = torch.randint(0, 100, (batch_size,)).float() 137 | 138 | out = model(tensor, ts) # No regression targets provided 139 | expected_shape = (batch_size, 220) 140 | self.assertEqual( 141 | out.shape, 142 | expected_shape, 143 | f"Output shape should be {expected_shape}, got {out.shape}", 144 | ) 145 | 146 | def test_invalid_task(self): 147 | """ 148 | Test that providing an invalid task raises an error. 149 | """ 150 | with self.assertRaises(ValueError): 151 | SimpleTableGenerator( 152 | data_dim=220, 153 | dim=256, 154 | n_res_blocks=3, 155 | out_dim=220, 156 | task="invalid_task", # Invalid task 157 | conditional=True, 158 | n_classes=3, 159 | classifier_free_guidance=False, 160 | ) 161 | 162 | def test_incorrect_class_count(self): 163 | """ 164 | Test that providing a non-positive number of classes raises an error for classification. 165 | """ 166 | with self.assertRaises(ValueError): 167 | SimpleTableGenerator( 168 | data_dim=220, 169 | dim=256, 170 | n_res_blocks=3, 171 | out_dim=220, 172 | task="classification", 173 | conditional=True, 174 | n_classes=0, # Invalid number of classes 175 | classifier_free_guidance=False, 176 | ) 177 | 178 | def test_output_dtype(self): 179 | """ 180 | Test that the output tensor has the correct dtype. 181 | """ 182 | model = SimpleTableGenerator( 183 | data_dim=220, 184 | dim=256, 185 | n_res_blocks=3, 186 | out_dim=220, 187 | task="classification", 188 | conditional=True, 189 | n_classes=3, 190 | classifier_free_guidance=False, 191 | ) 192 | batch_size = 128 193 | tensor = torch.randn((batch_size, 220)) 194 | ts = torch.randint(0, 100, (batch_size,)).float() 195 | cls = torch.randint(0, 3, (batch_size,)) 196 | 197 | out = model(tensor, ts, cls) 198 | expected_dtype = torch.float32 # Assuming the model outputs float tensors 199 | self.assertEqual( 200 | out.dtype, 201 | expected_dtype, 202 | f"Output dtype should be {expected_dtype}, got {out.dtype}", 203 | ) 204 | 205 | def test_batch_size_zero(self): 206 | """ 207 | Test that the model can handle a batch size of zero without errors. 208 | """ 209 | model = SimpleTableGenerator( 210 | data_dim=220, 211 | dim=256, 212 | n_res_blocks=3, 213 | out_dim=220, 214 | task="classification", 215 | conditional=True, 216 | n_classes=3, 217 | classifier_free_guidance=False, 218 | ) 219 | batch_size = 0 220 | tensor = torch.randn((batch_size, 220)) 221 | ts = torch.randint(0, 100, (batch_size,)).float() 222 | cls = torch.randint(0, 3, (batch_size,)) 223 | 224 | out = model(tensor, ts, cls) 225 | expected_shape = (batch_size, 220) 226 | self.assertEqual( 227 | out.shape, 228 | expected_shape, 229 | f"Output shape should be {expected_shape}, got {out.shape}", 230 | ) 231 | 232 | def test_large_batch_size(self): 233 | """ 234 | Test that the model can handle a large batch size without errors. 235 | """ 236 | model = SimpleTableGenerator( 237 | data_dim=220, 238 | dim=256, 239 | n_res_blocks=3, 240 | out_dim=220, 241 | task="classification", 242 | conditional=True, 243 | n_classes=3, 244 | classifier_free_guidance=False, 245 | ) 246 | batch_size = 1024 247 | tensor = torch.randn((batch_size, 220)) 248 | ts = torch.randint(0, 100, (batch_size,)).float() 249 | cls = torch.randint(0, 3, (batch_size,)) 250 | 251 | out = model(tensor, ts, cls) 252 | expected_shape = (batch_size, 220) 253 | self.assertEqual( 254 | out.shape, 255 | expected_shape, 256 | f"Output shape should be {expected_shape}, got {out.shape}", 257 | ) 258 | 259 | 260 | if __name__ == "__main__": 261 | unittest.main() 262 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import wandb 4 | 5 | from binary_diffusion_tabular.trainer import FixedSizeTableBinaryDiffusionTrainer 6 | from binary_diffusion_tabular.utils import get_config 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--config", "-c", type=str, required=True) 12 | args = parser.parse_args() 13 | config = get_config(args.config) 14 | 15 | comment = config["comment"] 16 | logger = wandb.init(project="binary_diffusion_tabular", name=comment, config=config) 17 | 18 | trainer = FixedSizeTableBinaryDiffusionTrainer.from_config(config, logger=logger) 19 | 20 | if config["fine_tune_from"]: 21 | trainer.load_checkpoint(config["fine_tune_from"]) 22 | 23 | trainer.train() 24 | --------------------------------------------------------------------------------