├── .editorconfig ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md └── workflows │ ├── dependency_checker.yml │ ├── integration.yml │ ├── lint.yml │ ├── minimum.yml │ ├── prepare_release.yml │ ├── readme.yml │ └── unit.yml ├── .gitignore ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── HISTORY.md ├── LICENSE ├── Makefile ├── README.md ├── ctgan ├── __init__.py ├── __main__.py ├── data.py ├── data_sampler.py ├── data_transformer.py ├── demo.py ├── errors.py └── synthesizers │ ├── __init__.py │ ├── base.py │ ├── ctgan.py │ └── tvae.py ├── examples ├── csv │ ├── adult.csv │ └── adult.json └── tsv │ ├── acs.dat │ ├── acs.meta │ ├── adult.dat │ ├── adult.meta │ ├── br2000.dat │ ├── br2000.meta │ ├── nltcs.dat │ └── nltcs.meta ├── latest_requirements.txt ├── pyproject.toml ├── scripts └── release_notes_generator.py ├── static_code_analysis.txt ├── tasks.py ├── tests ├── __init__.py ├── integration │ ├── __init__.py │ ├── synthesizer │ │ ├── __init__.py │ │ ├── test_ctgan.py │ │ └── test_tvae.py │ ├── test_data_transformer.py │ └── test_load_demo.py ├── test_tasks.py └── unit │ ├── __init__.py │ ├── synthesizer │ ├── __init__.py │ ├── test_base.py │ ├── test_ctgan.py │ └── test_tvae.py │ └── test_data_transformer.py └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.py] 14 | max_line_length = 99 15 | 16 | [*.bat] 17 | indent_style = tab 18 | end_of_line = crlf 19 | 20 | [LICENSE] 21 | insert_final_newline = false 22 | 23 | [Makefile] 24 | indent_style = tab 25 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Global rule: 2 | * @sdv-dev/core-contributors 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report an error that you found when using CTGAN 4 | title: '' 5 | labels: bug, new 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Environment Details 11 | 12 | Please indicate the following details about the environment in which you found the bug: 13 | 14 | * CTGAN version: 15 | * Python version: 16 | * Operating System: 17 | 18 | ### Error Description 19 | 20 | 22 | 23 | ### Steps to reproduce 24 | 25 | 29 | 30 | ``` 31 | Paste the command(s) you ran and the output. 32 | If there was a crash, please include the traceback here. 33 | ``` 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Request a new feature that you would like to see implemented in CTGAN 4 | title: '' 5 | labels: new feature, new 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Problem Description 11 | 12 | 14 | 15 | ### Expected behavior 16 | 17 | 20 | 21 | ### Additional context 22 | 23 | 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Doubts about CTGAN usage 4 | title: '' 5 | labels: question, new 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Environment details 11 | 12 | If you are already running CTGAN, please indicate the following details about the environment in 13 | which you are running it: 14 | 15 | * CTGAN version: 16 | * Python version: 17 | * Operating System: 18 | 19 | ### Problem description 20 | 21 | 24 | 25 | ### What I already tried 26 | 27 | 29 | 30 | ``` 31 | Paste the command(s) you ran and the output. 32 | If there was a crash, please include the traceback here. 33 | ``` 34 | -------------------------------------------------------------------------------- /.github/workflows/dependency_checker.yml: -------------------------------------------------------------------------------- 1 | name: Dependency Checker 2 | on: 3 | schedule: 4 | - cron: '0 0 * * 1' 5 | workflow_dispatch: 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Set up Python 3.9 12 | uses: actions/setup-python@v5 13 | with: 14 | python-version: 3.9 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install .[dev] 18 | make check-deps OUTPUT_FILEPATH=latest_requirements.txt 19 | make fix-lint 20 | - name: Create pull request 21 | id: cpr 22 | uses: peter-evans/create-pull-request@v4 23 | with: 24 | token: ${{ secrets.GH_ACCESS_TOKEN }} 25 | commit-message: Update latest dependencies 26 | author: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 27 | committer: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 28 | title: Automated Latest Dependency Updates 29 | body: "This is an auto-generated PR with **latest** dependency updates." 30 | branch: latest-dependency-update 31 | branch-suffix: short-commit-hash 32 | base: main 33 | -------------------------------------------------------------------------------- /.github/workflows/integration.yml: -------------------------------------------------------------------------------- 1 | name: Integration Tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | jobs: 9 | integration: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 14 | os: [ubuntu-latest, windows-latest] 15 | include: 16 | - os: macos-latest 17 | python-version: '3.8' 18 | - os: macos-latest 19 | python-version: '3.13' 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install invoke .[test] 30 | - name: Run integration tests 31 | run: invoke integration 32 | 33 | - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.13 34 | name: Upload integration codecov report 35 | uses: codecov/codecov-action@v4 36 | with: 37 | flags: integration 38 | file: ${{ github.workspace }}/integration_cov.xml 39 | fail_ci_if_error: true 40 | token: ${{ secrets.CODECOV_TOKEN }} 41 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Style Checks 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | jobs: 9 | lint: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 3.9 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: 3.9 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | python -m pip install invoke .[dev] 21 | - name: Run lint checks 22 | run: invoke lint 23 | -------------------------------------------------------------------------------- /.github/workflows/minimum.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests Minimum Versions 2 | concurrency: 3 | group: ${{ github.workflow }}-${{ github.ref }} 4 | cancel-in-progress: true 5 | on: 6 | push: 7 | pull_request: 8 | types: [opened, reopened] 9 | jobs: 10 | minimum: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 15 | os: [ubuntu-latest, windows-latest] 16 | include: 17 | - os: macos-13 18 | python-version: '3.8' 19 | - os: macos-latest 20 | python-version: '3.13' 21 | steps: 22 | - uses: actions/checkout@v4 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install invoke .[test] 31 | - name: Test with minimum versions 32 | run: invoke minimum 33 | -------------------------------------------------------------------------------- /.github/workflows/prepare_release.yml: -------------------------------------------------------------------------------- 1 | name: Release Prep 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | branch: 7 | description: 'Branch to merge release notes and code analysis into.' 8 | required: true 9 | default: 'main' 10 | version: 11 | description: 12 | 'Version to use for the release. Must be in format: X.Y.Z.' 13 | date: 14 | description: 15 | 'Date of the release. Must be in format YYYY-MM-DD.' 16 | 17 | jobs: 18 | preparerelease: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: '3.10' 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install requests==2.31.0 31 | python -m pip install bandit==1.7.7 32 | python -m pip install .[test] 33 | 34 | - name: Generate release notes 35 | env: 36 | GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} 37 | run: > 38 | python scripts/release_notes_generator.py 39 | -v ${{ inputs.version }} 40 | -d ${{ inputs.date }} 41 | 42 | - name: Save static code analysis 43 | run: bandit -r . -x ./tests,./scripts,./build -f txt -o static_code_analysis.txt --exit-zero 44 | 45 | - name: Create pull request 46 | id: cpr 47 | uses: peter-evans/create-pull-request@v4 48 | with: 49 | token: ${{ secrets.GH_ACCESS_TOKEN }} 50 | commit-message: Prepare release for v${{ inputs.version }} 51 | author: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 52 | committer: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 53 | title: v${{ inputs.version }} Release Preparation 54 | body: "This is an auto-generated PR to prepare the release." 55 | branch: prepared-release 56 | branch-suffix: short-commit-hash 57 | base: ${{ inputs.branch }} 58 | -------------------------------------------------------------------------------- /.github/workflows/readme.yml: -------------------------------------------------------------------------------- 1 | name: Test README 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | jobs: 9 | readme: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 14 | os: [ubuntu-latest, macos-latest] # skip windows bc rundoc fails 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install invoke rundoc . 25 | python -m pip install tomli 26 | python -m pip install packaging 27 | - name: Run the README.md 28 | run: invoke readme 29 | -------------------------------------------------------------------------------- /.github/workflows/unit.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | jobs: 9 | unit: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 14 | os: [ubuntu-latest, windows-latest] 15 | include: 16 | - os: macos-latest 17 | python-version: '3.8' 18 | - os: macos-latest 19 | python-version: '3.13' 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install invoke .[test] 30 | - name: Run unit tests 31 | run: invoke unit 32 | 33 | - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.13 34 | name: Upload unit codecov report 35 | uses: codecov/codecov-action@v4 36 | with: 37 | flags: unit 38 | file: ${{ github.workspace }}/unit_cov.xml 39 | fail_ci_if_error: true 40 | token: ${{ secrets.CODECOV_TOKEN }} 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | tests/readme_test/ 50 | *_cov.xml 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | docs/api/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # Vim 108 | .*.swp 109 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | See: https://github.com/sdv-dev/CTGAN/graphs/contributors 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | We love our community’s interest in the SDV ecosystem. The SDV software 6 | (and its related libraries) is owned and maintained by DataCebo, Inc. 7 | It is available under the `Business Source License`_ for you to browse. 8 | 9 | We support a large set of users with enterprise-specific intricacies and 10 | reliability needs. This has required us to be deliberate about setting 11 | the roadmap for SDV libraries. As a result, we are unable to prioritize 12 | reviewing and accepting external pull requests. As a policy, we will 13 | not be able to accept external contributions. 14 | 15 | **Would you like a bug or feature request to be addressed?** If you haven't 16 | already, we would greatly appreciate it if you could `file an issue`_ 17 | instead with the overall description of your problem. We can determine 18 | whether it’s aligned with our framework. Once discussed, our team 19 | typically resolves smaller issues within a few release cycles. 20 | We appreciate your understanding. 21 | 22 | 23 | .. _Business Source License: https://github.com/sdv-dev/CTGAN/blob/main/LICENSE 24 | .. _file an issue: https://github.com/sdv-dev/CTGAN/issues 25 | -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | # History 2 | 3 | ### v0.11.0 - 2025-02-26 4 | 5 | ### New Features 6 | 7 | * Surface error to user during fit if training data contains null values - Issue [#414](https://github.com/sdv-dev/CTGAN/issues/414) by @rwedge 8 | 9 | ### Maintenance 10 | 11 | * Combine `static_code_analysis.yml` with `release_notes.yml` - Issue [#421](https://github.com/sdv-dev/CTGAN/issues/421) by @R-Palazzo 12 | * Support Python 3.13 - Issue [#411](https://github.com/sdv-dev/CTGAN/issues/411) by @rwedge 13 | * Update codecov and add flag for integration tests - Issue [#410](https://github.com/sdv-dev/CTGAN/issues/410) by @pvk-developer 14 | 15 | ## v0.10.2 - 2024-10-22 16 | 17 | ### Bugs Fixed 18 | 19 | * Cap numpy to less than 2.0.0 until CTGan supports - Issue [#387](https://github.com/sdv-dev/CTGAN/issues/387) by @gsheni 20 | * Redundant whitespace in the demo data - Issue [#233](https://github.com/sdv-dev/CTGAN/issues/233) 21 | 22 | ### Internal 23 | 24 | * Add workflow to generate release notes - Issue [#404](https://github.com/sdv-dev/CTGAN/issues/404) by @amontanez24 25 | 26 | ### Maintenance 27 | 28 | * Switch to using ruff for Python linting and code formatting - Issue [#335](https://github.com/sdv-dev/CTGAN/issues/335) by @gsheni 29 | 30 | ### Miscellaneous 31 | 32 | * Add support for numpy 2.0.0 - Issue [#386](https://github.com/sdv-dev/CTGAN/issues/386) by @R-Palazzo 33 | 34 | # v0.10.1 - 2024-05-13 35 | 36 | This release removes a warning that was cluttering the console. 37 | 38 | ### Maintenance 39 | 40 | * Cleanup automated PR workflows - Issue [#370](https://github.com/sdv-dev/CTGAN/issues/370) by @R-Palazzo 41 | * Only run unit and integration tests on oldest and latest python versions for macos - Issue [#375](https://github.com/sdv-dev/CTGAN/issues/375) by @R-Palazzo 42 | 43 | ### Internal 44 | 45 | * Remove FutureWarning: Setting an item of incompatible dtype is deprecated - Issue [#373](https://github.com/sdv-dev/CTGAN/issues/373) by @fealho 46 | 47 | ## v0.10.0 - 2024-04-11 48 | 49 | This release adds support for Python 3.12! 50 | 51 | ### Maintenance 52 | 53 | * Support Python 3.12 - Issue [#324](https://github.com/sdv-dev/CTGAN/issues/324) by @fealho 54 | * Remove scikit-learn dependency - Issue [#346](https://github.com/sdv-dev/CTGAN/issues/346) by @R-Palazzo 55 | * Add bandit workflow - Issue [#353](https://github.com/sdv-dev/CTGAN/issues/353) by @R-Palazzo 56 | 57 | ### Internal 58 | 59 | * Replace integration test that uses the iris demo data - Issue [#352](https://github.com/sdv-dev/CTGAN/issues/352) by @R-Palazzo 60 | 61 | ### Bugs Fixed 62 | 63 | * Fix minimum version workflow when pointing to github branch - Issue [#355](https://github.com/sdv-dev/CTGAN/issues/355) by @R-Palazzo 64 | 65 | ## v0.9.1 - 2024-03-14 66 | 67 | This release changes the `loss_values` attribute of a CTGAN model to contain floats instead of `torch.Tensors`. 68 | 69 | ### New Features 70 | 71 | * Return loss values as float values not PyTorch objects - Issue [#332](https://github.com/sdv-dev/CTGAN/issues/332) by @fealho 72 | 73 | ### Maintenance 74 | 75 | * Transition from using setup.py to pyproject.toml to specify project metadata - Issue [#333](https://github.com/sdv-dev/CTGAN/issues/333) by @R-Palazzo 76 | * Remove bumpversion and use bump-my-version - Issue [#334](https://github.com/sdv-dev/CTGAN/issues/334) by @R-Palazzo 77 | * Add dependency checker - Issue [#336](https://github.com/sdv-dev/CTGAN/issues/336) by @amontanez24 78 | 79 | ## v0.9.0 - 2024-02-13 80 | 81 | This release makes CTGAN sampling more efficient by saving the frequency of each categorical value. 82 | 83 | ### New Features 84 | 85 | * Improve DataSampler efficiency - Issue [#327] ((https://github.com/sdv-dev/CTGAN/issue/327)) by @fealho 86 | 87 | ## v0.8.0 - 2023-11-13 88 | 89 | This release adds a progress bar that will show when setting the `verbose` parameter to `True` 90 | when initializing `TVAE`. 91 | 92 | ### New Features 93 | 94 | * Add verbosity TVAE (progress bar + save the loss values) - Issue [#300]((https://github.com/sdv-dev/CTGAN/issues/300) by @frances-h 95 | 96 | ## v0.7.5 - 2023-10-05 97 | 98 | This release adds a progress bar that will show when setting the `verbose` parameter to True when initializing `CTGAN`. It also removes a warning that was showing. 99 | 100 | ### Maintenance 101 | 102 | * Remove model_missing_values from ClusterBasedNormalizer call - PR [#310](https://github.com/sdv-dev/CTGAN/pull/310) by @fealho 103 | * Switch default branch from master to main - Issue [#311](https://github.com/sdv-dev/CTGAN/issues/311) by @amontanez24 104 | * Remove or implement CTGAN tests - Issue [#312](https://github.com/sdv-dev/CTGAN/issues/312) by @fealho 105 | 106 | ### New Features 107 | 108 | * Add progress bar for CTGAN fitting (+ save the loss values) - Issue [#298](https://github.com/sdv-dev/CTGAN/issues/298) by @frances-h 109 | 110 | ## v0.7.4 - 2023-07-25 111 | 112 | This release adds support for Python 3.11 and drops support for Python 3.7. 113 | 114 | ### Maintenance 115 | 116 | * Why is there an upper bound in the packaging requirement? (packaging<22) - Issue [#276](https://github.com/sdv-dev/CTGAN/issues/276) by @fealho 117 | * Add support for Python 3.11 - Issue [#296](https://github.com/sdv-dev/CTGAN/issues/296) by @fealho 118 | * Drop support for Python 3.7 - Issue [#302](https://github.com/sdv-dev/CTGAN/issues/302) by @fealho 119 | 120 | ## v0.7.3 - 2023-05-25 121 | 122 | This release adds support for Torch 2.0! 123 | 124 | ### Bugs Fixed 125 | 126 | * Torch 2.0 fails with cuda=False - Issue [#288](https://github.com/sdv-dev/CTGAN/issues/288) by @amontanez24 127 | 128 | ### Maintenance 129 | 130 | * Upgrade to torch 2.0 - Issue [#280](https://github.com/sdv-dev/CTGAN/issues/280) by @frances-h 131 | 132 | ## v0.7.2 - 2023-05-09 133 | 134 | This release adds support for Pandas 2.0! It also fixes a bug in the `load_demo` function. 135 | 136 | ### Bugs Fixed 137 | 138 | * load_demo raises urllib.error.HTTPError: HTTP Error 403: Forbidden - Issue [#284](https://github.com/sdv-dev/CTGAN/issues/284) by @amontanez24 139 | 140 | ### Maintenance 141 | 142 | * Remove upper bound for pandas - Issue [#282](https://github.com/sdv-dev/CTGAN/issues/282) by @frances-h 143 | 144 | ## v0.7.1 - 2023-02-23 145 | 146 | This release fixes a bug that prevented the `CTGAN` model from being saved after sampling. 147 | 148 | ### Bugs Fixed 149 | 150 | * Cannot save CTGANSynthesizer after sampling (TypeError) - Issue [#270](https://github.com/sdv-dev/CTGAN/issues/270) by @pvk-developer 151 | 152 | ## v0.7.0 - 2023-01-20 153 | 154 | This release adds support for python 3.10 and drops support for python 3.6. It also fixes a couple of the most common warnings that were surfacing. 155 | 156 | ### New Features 157 | 158 | * Support Python 3.10 and 3.11 - Issue [#259](https://github.com/sdv-dev/CTGAN/issues/259) by @pvk-developer 159 | 160 | ### Bugs Fixed 161 | 162 | * Fix SettingWithCopyWarning (may be leading to a numerical calculation bug) - Issue [#215](https://github.com/sdv-dev/CTGAN/issues/215) by @amontanez24 163 | * FutureWarning in data_transformer with pandas 1.5.0 - Issue [#246](https://github.com/sdv-dev/CTGAN/issues/246) by @amontanez24 164 | 165 | ### Maintenance 166 | 167 | * CTGAN Package Maintenance Updates - Issue [#257](https://github.com/sdv-dev/CTGAN/issues/257) by @amontanez24 168 | 169 | ## v0.6.0 - 2022-10-07 170 | 171 | This release renames the models in CTGAN. `CTGANSynthesizer` is now called `CTGAN` and `TVAESynthesizer` is now called `TVAE`. 172 | 173 | ### New Features 174 | 175 | * Rename synthesizers - Issue [#243](https://github.com/sdv-dev/CTGAN/issues/243) by @amontanez24 176 | 177 | ## v0.5.2 - 2022-08-18 178 | 179 | This release updates CTGAN to use the latest version of RDT. It also includes performance and robustness updates to the data transformer. 180 | 181 | ### Issues closed 182 | * Bump rdt version - Issue [#242](https://github.com/sdv-dev/CTGAN/issues/242) by @katxiao 183 | * Single thread data transform is slow for huge table - Issue [#151](https://github.com/sdv-dev/CTGAN/issues/151) by @mfhbree 184 | * Fix RDT api - Issue [#232](https://github.com/sdv-dev/CTGAN/issues/232) by @pvk-developer 185 | * Update macos to use latest version. - Issue [#237](https://github.com/sdv-dev/CTGAN/issues/237) by @pvk-developer 186 | * Update the RDT version to 1.0 - Issue [#224](https://github.com/sdv-dev/CTGAN/issues/224) by @pvk-developer 187 | * Update slack invite link. - Issue [#222](https://github.com/sdv-dev/CTGAN/issues/222) by @pvk-developer 188 | * robustness fix, when data have less rows than the default number of cl… - Issue [#211](https://github.com/sdv-dev/CTGAN/issues/211) by @Deathn0t 189 | 190 | ## v0.5.1 - 2022-02-25 191 | 192 | This release fixes a bug with the decoder instantiation, and also allows users to set a random state for the model 193 | fitting and sampling. 194 | 195 | ### Issues closed 196 | 197 | * Update self.decoder with correct variable name - Issue [#203](https://github.com/sdv-dev/CTGAN/issues/203) by @tejuafonja 198 | * Add random state - Issue [#204](https://github.com/sdv-dev/CTGAN/issues/204) by @katxiao 199 | 200 | ## v0.5.0 - 2021-11-18 201 | 202 | This release adds support for Python 3.9 and updates dependencies to ensure compatibility with the 203 | rest of the SDV ecosystem, and upgrades to the latests [RDT](https://github.com/sdv-dev/RDT/releases/tag/v0.6.1) 204 | release. 205 | 206 | ### Issues closed 207 | 208 | * Add support for Python 3.9 - Issue [#177](https://github.com/sdv-dev/CTGAN/issues/177) by @pvk-developer 209 | * Add pip check to CI workflows - Issue [#174](https://github.com/sdv-dev/CTGAN/issues/174) by @pvk-developer 210 | * Typo in `CTGAN` code - Issue [#158](https://github.com/sdv-dev/CTGAN/issues/158) by @ori-katz100 and @fealho 211 | 212 | ## v0.4.3 - 2021-07-12 213 | 214 | Dependency upgrades to ensure compatibility with the rest of the SDV ecosystem. 215 | 216 | ## v0.4.2 - 2021-04-27 217 | 218 | In this release, the way in which the loss function of the TVAE model was computed has been fixed. 219 | In addition, the default value of the `discriminator_decay` has been changed to a more optimal 220 | value. Also some improvements to the tests were added. 221 | 222 | ### Issues closed 223 | 224 | * `TVAE`: loss function - Issue [#143](https://github.com/sdv-dev/CTGAN/issues/143) by @fealho and @DingfanChen 225 | * Set `discriminator_decay` to `1e-6` - Pull request [#145](https://github.com/sdv-dev/CTGAN/pull/145/) by @fealho 226 | * Adds unit tests - Pull requests [#140](https://github.com/sdv-dev/CTGAN/pull/140) by @fealho 227 | 228 | ## v0.4.1 - 2021-03-30 229 | 230 | This release exposes all the hyperparameters which the user may find useful for both `CTGAN` 231 | and `TVAE`. Also `TVAE` can now be fitted on datasets that are shorter than the batch 232 | size and drops the last batch only if the data size is not divisible by the batch size. 233 | 234 | ### Issues closed 235 | 236 | * `TVAE`: Adapt `batch_size` to data size - Issue [#135](https://github.com/sdv-dev/CTGAN/issues/135) by @fealho and @csala 237 | * `ValueError` from `validate_discre_columns` with `uniqueCombinationConstraint` - Issue [133](https://github.com/sdv-dev/CTGAN/issues/133) by @fealho and @MLjungg 238 | 239 | ## v0.4.0 - 2021-02-24 240 | 241 | Maintenance relese to upgrade dependencies to ensure compatibility with the rest 242 | of the SDV libraries. 243 | 244 | Also add a validation on the CTGAN `condition_column` and `condition_value` inputs. 245 | 246 | ### Improvements 247 | 248 | * Validate condition_column and condition_value - Issue [#124](https://github.com/sdv-dev/CTGAN/issues/124) by @fealho 249 | 250 | ## v0.3.1 - 2021-01-27 251 | 252 | ### Improvements 253 | 254 | * Check discrete_columns valid before fitting - [Issue #35](https://github.com/sdv-dev/CTGAN/issues/35) by @fealho 255 | 256 | ## Bugs fixed 257 | 258 | * ValueError: max() arg is an empty sequence - [Issue #115](https://github.com/sdv-dev/CTGAN/issues/115) by @fealho 259 | 260 | ## v0.3.0 - 2020-12-18 261 | 262 | In this release we add a new TVAE model which was presented in the original CTGAN paper. 263 | It also exposes more hyperparameters and moves epochs and log_frequency from fit to the constructor. 264 | 265 | A new verbose argument has been added to optionally disable unnecessary printing, and a new hyperparameter 266 | called `discriminator_steps` has been added to CTGAN to control the number of optimization steps performed 267 | in the discriminator for each generator epoch. 268 | 269 | The code has also been reorganized and cleaned up for better readability and interpretability. 270 | 271 | Special thanks to @Baukebrenninkmeijer @fealho @leix28 @csala for the contributions! 272 | 273 | ### Improvements 274 | 275 | * Add TVAE - [Issue #111](https://github.com/sdv-dev/CTGAN/issues/111) by @fealho 276 | * Move `log_frequency` to `__init__` - [Issue #102](https://github.com/sdv-dev/CTGAN/issues/102) by @fealho 277 | * Add discriminator steps hyperparameter - [Issue #101](https://github.com/sdv-dev/CTGAN/issues/101) by @Baukebrenninkmeijer 278 | * Code cleanup / Expose hyperparameters - [Issue #59](https://github.com/sdv-dev/CTGAN/issues/59) by @fealho and @leix28 279 | * Publish to conda repo - [Issue #54](https://github.com/sdv-dev/CTGAN/issues/54) by @fealho 280 | 281 | ### Bugs fixed 282 | 283 | * Fixed NaN != NaN counting bug. - [Issue #100](https://github.com/sdv-dev/CTGAN/issues/100) by @fealho 284 | * Update dependencies and testing - [Issue #90](https://github.com/sdv-dev/CTGAN/issues/90) by @csala 285 | 286 | ## v0.2.2 - 2020-11-13 287 | 288 | In this release we introduce several minor improvements to make CTGAN more versatile and 289 | propertly support new types of data, such as categorical NaN values, as well as conditional 290 | sampling and features to save and load models. 291 | 292 | Additionally, the dependency ranges and python versions have been updated to support up 293 | to date runtimes. 294 | 295 | Many thanks @fealho @leix28 @csala @oregonpillow and @lurosenb for working on making this release possible! 296 | 297 | ### Improvements 298 | 299 | * Drop Python 3.5 support - [Issue #79](https://github.com/sdv-dev/CTGAN/issues/79) by @fealho 300 | * Support NaN values in categorical variables - [Issue #78](https://github.com/sdv-dev/CTGAN/issues/78) by @fealho 301 | * Sample synthetic data conditioning on a discrete column - [Issue #69](https://github.com/sdv-dev/CTGAN/issues/69) by @leix28 302 | * Support recent versions of pandas - [Issue #57](https://github.com/sdv-dev/CTGAN/issues/57) by @csala 303 | * Easy solution for restoring original dtypes - [Issue #26](https://github.com/sdv-dev/CTGAN/issues/26) by @oregonpillow 304 | 305 | ### Bugs fixed 306 | 307 | * Loss to nan - [Issue #73](https://github.com/sdv-dev/CTGAN/issues/73) by @fealho 308 | * Swapped the sklearn utils testing import statement - [Issue #53](https://github.com/sdv-dev/CTGAN/issues/53) by @lurosenb 309 | 310 | ## v0.2.1 - 2020-01-27 311 | 312 | Minor version including changes to ensure the logs are properly printed and 313 | the option to disable the log transformation to the discrete column frequencies. 314 | 315 | Special thanks to @kevinykuo for the contributions! 316 | 317 | ### Issues Resolved: 318 | 319 | * Option to sample from true data frequency instead of logged frequency - [Issue #16](https://github.com/sdv-dev/CTGAN/issues/16) by @kevinykuo 320 | * Flush stdout buffer for epoch updates - [Issue #14](https://github.com/sdv-dev/CTGAN/issues/14) by @kevinykuo 321 | 322 | ## v0.2.0 - 2019-12-18 323 | 324 | Reorganization of the project structure with a new Python API, new Command Line Interface 325 | and increased data format support. 326 | 327 | ### Issues Resolved: 328 | 329 | * Reorganize the project structure - [Issue #10](https://github.com/sdv-dev/CTGAN/issues/10) by @csala 330 | * Move epochs to the fit method - [Issue #5](https://github.com/sdv-dev/CTGAN/issues/5) by @csala 331 | 332 | ## v0.1.0 - 2019-11-07 333 | 334 | First Release - NeurIPS 2019 Version. 335 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Business Source License 1.1 3 | 4 | Parameters 5 | 6 | Licensor: DataCebo, Inc. 7 | 8 | Licensed Work: CTGAN 9 | The Licensed Work is (c) DataCebo, Inc. 10 | 11 | Additional Use Grant: You may make use of the Licensed Work, and derivatives of the Licensed 12 | Work, provided that you do not use the Licensed Work, or derivatives of 13 | the Licensed Work, for a Synthetic Data Creation Service. 14 | 15 | A "Synthetic Data Creation Service" is a commercial offering 16 | that allows third parties (other than your employees and 17 | contractors) to access the functionality of the Licensed 18 | Work so that such third parties directly benefit from the 19 | data processing, machine learning or synthetic data creation 20 | features of the Licensed Work. 21 | 22 | Change Date: Change date is four years from release date. 23 | Please see https://github.com/sdv-dev/CTGAN/releases 24 | for exact dates. 25 | 26 | Change License: MIT License 27 | 28 | 29 | Notice 30 | 31 | The Business Source License (this document, or the "License") is not an Open 32 | Source license. However, the Licensed Work will eventually be made available 33 | under an Open Source License, as stated in this License. 34 | 35 | License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved. 36 | "Business Source License" is a trademark of MariaDB Corporation Ab. 37 | 38 | ----------------------------------------------------------------------------- 39 | 40 | Business Source License 1.1 41 | 42 | Terms 43 | 44 | The Licensor hereby grants you the right to copy, modify, create derivative 45 | works, redistribute, and make non-production use of the Licensed Work. The 46 | Licensor may make an Additional Use Grant, above, permitting limited 47 | production use. 48 | 49 | Effective on the Change Date, or the fourth anniversary of the first publicly 50 | available distribution of a specific version of the Licensed Work under this 51 | License, whichever comes first, the Licensor hereby grants you rights under 52 | the terms of the Change License, and the rights granted in the paragraph 53 | above terminate. 54 | 55 | If your use of the Licensed Work does not comply with the requirements 56 | currently in effect as described in this License, you must purchase a 57 | commercial license from the Licensor, its affiliated entities, or authorized 58 | resellers, or you must refrain from using the Licensed Work. 59 | 60 | All copies of the original and modified Licensed Work, and derivative works 61 | of the Licensed Work, are subject to this License. This License applies 62 | separately for each version of the Licensed Work and the Change Date may vary 63 | for each version of the Licensed Work released by Licensor. 64 | 65 | You must conspicuously display this License on each original or modified copy 66 | of the Licensed Work. If you receive the Licensed Work in original or 67 | modified form from a third party, the terms and conditions set forth in this 68 | License apply to your use of that work. 69 | 70 | Any use of the Licensed Work in violation of this License will automatically 71 | terminate your rights under this License for the current and all other 72 | versions of the Licensed Work. 73 | 74 | This License does not grant you any right in any trademark or logo of 75 | Licensor or its affiliates (provided that you may use a trademark or logo of 76 | Licensor as expressly required by this License). 77 | 78 | TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON 79 | AN "AS IS" BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, 80 | EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND 82 | TITLE. 83 | 84 | MariaDB hereby grants you permission to use this License’s text to license 85 | your works, and to refer to it using the trademark "Business Source License", 86 | as long as you comply with the Covenants of Licensor below. 87 | 88 | Covenants of Licensor 89 | 90 | In consideration of the right to use this License’s text and the "Business 91 | Source License" name and trademark, Licensor covenants to MariaDB, and to all 92 | other recipients of the licensed work to be provided by Licensor: 93 | 94 | 1. To specify as the Change License the GPL Version 2.0 or any later version, 95 | or a license that is compatible with GPL Version 2.0 or a later version, 96 | where "compatible" means that software provided under the Change License can 97 | be included in a program with software provided under GPL Version 2.0 or a 98 | later version. Licensor may specify additional Change Licenses without 99 | limitation. 100 | 101 | 2. To either: (a) specify an additional grant of rights to use that does not 102 | impose any additional restriction on the right granted in this License, as 103 | the Additional Use Grant; or (b) insert the text "None". 104 | 105 | 3. To specify a Change Date. 106 | 107 | 4. Not to modify this License in any other way. 108 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | 3 | define BROWSER_PYSCRIPT 4 | import os, webbrowser, sys 5 | 6 | try: 7 | from urllib import pathname2url 8 | except: 9 | from urllib.request import pathname2url 10 | 11 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 12 | endef 13 | export BROWSER_PYSCRIPT 14 | 15 | define PRINT_HELP_PYSCRIPT 16 | import re, sys 17 | 18 | for line in sys.stdin: 19 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 20 | if match: 21 | target, help = match.groups() 22 | print("%-20s %s" % (target, help)) 23 | endef 24 | export PRINT_HELP_PYSCRIPT 25 | 26 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 27 | 28 | .PHONY: help 29 | help: 30 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 31 | 32 | 33 | # CLEAN TARGETS 34 | 35 | .PHONY: clean-build 36 | clean-build: ## remove build artifacts 37 | rm -fr build/ 38 | rm -fr dist/ 39 | rm -fr .eggs/ 40 | find . -name '*.egg-info' -exec rm -fr {} + 41 | find . -name '*.egg' -exec rm -f {} + 42 | 43 | .PHONY: clean-pyc 44 | clean-pyc: ## remove Python file artifacts 45 | find . -name '*.pyc' -exec rm -f {} + 46 | find . -name '*.pyo' -exec rm -f {} + 47 | find . -name '*~' -exec rm -f {} + 48 | find . -name '__pycache__' -exec rm -fr {} + 49 | 50 | .PHONY: clean-coverage 51 | clean-coverage: ## remove coverage artifacts 52 | rm -f .coverage 53 | rm -f .coverage.* 54 | rm -fr htmlcov/ 55 | 56 | .PHONY: clean-test 57 | clean-test: ## remove test artifacts 58 | rm -fr .tox/ 59 | rm -fr .pytest_cache 60 | 61 | .PHONY: clean 62 | clean: clean-build clean-pyc clean-test clean-coverage ## remove all build, test, coverage and Python artifacts 63 | 64 | 65 | # INSTALL TARGETS 66 | 67 | .PHONY: install 68 | install: clean-build clean-pyc ## install the package to the active Python's site-packages 69 | pip install . 70 | 71 | .PHONY: install-test 72 | install-test: clean-build clean-pyc ## install the package and test dependencies 73 | pip install .[test] 74 | 75 | .PHONY: install-develop 76 | install-develop: clean-build clean-pyc ## install the package in editable mode and dependencies for development 77 | pip install -e .[dev] 78 | 79 | 80 | # LINT TARGETS 81 | 82 | .PHONY: lint 83 | lint: 84 | invoke lint 85 | 86 | .PHONY: fix-lint 87 | fix-lint: 88 | invoke fix-lint 89 | 90 | 91 | # TEST TARGETS 92 | 93 | .PHONY: test-unit 94 | test-unit: ## run unit tests using pytest 95 | invoke unit 96 | 97 | .PHONY: test-integration 98 | test-integration: ## run integration tests using pytest 99 | invoke integration 100 | 101 | .PHONY: test-readme 102 | test-readme: ## run the readme snippets 103 | invoke readme 104 | 105 | .PHONY: check-dependencies 106 | check-dependencies: ## test if there are any broken dependencies 107 | pip check 108 | 109 | .PHONY: test 110 | test: test-unit test-integration test-readme ## test everything that needs test dependencies 111 | 112 | .PHONY: test-devel 113 | test-devel: lint ## test everything that needs development dependencies 114 | 115 | .PHONY: test-all 116 | test-all: ## run tests on every Python version with tox 117 | tox -r 118 | 119 | 120 | .PHONY: coverage 121 | coverage: ## check code coverage quickly with the default Python 122 | coverage run --source ctgan -m pytest 123 | coverage report -m 124 | coverage html 125 | $(BROWSER) htmlcov/index.html 126 | 127 | 128 | # RELEASE TARGETS 129 | 130 | .PHONY: dist 131 | dist: clean ## builds source and wheel package 132 | python -m build --wheel --sdist 133 | ls -l dist 134 | 135 | .PHONY: publish-confirm 136 | publish-confirm: 137 | @echo "WARNING: This will irreversibly upload a new version to PyPI!" 138 | @echo -n "Please type 'confirm' to proceed: " \ 139 | && read answer \ 140 | && [ "$${answer}" = "confirm" ] 141 | 142 | .PHONY: publish-test 143 | publish-test: dist publish-confirm ## package and upload a release on TestPyPI 144 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 145 | 146 | .PHONY: publish 147 | publish: dist publish-confirm ## package and upload a release 148 | twine upload dist/* 149 | 150 | .PHONY: bumpversion-release 151 | bumpversion-release: ## Merge main to stable and bumpversion release 152 | git checkout stable || git checkout -b stable 153 | git merge --no-ff main -m"make release-tag: Merge branch 'main' into stable" 154 | bump-my-version bump release 155 | git push --tags origin stable 156 | 157 | .PHONY: bumpversion-release-test 158 | bumpversion-release-test: ## Merge main to stable and bumpversion release 159 | git checkout stable || git checkout -b stable 160 | git merge --no-ff main -m"make release-tag: Merge branch 'main' into stable" 161 | bump-my-version bump release --no-tag 162 | @echo git push --tags origin stable 163 | 164 | .PHONY: bumpversion-patch 165 | bumpversion-patch: ## Merge stable to main and bumpversion patch 166 | git checkout main 167 | git merge stable 168 | bump-my-version bump --no-tag patch 169 | git push 170 | 171 | .PHONY: bumpversion-candidate 172 | bumpversion-candidate: ## Bump the version to the next candidate 173 | bump-my-version bump candidate --no-tag 174 | 175 | .PHONY: bumpversion-minor 176 | bumpversion-minor: ## Bump the version the next minor skipping the release 177 | bump-my-version bump --no-tag minor 178 | 179 | .PHONY: bumpversion-major 180 | bumpversion-major: ## Bump the version the next major skipping the release 181 | bump-my-version bump --no-tag major 182 | 183 | .PHONY: bumpversion-revert 184 | bumpversion-revert: ## Undo a previous bumpversion-release 185 | git checkout main 186 | git branch -D stable 187 | 188 | CLEAN_DIR := $(shell git status --short | grep -v ??) 189 | CURRENT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD 2>/dev/null) 190 | CHANGELOG_LINES := $(shell git diff HEAD..origin/stable HISTORY.md 2>&1 | wc -l) 191 | 192 | .PHONY: check-clean 193 | check-clean: ## Check if the directory has uncommitted changes 194 | ifneq ($(CLEAN_DIR),) 195 | $(error There are uncommitted changes) 196 | endif 197 | 198 | .PHONY: check-main 199 | check-main: ## Check if we are in main branch 200 | ifneq ($(CURRENT_BRANCH),main) 201 | $(error Please make the release from main branch\n) 202 | endif 203 | 204 | .PHONY: check-history 205 | check-history: ## Check if HISTORY.md has been modified 206 | ifeq ($(CHANGELOG_LINES),0) 207 | $(error Please insert the release notes in HISTORY.md before releasing) 208 | endif 209 | 210 | .PHONY: git-push 211 | git-push: ## Simply push the repository to github 212 | git push 213 | 214 | .PHONY: check-release 215 | check-release: check-clean check-main check-history ## Check if the release can be made 216 | @echo "A new release can be made" 217 | 218 | .PHONY: release 219 | release: check-release bumpversion-release publish bumpversion-patch 220 | 221 | .PHONY: release-test 222 | release-test: check-release bumpversion-release-test publish-test bumpversion-revert 223 | 224 | .PHONY: release-candidate 225 | release-candidate: check-main publish bumpversion-candidate git-push 226 | 227 | .PHONY: release-candidate-test 228 | release-candidate-test: check-clean check-main publish-test 229 | 230 | .PHONY: release-minor 231 | release-minor: check-release bumpversion-minor release 232 | 233 | .PHONY: release-major 234 | release-major: check-release bumpversion-major release 235 | 236 | # Dependency targets 237 | 238 | .PHONY: check-deps 239 | check-deps: 240 | $(eval allow_list='numpy=|pandas=|tqdm=|torch=|rdt=') 241 | pip freeze | grep -v "CTGAN.git" | grep -E $(allow_list) > $(OUTPUT_FILEPATH) 242 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

4 | This repository is part of The Synthetic Data Vault Project, a project from DataCebo. 5 |

6 | 7 | [![Development Status](https://img.shields.io/badge/Development%20Status-2%20--%20Pre--Alpha-yellow)](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha) 8 | [![PyPI Shield](https://img.shields.io/pypi/v/ctgan.svg)](https://pypi.python.org/pypi/ctgan) 9 | [![Unit Tests](https://github.com/sdv-dev/CTGAN/actions/workflows/unit.yml/badge.svg)](https://github.com/sdv-dev/CTGAN/actions/workflows/unit.yml) 10 | [![Downloads](https://pepy.tech/badge/ctgan)](https://pepy.tech/project/ctgan) 11 | [![Coverage Status](https://codecov.io/gh/sdv-dev/CTGAN/branch/main/graph/badge.svg)](https://codecov.io/gh/sdv-dev/CTGAN) 12 | 13 |
14 |
15 |

16 | 17 | 18 | 19 |

20 |
21 | 22 |
23 | 24 | # Overview 25 | 26 | CTGAN is a collection of Deep Learning based synthetic data generators for single table data, which are able to learn from real data and generate synthetic data with high fidelity. 27 | 28 | | Important Links | | 29 | | --------------------------------------------- | -------------------------------------------------------------------- | 30 | | :computer: **[Website]** | Check out the SDV Website for more information about our overall synthetic data ecosystem.| 31 | | :orange_book: **[Blog]** | A deeper look at open source, synthetic data creation and evaluation.| 32 | | :book: **[Documentation]** | Quickstarts, User and Development Guides, and API Reference. | 33 | | :octocat: **[Repository]** | The link to the Github Repository of this library. | 34 | | :keyboard: **[Development Status]** | This software is in its Pre-Alpha stage. | 35 | | [![][Slack Logo] **Community**][Community] | Join our Slack Workspace for announcements and discussions. | 36 | 37 | [Website]: https://sdv.dev 38 | [Blog]: https://datacebo.com/blog 39 | [Documentation]: https://bit.ly/sdv-docs 40 | [Repository]: https://github.com/sdv-dev/CTGAN 41 | [License]: https://github.com/sdv-dev/CTGAN/blob/main/LICENSE 42 | [Development Status]: https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha 43 | [Slack Logo]: https://github.com/sdv-dev/SDV/blob/stable/docs/images/slack.png 44 | [Community]: https://bit.ly/sdv-slack-invite 45 | 46 | Currently, this library implements the **CTGAN** and **TVAE** models described in the [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503) paper, presented at the 2019 NeurIPS conference. 47 | 48 | # Install 49 | 50 | ## Use CTGAN through the SDV library 51 | 52 | :warning: If you're just getting started with synthetic data, we recommend installing the SDV library which provides user-friendly APIs for accessing CTGAN. :warning: 53 | 54 | The SDV library provides wrappers for preprocessing your data as well as additional usability features like constraints. See the [SDV documentation](https://bit.ly/sdv-docs) to get started. 55 | 56 | ## Use the CTGAN standalone library 57 | 58 | Alternatively, you can also install and use **CTGAN** directly, as a standalone library: 59 | 60 | **Using `pip`:** 61 | 62 | ```bash 63 | pip install ctgan 64 | ``` 65 | 66 | **Using `conda`:** 67 | 68 | ```bash 69 | conda install -c pytorch -c conda-forge ctgan 70 | ``` 71 | 72 | When using the CTGAN library directly, you may need to manually preprocess your data into the correct format, for example: 73 | 74 | * Continuous data must be represented as floats 75 | * Discrete data must be represented as ints or strings 76 | * The data should not contain any missing values 77 | 78 | # Usage Example 79 | 80 | In this example we load the [Adult Census Dataset](https://archive.ics.uci.edu/ml/datasets/adult)* which is a built-in demo dataset. We use CTGAN to learn from the real data and then generate some synthetic data. 81 | 82 | ```python3 83 | from ctgan import CTGAN 84 | from ctgan import load_demo 85 | 86 | real_data = load_demo() 87 | 88 | # Names of the columns that are discrete 89 | discrete_columns = [ 90 | 'workclass', 91 | 'education', 92 | 'marital-status', 93 | 'occupation', 94 | 'relationship', 95 | 'race', 96 | 'sex', 97 | 'native-country', 98 | 'income' 99 | ] 100 | 101 | ctgan = CTGAN(epochs=10) 102 | ctgan.fit(real_data, discrete_columns) 103 | 104 | # Create synthetic data 105 | synthetic_data = ctgan.sample(1000) 106 | ``` 107 | 108 | *For more information about the dataset see: 109 | Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. 110 | Irvine, CA: University of California, School of Information and Computer Science. 111 | 112 | # Join our community 113 | 114 | Join our [Slack channel](https://bit.ly/sdv-slack-invite) to discuss more about CTGAN and synthetic data. If you find a bug or have a feature request, you can also [open an issue](https://github.com/sdv-dev/CTGAN/issues) on our GitHub. 115 | 116 | **Interested in contributing to CTGAN?** Read our [Contribution Guide](CONTRIBUTING.rst) to get started. 117 | 118 | # Citing CTGAN 119 | 120 | If you use CTGAN, please cite the following work: 121 | 122 | *Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni.* **Modeling Tabular data using Conditional GAN**. NeurIPS, 2019. 123 | 124 | ```LaTeX 125 | @inproceedings{ctgan, 126 | title={Modeling Tabular data using Conditional GAN}, 127 | author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan}, 128 | booktitle={Advances in Neural Information Processing Systems}, 129 | year={2019} 130 | } 131 | ``` 132 | 133 | # Related Projects 134 | Please note that these projects are external to the SDV Ecosystem. They are not affiliated with or maintained by DataCebo. 135 | 136 | * **R Interface for CTGAN**: A wrapper around **CTGAN** that brings the functionalities to **R** users. 137 | More details can be found in the corresponding repository: https://github.com/kasaai/ctgan 138 | * **CTGAN Server CLI**: A package to easily deploy CTGAN onto a remote server. Created by Timothy Pillow @oregonpillow at: https://github.com/oregonpillow/ctgan-server-cli 139 | 140 | --- 141 | 142 | 143 |
144 | 145 |
146 |
147 |
148 | 149 | [The Synthetic Data Vault Project](https://sdv.dev) was first created at MIT's [Data to AI Lab]( 150 | https://dai.lids.mit.edu/) in 2016. After 4 years of research and traction with enterprise, we 151 | created [DataCebo](https://datacebo.com) in 2020 with the goal of growing the project. 152 | Today, DataCebo is the proud developer of SDV, the largest ecosystem for 153 | synthetic data generation & evaluation. It is home to multiple libraries that support synthetic 154 | data, including: 155 | 156 | * 🔄 Data discovery & transformation. Reverse the transforms to reproduce realistic data. 157 | * 🧠 Multiple machine learning models -- ranging from Copulas to Deep Learning -- to create tabular, 158 | multi table and time series data. 159 | * 📊 Measuring quality and privacy of synthetic data, and comparing different synthetic data 160 | generation models. 161 | 162 | [Get started using the SDV package](https://sdv.dev/SDV/getting_started/install.html) -- a fully 163 | integrated solution and your one-stop shop for synthetic data. Or, use the standalone libraries 164 | for specific needs. 165 | -------------------------------------------------------------------------------- /ctgan/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for ctgan.""" 4 | 5 | __author__ = 'DataCebo, Inc.' 6 | __email__ = 'info@sdv.dev' 7 | __version__ = '0.11.1.dev0' 8 | 9 | from ctgan.demo import load_demo 10 | from ctgan.synthesizers.ctgan import CTGAN 11 | from ctgan.synthesizers.tvae import TVAE 12 | 13 | __all__ = ('CTGAN', 'TVAE', 'load_demo') 14 | -------------------------------------------------------------------------------- /ctgan/__main__.py: -------------------------------------------------------------------------------- 1 | """CLI.""" 2 | 3 | import argparse 4 | 5 | from ctgan.data import read_csv, read_tsv, write_tsv 6 | from ctgan.synthesizers.ctgan import CTGAN 7 | 8 | 9 | def _parse_args(): 10 | parser = argparse.ArgumentParser(description='CTGAN Command Line Interface') 11 | parser.add_argument('-e', '--epochs', default=300, type=int, help='Number of training epochs') 12 | parser.add_argument( 13 | '-t', '--tsv', action='store_true', help='Load data in TSV format instead of CSV' 14 | ) 15 | parser.add_argument( 16 | '--no-header', 17 | dest='header', 18 | action='store_false', 19 | help='The CSV file has no header. Discrete columns will be indices.', 20 | ) 21 | 22 | parser.add_argument('-m', '--metadata', help='Path to the metadata') 23 | parser.add_argument( 24 | '-d', '--discrete', help='Comma separated list of discrete columns without whitespaces.' 25 | ) 26 | parser.add_argument( 27 | '-n', 28 | '--num-samples', 29 | type=int, 30 | help='Number of rows to sample. Defaults to the training data size', 31 | ) 32 | 33 | parser.add_argument( 34 | '--generator_lr', type=float, default=2e-4, help='Learning rate for the generator.' 35 | ) 36 | parser.add_argument( 37 | '--discriminator_lr', type=float, default=2e-4, help='Learning rate for the discriminator.' 38 | ) 39 | 40 | parser.add_argument( 41 | '--generator_decay', type=float, default=1e-6, help='Weight decay for the generator.' 42 | ) 43 | parser.add_argument( 44 | '--discriminator_decay', type=float, default=0, help='Weight decay for the discriminator.' 45 | ) 46 | 47 | parser.add_argument( 48 | '--embedding_dim', type=int, default=128, help='Dimension of input z to the generator.' 49 | ) 50 | parser.add_argument( 51 | '--generator_dim', 52 | type=str, 53 | default='256,256', 54 | help='Dimension of each generator layer. Comma separated integers with no whitespaces.', 55 | ) 56 | parser.add_argument( 57 | '--discriminator_dim', 58 | type=str, 59 | default='256,256', 60 | help='Dimension of each discriminator layer. Comma separated integers with no whitespaces.', 61 | ) 62 | 63 | parser.add_argument( 64 | '--batch_size', type=int, default=500, help='Batch size. Must be an even number.' 65 | ) 66 | parser.add_argument( 67 | '--save', default=None, type=str, help='A filename to save the trained synthesizer.' 68 | ) 69 | parser.add_argument( 70 | '--load', default=None, type=str, help='A filename to load a trained synthesizer.' 71 | ) 72 | 73 | parser.add_argument( 74 | '--sample_condition_column', default=None, type=str, help='Select a discrete column name.' 75 | ) 76 | parser.add_argument( 77 | '--sample_condition_column_value', 78 | default=None, 79 | type=str, 80 | help='Specify the value of the selected discrete column.', 81 | ) 82 | 83 | parser.add_argument('data', help='Path to training data') 84 | parser.add_argument('output', help='Path of the output file') 85 | 86 | return parser.parse_args() 87 | 88 | 89 | def main(): 90 | """CLI.""" 91 | args = _parse_args() 92 | if args.tsv: 93 | data, discrete_columns = read_tsv(args.data, args.metadata) 94 | else: 95 | data, discrete_columns = read_csv(args.data, args.metadata, args.header, args.discrete) 96 | 97 | if args.load: 98 | model = CTGAN.load(args.load) 99 | else: 100 | generator_dim = [int(x) for x in args.generator_dim.split(',')] 101 | discriminator_dim = [int(x) for x in args.discriminator_dim.split(',')] 102 | model = CTGAN( 103 | embedding_dim=args.embedding_dim, 104 | generator_dim=generator_dim, 105 | discriminator_dim=discriminator_dim, 106 | generator_lr=args.generator_lr, 107 | generator_decay=args.generator_decay, 108 | discriminator_lr=args.discriminator_lr, 109 | discriminator_decay=args.discriminator_decay, 110 | batch_size=args.batch_size, 111 | epochs=args.epochs, 112 | ) 113 | model.fit(data, discrete_columns) 114 | 115 | if args.save is not None: 116 | model.save(args.save) 117 | 118 | num_samples = args.num_samples or len(data) 119 | 120 | if args.sample_condition_column is not None: 121 | assert args.sample_condition_column_value is not None 122 | 123 | sampled = model.sample( 124 | num_samples, args.sample_condition_column, args.sample_condition_column_value 125 | ) 126 | 127 | if args.tsv: 128 | write_tsv(sampled, args.metadata, args.output) 129 | else: 130 | sampled.to_csv(args.output, index=False) 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /ctgan/data.py: -------------------------------------------------------------------------------- 1 | """Data loading.""" 2 | 3 | import json 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | def read_csv(csv_filename, meta_filename=None, header=True, discrete=None): 10 | """Read a csv file.""" 11 | data = pd.read_csv(csv_filename, header='infer' if header else None) 12 | 13 | if meta_filename: 14 | with open(meta_filename) as meta_file: 15 | metadata = json.load(meta_file) 16 | 17 | discrete_columns = [ 18 | column['name'] for column in metadata['columns'] if column['type'] != 'continuous' 19 | ] 20 | 21 | elif discrete: 22 | discrete_columns = discrete.split(',') 23 | if not header: 24 | discrete_columns = [int(i) for i in discrete_columns] 25 | 26 | else: 27 | discrete_columns = [] 28 | 29 | return data, discrete_columns 30 | 31 | 32 | def read_tsv(data_filename, meta_filename): 33 | """Read a tsv file.""" 34 | with open(meta_filename) as f: 35 | column_info = f.readlines() 36 | 37 | column_info_raw = [x.replace('{', ' ').replace('}', ' ').split() for x in column_info] 38 | 39 | discrete = [] 40 | continuous = [] 41 | column_info = [] 42 | 43 | for idx, item in enumerate(column_info_raw): 44 | if item[0] == 'C': 45 | continuous.append(idx) 46 | column_info.append((float(item[1]), float(item[2]))) 47 | else: 48 | assert item[0] == 'D' 49 | discrete.append(idx) 50 | column_info.append(item[1:]) 51 | 52 | meta = { 53 | 'continuous_columns': continuous, 54 | 'discrete_columns': discrete, 55 | 'column_info': column_info, 56 | } 57 | 58 | with open(data_filename) as f: 59 | lines = f.readlines() 60 | 61 | data = [] 62 | for row in lines: 63 | row_raw = row.split() 64 | row = [] 65 | for idx, col in enumerate(row_raw): 66 | if idx in continuous: 67 | row.append(col) 68 | else: 69 | assert idx in discrete 70 | row.append(column_info[idx].index(col)) 71 | 72 | data.append(row) 73 | 74 | return np.asarray(data, dtype='float32'), meta['discrete_columns'] 75 | 76 | 77 | def write_tsv(data, meta, output_filename): 78 | """Write to a tsv file.""" 79 | with open(output_filename, 'w') as f: 80 | for row in data: 81 | for idx, col in enumerate(row): 82 | if idx in meta['continuous_columns']: 83 | print(col, end=' ', file=f) 84 | else: 85 | assert idx in meta['discrete_columns'] 86 | print(meta['column_info'][idx][int(col)], end=' ', file=f) 87 | 88 | print(file=f) 89 | -------------------------------------------------------------------------------- /ctgan/data_sampler.py: -------------------------------------------------------------------------------- 1 | """DataSampler module.""" 2 | 3 | import numpy as np 4 | 5 | 6 | class DataSampler(object): 7 | """DataSampler samples the conditional vector and corresponding data for CTGAN.""" 8 | 9 | def __init__(self, data, output_info, log_frequency): 10 | self._data_length = len(data) 11 | 12 | def is_discrete_column(column_info): 13 | return len(column_info) == 1 and column_info[0].activation_fn == 'softmax' 14 | 15 | n_discrete_columns = sum([ 16 | 1 for column_info in output_info if is_discrete_column(column_info) 17 | ]) 18 | 19 | self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype='int32') 20 | 21 | # Store the row id for each category in each discrete column. 22 | # For example _rid_by_cat_cols[a][b] is a list of all rows with the 23 | # a-th discrete column equal value b. 24 | self._rid_by_cat_cols = [] 25 | 26 | # Compute _rid_by_cat_cols 27 | st = 0 28 | for column_info in output_info: 29 | if is_discrete_column(column_info): 30 | span_info = column_info[0] 31 | ed = st + span_info.dim 32 | 33 | rid_by_cat = [] 34 | for j in range(span_info.dim): 35 | rid_by_cat.append(np.nonzero(data[:, st + j])[0]) 36 | self._rid_by_cat_cols.append(rid_by_cat) 37 | st = ed 38 | else: 39 | st += sum([span_info.dim for span_info in column_info]) 40 | assert st == data.shape[1] 41 | 42 | # Prepare an interval matrix for efficiently sample conditional vector 43 | max_category = max( 44 | [column_info[0].dim for column_info in output_info if is_discrete_column(column_info)], 45 | default=0, 46 | ) 47 | 48 | self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32') 49 | self._discrete_column_n_category = np.zeros(n_discrete_columns, dtype='int32') 50 | self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category)) 51 | self._n_discrete_columns = n_discrete_columns 52 | self._n_categories = sum([ 53 | column_info[0].dim for column_info in output_info if is_discrete_column(column_info) 54 | ]) 55 | 56 | st = 0 57 | current_id = 0 58 | current_cond_st = 0 59 | for column_info in output_info: 60 | if is_discrete_column(column_info): 61 | span_info = column_info[0] 62 | ed = st + span_info.dim 63 | category_freq = np.sum(data[:, st:ed], axis=0) 64 | if log_frequency: 65 | category_freq = np.log(category_freq + 1) 66 | category_prob = category_freq / np.sum(category_freq) 67 | self._discrete_column_category_prob[current_id, : span_info.dim] = category_prob 68 | self._discrete_column_cond_st[current_id] = current_cond_st 69 | self._discrete_column_n_category[current_id] = span_info.dim 70 | current_cond_st += span_info.dim 71 | current_id += 1 72 | st = ed 73 | else: 74 | st += sum([span_info.dim for span_info in column_info]) 75 | 76 | def _random_choice_prob_index(self, discrete_column_id): 77 | probs = self._discrete_column_category_prob[discrete_column_id] 78 | r = np.expand_dims(np.random.rand(probs.shape[0]), axis=1) 79 | return (probs.cumsum(axis=1) > r).argmax(axis=1) 80 | 81 | def sample_condvec(self, batch): 82 | """Generate the conditional vector for training. 83 | 84 | Returns: 85 | cond (batch x #categories): 86 | The conditional vector. 87 | mask (batch x #discrete columns): 88 | A one-hot vector indicating the selected discrete column. 89 | discrete column id (batch): 90 | Integer representation of mask. 91 | category_id_in_col (batch): 92 | Selected category in the selected discrete column. 93 | """ 94 | if self._n_discrete_columns == 0: 95 | return None 96 | 97 | discrete_column_id = np.random.choice(np.arange(self._n_discrete_columns), batch) 98 | 99 | cond = np.zeros((batch, self._n_categories), dtype='float32') 100 | mask = np.zeros((batch, self._n_discrete_columns), dtype='float32') 101 | mask[np.arange(batch), discrete_column_id] = 1 102 | category_id_in_col = self._random_choice_prob_index(discrete_column_id) 103 | category_id = self._discrete_column_cond_st[discrete_column_id] + category_id_in_col 104 | cond[np.arange(batch), category_id] = 1 105 | 106 | return cond, mask, discrete_column_id, category_id_in_col 107 | 108 | def sample_original_condvec(self, batch): 109 | """Generate the conditional vector for generation use original frequency.""" 110 | if self._n_discrete_columns == 0: 111 | return None 112 | 113 | category_freq = self._discrete_column_category_prob.flatten() 114 | category_freq = category_freq[category_freq != 0] 115 | category_freq = category_freq / np.sum(category_freq) 116 | col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq) 117 | cond = np.zeros((batch, self._n_categories), dtype='float32') 118 | cond[np.arange(batch), col_idxs] = 1 119 | 120 | return cond 121 | 122 | def sample_data(self, data, n, col, opt): 123 | """Sample data from original training data satisfying the sampled conditional vector. 124 | 125 | Args: 126 | data: 127 | The training data. 128 | 129 | Returns: 130 | n: 131 | n rows of matrix data. 132 | """ 133 | if col is None: 134 | idx = np.random.randint(len(data), size=n) 135 | return data[idx] 136 | 137 | idx = [] 138 | for c, o in zip(col, opt): 139 | idx.append(np.random.choice(self._rid_by_cat_cols[c][o])) 140 | 141 | return data[idx] 142 | 143 | def dim_cond_vec(self): 144 | """Return the total number of categories.""" 145 | return self._n_categories 146 | 147 | def generate_cond_from_condition_column_info(self, condition_info, batch): 148 | """Generate the condition vector.""" 149 | vec = np.zeros((batch, self._n_categories), dtype='float32') 150 | id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']] 151 | id_ += condition_info['value_id'] 152 | vec[:, id_] = 1 153 | return vec 154 | -------------------------------------------------------------------------------- /ctgan/data_transformer.py: -------------------------------------------------------------------------------- 1 | """DataTransformer module.""" 2 | 3 | from collections import namedtuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from joblib import Parallel, delayed 8 | from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder 9 | 10 | SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn']) 11 | ColumnTransformInfo = namedtuple( 12 | 'ColumnTransformInfo', 13 | ['column_name', 'column_type', 'transform', 'output_info', 'output_dimensions'], 14 | ) 15 | 16 | 17 | class DataTransformer(object): 18 | """Data Transformer. 19 | 20 | Model continuous columns with a BayesianGMM and normalize them to a scalar between [-1, 1] 21 | and a vector. Discrete columns are encoded using a OneHotEncoder. 22 | """ 23 | 24 | def __init__(self, max_clusters=10, weight_threshold=0.005): 25 | """Create a data transformer. 26 | 27 | Args: 28 | max_clusters (int): 29 | Maximum number of Gaussian distributions in Bayesian GMM. 30 | weight_threshold (float): 31 | Weight threshold for a Gaussian distribution to be kept. 32 | """ 33 | self._max_clusters = max_clusters 34 | self._weight_threshold = weight_threshold 35 | 36 | def _fit_continuous(self, data): 37 | """Train Bayesian GMM for continuous columns. 38 | 39 | Args: 40 | data (pd.DataFrame): 41 | A dataframe containing a column. 42 | 43 | Returns: 44 | namedtuple: 45 | A ``ColumnTransformInfo`` object. 46 | """ 47 | column_name = data.columns[0] 48 | gm = ClusterBasedNormalizer( 49 | missing_value_generation='from_column', 50 | max_clusters=min(len(data), self._max_clusters), 51 | weight_threshold=self._weight_threshold, 52 | ) 53 | gm.fit(data, column_name) 54 | num_components = sum(gm.valid_component_indicator) 55 | 56 | return ColumnTransformInfo( 57 | column_name=column_name, 58 | column_type='continuous', 59 | transform=gm, 60 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(num_components, 'softmax')], 61 | output_dimensions=1 + num_components, 62 | ) 63 | 64 | def _fit_discrete(self, data): 65 | """Fit one hot encoder for discrete column. 66 | 67 | Args: 68 | data (pd.DataFrame): 69 | A dataframe containing a column. 70 | 71 | Returns: 72 | namedtuple: 73 | A ``ColumnTransformInfo`` object. 74 | """ 75 | column_name = data.columns[0] 76 | ohe = OneHotEncoder() 77 | ohe.fit(data, column_name) 78 | num_categories = len(ohe.dummies) 79 | 80 | return ColumnTransformInfo( 81 | column_name=column_name, 82 | column_type='discrete', 83 | transform=ohe, 84 | output_info=[SpanInfo(num_categories, 'softmax')], 85 | output_dimensions=num_categories, 86 | ) 87 | 88 | def fit(self, raw_data, discrete_columns=()): 89 | """Fit the ``DataTransformer``. 90 | 91 | Fits a ``ClusterBasedNormalizer`` for continuous columns and a 92 | ``OneHotEncoder`` for discrete columns. 93 | 94 | This step also counts the #columns in matrix data and span information. 95 | """ 96 | self.output_info_list = [] 97 | self.output_dimensions = 0 98 | self.dataframe = True 99 | 100 | if not isinstance(raw_data, pd.DataFrame): 101 | self.dataframe = False 102 | # work around for RDT issue #328 Fitting with numerical column names fails 103 | discrete_columns = [str(column) for column in discrete_columns] 104 | column_names = [str(num) for num in range(raw_data.shape[1])] 105 | raw_data = pd.DataFrame(raw_data, columns=column_names) 106 | 107 | self._column_raw_dtypes = raw_data.infer_objects().dtypes 108 | self._column_transform_info_list = [] 109 | for column_name in raw_data.columns: 110 | if column_name in discrete_columns: 111 | column_transform_info = self._fit_discrete(raw_data[[column_name]]) 112 | else: 113 | column_transform_info = self._fit_continuous(raw_data[[column_name]]) 114 | 115 | self.output_info_list.append(column_transform_info.output_info) 116 | self.output_dimensions += column_transform_info.output_dimensions 117 | self._column_transform_info_list.append(column_transform_info) 118 | 119 | def _transform_continuous(self, column_transform_info, data): 120 | column_name = data.columns[0] 121 | flattened_column = data[column_name].to_numpy().flatten() 122 | data = data.assign(**{column_name: flattened_column}) 123 | gm = column_transform_info.transform 124 | transformed = gm.transform(data) 125 | 126 | # Converts the transformed data to the appropriate output format. 127 | # The first column (ending in '.normalized') stays the same, 128 | # but the lable encoded column (ending in '.component') is one hot encoded. 129 | output = np.zeros((len(transformed), column_transform_info.output_dimensions)) 130 | output[:, 0] = transformed[f'{column_name}.normalized'].to_numpy() 131 | index = transformed[f'{column_name}.component'].to_numpy().astype(int) 132 | output[np.arange(index.size), index + 1] = 1.0 133 | 134 | return output 135 | 136 | def _transform_discrete(self, column_transform_info, data): 137 | ohe = column_transform_info.transform 138 | return ohe.transform(data).to_numpy() 139 | 140 | def _synchronous_transform(self, raw_data, column_transform_info_list): 141 | """Take a Pandas DataFrame and transform columns synchronous. 142 | 143 | Outputs a list with Numpy arrays. 144 | """ 145 | column_data_list = [] 146 | for column_transform_info in column_transform_info_list: 147 | column_name = column_transform_info.column_name 148 | data = raw_data[[column_name]] 149 | if column_transform_info.column_type == 'continuous': 150 | column_data_list.append(self._transform_continuous(column_transform_info, data)) 151 | else: 152 | column_data_list.append(self._transform_discrete(column_transform_info, data)) 153 | 154 | return column_data_list 155 | 156 | def _parallel_transform(self, raw_data, column_transform_info_list): 157 | """Take a Pandas DataFrame and transform columns in parallel. 158 | 159 | Outputs a list with Numpy arrays. 160 | """ 161 | processes = [] 162 | for column_transform_info in column_transform_info_list: 163 | column_name = column_transform_info.column_name 164 | data = raw_data[[column_name]] 165 | process = None 166 | if column_transform_info.column_type == 'continuous': 167 | process = delayed(self._transform_continuous)(column_transform_info, data) 168 | else: 169 | process = delayed(self._transform_discrete)(column_transform_info, data) 170 | processes.append(process) 171 | 172 | return Parallel(n_jobs=-1)(processes) 173 | 174 | def transform(self, raw_data): 175 | """Take raw data and output a matrix data.""" 176 | if not isinstance(raw_data, pd.DataFrame): 177 | column_names = [str(num) for num in range(raw_data.shape[1])] 178 | raw_data = pd.DataFrame(raw_data, columns=column_names) 179 | 180 | # Only use parallelization with larger data sizes. 181 | # Otherwise, the transformation will be slower. 182 | if raw_data.shape[0] < 500: 183 | column_data_list = self._synchronous_transform( 184 | raw_data, self._column_transform_info_list 185 | ) 186 | else: 187 | column_data_list = self._parallel_transform(raw_data, self._column_transform_info_list) 188 | 189 | return np.concatenate(column_data_list, axis=1).astype(float) 190 | 191 | def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st): 192 | gm = column_transform_info.transform 193 | data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())).astype(float) 194 | data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1) 195 | if sigmas is not None: 196 | selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st]) 197 | data.iloc[:, 0] = selected_normalized_value 198 | 199 | return gm.reverse_transform(data) 200 | 201 | def _inverse_transform_discrete(self, column_transform_info, column_data): 202 | ohe = column_transform_info.transform 203 | data = pd.DataFrame(column_data, columns=list(ohe.get_output_sdtypes())) 204 | return ohe.reverse_transform(data)[column_transform_info.column_name] 205 | 206 | def inverse_transform(self, data, sigmas=None): 207 | """Take matrix data and output raw data. 208 | 209 | Output uses the same type as input to the transform function. 210 | Either np array or pd dataframe. 211 | """ 212 | st = 0 213 | recovered_column_data_list = [] 214 | column_names = [] 215 | for column_transform_info in self._column_transform_info_list: 216 | dim = column_transform_info.output_dimensions 217 | column_data = data[:, st : st + dim] 218 | if column_transform_info.column_type == 'continuous': 219 | recovered_column_data = self._inverse_transform_continuous( 220 | column_transform_info, column_data, sigmas, st 221 | ) 222 | else: 223 | recovered_column_data = self._inverse_transform_discrete( 224 | column_transform_info, column_data 225 | ) 226 | 227 | recovered_column_data_list.append(recovered_column_data) 228 | column_names.append(column_transform_info.column_name) 229 | st += dim 230 | 231 | recovered_data = np.column_stack(recovered_column_data_list) 232 | recovered_data = pd.DataFrame(recovered_data, columns=column_names).astype( 233 | self._column_raw_dtypes 234 | ) 235 | if not self.dataframe: 236 | recovered_data = recovered_data.to_numpy() 237 | 238 | return recovered_data 239 | 240 | def convert_column_name_value_to_id(self, column_name, value): 241 | """Get the ids of the given `column_name`.""" 242 | discrete_counter = 0 243 | column_id = 0 244 | for column_transform_info in self._column_transform_info_list: 245 | if column_transform_info.column_name == column_name: 246 | break 247 | if column_transform_info.column_type == 'discrete': 248 | discrete_counter += 1 249 | 250 | column_id += 1 251 | 252 | else: 253 | raise ValueError(f"The column_name `{column_name}` doesn't exist in the data.") 254 | 255 | ohe = column_transform_info.transform 256 | data = pd.DataFrame([value], columns=[column_transform_info.column_name]) 257 | one_hot = ohe.transform(data).to_numpy()[0] 258 | if sum(one_hot) == 0: 259 | raise ValueError(f"The value `{value}` doesn't exist in the column `{column_name}`.") 260 | 261 | return { 262 | 'discrete_column_id': discrete_counter, 263 | 'column_id': column_id, 264 | 'value_id': np.argmax(one_hot), 265 | } 266 | -------------------------------------------------------------------------------- /ctgan/demo.py: -------------------------------------------------------------------------------- 1 | """Demo module.""" 2 | 3 | import pandas as pd 4 | 5 | DEMO_URL = 'http://ctgan-demo.s3.amazonaws.com/census.csv.gz' 6 | 7 | 8 | def load_demo(): 9 | """Load the demo.""" 10 | return pd.read_csv(DEMO_URL, compression='gzip') 11 | -------------------------------------------------------------------------------- /ctgan/errors.py: -------------------------------------------------------------------------------- 1 | """Custom errors for CTGAN.""" 2 | 3 | 4 | class InvalidDataError(Exception): 5 | """Error to raise when data is not valid.""" 6 | -------------------------------------------------------------------------------- /ctgan/synthesizers/__init__.py: -------------------------------------------------------------------------------- 1 | """Synthesizers module.""" 2 | 3 | from ctgan.synthesizers.ctgan import CTGAN 4 | from ctgan.synthesizers.tvae import TVAE 5 | 6 | __all__ = ('CTGAN', 'TVAE') 7 | 8 | 9 | def get_all_synthesizers(): 10 | return {name: globals()[name] for name in __all__} 11 | -------------------------------------------------------------------------------- /ctgan/synthesizers/base.py: -------------------------------------------------------------------------------- 1 | """BaseSynthesizer module.""" 2 | 3 | import contextlib 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | @contextlib.contextmanager 10 | def set_random_states(random_state, set_model_random_state): 11 | """Context manager for managing the random state. 12 | 13 | Args: 14 | random_state (int or tuple): 15 | The random seed or a tuple of (numpy.random.RandomState, torch.Generator). 16 | set_model_random_state (function): 17 | Function to set the random state on the model. 18 | """ 19 | original_np_state = np.random.get_state() 20 | original_torch_state = torch.get_rng_state() 21 | 22 | random_np_state, random_torch_state = random_state 23 | 24 | np.random.set_state(random_np_state.get_state()) 25 | torch.set_rng_state(random_torch_state.get_state()) 26 | 27 | try: 28 | yield 29 | finally: 30 | current_np_state = np.random.RandomState() 31 | current_np_state.set_state(np.random.get_state()) 32 | current_torch_state = torch.Generator() 33 | current_torch_state.set_state(torch.get_rng_state()) 34 | set_model_random_state((current_np_state, current_torch_state)) 35 | 36 | np.random.set_state(original_np_state) 37 | torch.set_rng_state(original_torch_state) 38 | 39 | 40 | def random_state(function): 41 | """Set the random state before calling the function. 42 | 43 | Args: 44 | function (Callable): 45 | The function to wrap around. 46 | """ 47 | 48 | def wrapper(self, *args, **kwargs): 49 | if self.random_states is None: 50 | return function(self, *args, **kwargs) 51 | 52 | else: 53 | with set_random_states(self.random_states, self.set_random_state): 54 | return function(self, *args, **kwargs) 55 | 56 | return wrapper 57 | 58 | 59 | class BaseSynthesizer: 60 | """Base class for all default synthesizers of ``CTGAN``.""" 61 | 62 | random_states = None 63 | 64 | def __getstate__(self): 65 | """Improve pickling state for ``BaseSynthesizer``. 66 | 67 | Convert to ``cpu`` device before starting the pickling process in order to be able to 68 | load the model even when used from an external tool such as ``SDV``. Also, if 69 | ``random_states`` are set, store their states as dictionaries rather than generators. 70 | 71 | Returns: 72 | dict: 73 | Python dict representing the object. 74 | """ 75 | device_backup = self._device 76 | self.set_device(torch.device('cpu')) 77 | state = self.__dict__.copy() 78 | self.set_device(device_backup) 79 | if ( 80 | isinstance(self.random_states, tuple) 81 | and isinstance(self.random_states[0], np.random.RandomState) 82 | and isinstance(self.random_states[1], torch.Generator) 83 | ): 84 | state['_numpy_random_state'] = self.random_states[0].get_state() 85 | state['_torch_random_state'] = self.random_states[1].get_state() 86 | state.pop('random_states') 87 | 88 | return state 89 | 90 | def __setstate__(self, state): 91 | """Restore the state of a ``BaseSynthesizer``. 92 | 93 | Restore the ``random_states`` from the state dict if those are present and then 94 | set the device according to the current hardware. 95 | """ 96 | if '_numpy_random_state' in state and '_torch_random_state' in state: 97 | np_state = state.pop('_numpy_random_state') 98 | torch_state = state.pop('_torch_random_state') 99 | 100 | current_torch_state = torch.Generator() 101 | current_torch_state.set_state(torch_state) 102 | 103 | current_numpy_state = np.random.RandomState() 104 | current_numpy_state.set_state(np_state) 105 | state['random_states'] = (current_numpy_state, current_torch_state) 106 | 107 | self.__dict__ = state 108 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 109 | self.set_device(device) 110 | 111 | def save(self, path): 112 | """Save the model in the passed `path`.""" 113 | device_backup = self._device 114 | self.set_device(torch.device('cpu')) 115 | torch.save(self, path) 116 | self.set_device(device_backup) 117 | 118 | @classmethod 119 | def load(cls, path): 120 | """Load the model stored in the passed `path`.""" 121 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 122 | model = torch.load(path, weights_only=False) 123 | model.set_device(device) 124 | return model 125 | 126 | def set_random_state(self, random_state): 127 | """Set the random state. 128 | 129 | Args: 130 | random_state (int, tuple, or None): 131 | Either a tuple containing the (numpy.random.RandomState, torch.Generator) 132 | or an int representing the random seed to use for both random states. 133 | """ 134 | if random_state is None: 135 | self.random_states = random_state 136 | elif isinstance(random_state, int): 137 | self.random_states = ( 138 | np.random.RandomState(seed=random_state), 139 | torch.Generator().manual_seed(random_state), 140 | ) 141 | elif ( 142 | isinstance(random_state, tuple) 143 | and isinstance(random_state[0], np.random.RandomState) 144 | and isinstance(random_state[1], torch.Generator) 145 | ): 146 | self.random_states = random_state 147 | else: 148 | raise TypeError( 149 | f'`random_state` {random_state} expected to be an int or a tuple of ' 150 | '(`np.random.RandomState`, `torch.Generator`)' 151 | ) 152 | -------------------------------------------------------------------------------- /ctgan/synthesizers/ctgan.py: -------------------------------------------------------------------------------- 1 | """CTGAN module.""" 2 | 3 | import warnings 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch import optim 9 | from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional 10 | from tqdm import tqdm 11 | 12 | from ctgan.data_sampler import DataSampler 13 | from ctgan.data_transformer import DataTransformer 14 | from ctgan.errors import InvalidDataError 15 | from ctgan.synthesizers.base import BaseSynthesizer, random_state 16 | 17 | 18 | class Discriminator(Module): 19 | """Discriminator for the CTGAN.""" 20 | 21 | def __init__(self, input_dim, discriminator_dim, pac=10): 22 | super(Discriminator, self).__init__() 23 | dim = input_dim * pac 24 | self.pac = pac 25 | self.pacdim = dim 26 | seq = [] 27 | for item in list(discriminator_dim): 28 | seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)] 29 | dim = item 30 | 31 | seq += [Linear(dim, 1)] 32 | self.seq = Sequential(*seq) 33 | 34 | def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10): 35 | """Compute the gradient penalty.""" 36 | alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device) 37 | alpha = alpha.repeat(1, pac, real_data.size(1)) 38 | alpha = alpha.view(-1, real_data.size(1)) 39 | 40 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 41 | 42 | disc_interpolates = self(interpolates) 43 | 44 | gradients = torch.autograd.grad( 45 | outputs=disc_interpolates, 46 | inputs=interpolates, 47 | grad_outputs=torch.ones(disc_interpolates.size(), device=device), 48 | create_graph=True, 49 | retain_graph=True, 50 | only_inputs=True, 51 | )[0] 52 | 53 | gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 54 | gradient_penalty = ((gradients_view) ** 2).mean() * lambda_ 55 | 56 | return gradient_penalty 57 | 58 | def forward(self, input_): 59 | """Apply the Discriminator to the `input_`.""" 60 | assert input_.size()[0] % self.pac == 0 61 | return self.seq(input_.view(-1, self.pacdim)) 62 | 63 | 64 | class Residual(Module): 65 | """Residual layer for the CTGAN.""" 66 | 67 | def __init__(self, i, o): 68 | super(Residual, self).__init__() 69 | self.fc = Linear(i, o) 70 | self.bn = BatchNorm1d(o) 71 | self.relu = ReLU() 72 | 73 | def forward(self, input_): 74 | """Apply the Residual layer to the `input_`.""" 75 | out = self.fc(input_) 76 | out = self.bn(out) 77 | out = self.relu(out) 78 | return torch.cat([out, input_], dim=1) 79 | 80 | 81 | class Generator(Module): 82 | """Generator for the CTGAN.""" 83 | 84 | def __init__(self, embedding_dim, generator_dim, data_dim): 85 | super(Generator, self).__init__() 86 | dim = embedding_dim 87 | seq = [] 88 | for item in list(generator_dim): 89 | seq += [Residual(dim, item)] 90 | dim += item 91 | seq.append(Linear(dim, data_dim)) 92 | self.seq = Sequential(*seq) 93 | 94 | def forward(self, input_): 95 | """Apply the Generator to the `input_`.""" 96 | data = self.seq(input_) 97 | return data 98 | 99 | 100 | class CTGAN(BaseSynthesizer): 101 | """Conditional Table GAN Synthesizer. 102 | 103 | This is the core class of the CTGAN project, where the different components 104 | are orchestrated together. 105 | For more details about the process, please check the [Modeling Tabular data using 106 | Conditional GAN](https://arxiv.org/abs/1907.00503) paper. 107 | 108 | Args: 109 | embedding_dim (int): 110 | Size of the random sample passed to the Generator. Defaults to 128. 111 | generator_dim (tuple or list of ints): 112 | Size of the output samples for each one of the Residuals. A Residual Layer 113 | will be created for each one of the values provided. Defaults to (256, 256). 114 | discriminator_dim (tuple or list of ints): 115 | Size of the output samples for each one of the Discriminator Layers. A Linear Layer 116 | will be created for each one of the values provided. Defaults to (256, 256). 117 | generator_lr (float): 118 | Learning rate for the generator. Defaults to 2e-4. 119 | generator_decay (float): 120 | Generator weight decay for the Adam Optimizer. Defaults to 1e-6. 121 | discriminator_lr (float): 122 | Learning rate for the discriminator. Defaults to 2e-4. 123 | discriminator_decay (float): 124 | Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. 125 | batch_size (int): 126 | Number of data samples to process in each step. 127 | discriminator_steps (int): 128 | Number of discriminator updates to do for each generator update. 129 | From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper 130 | default is 5. Default used is 1 to match original CTGAN implementation. 131 | log_frequency (boolean): 132 | Whether to use log frequency of categorical levels in conditional 133 | sampling. Defaults to ``True``. 134 | verbose (boolean): 135 | Whether to have print statements for progress results. Defaults to ``False``. 136 | epochs (int): 137 | Number of training epochs. Defaults to 300. 138 | pac (int): 139 | Number of samples to group together when applying the discriminator. 140 | Defaults to 10. 141 | cuda (bool): 142 | Whether to attempt to use cuda for GPU computation. 143 | If this is False or CUDA is not available, CPU will be used. 144 | Defaults to ``True``. 145 | """ 146 | 147 | def __init__( 148 | self, 149 | embedding_dim=128, 150 | generator_dim=(256, 256), 151 | discriminator_dim=(256, 256), 152 | generator_lr=2e-4, 153 | generator_decay=1e-6, 154 | discriminator_lr=2e-4, 155 | discriminator_decay=1e-6, 156 | batch_size=500, 157 | discriminator_steps=1, 158 | log_frequency=True, 159 | verbose=False, 160 | epochs=300, 161 | pac=10, 162 | cuda=True, 163 | ): 164 | assert batch_size % 2 == 0 165 | 166 | self._embedding_dim = embedding_dim 167 | self._generator_dim = generator_dim 168 | self._discriminator_dim = discriminator_dim 169 | 170 | self._generator_lr = generator_lr 171 | self._generator_decay = generator_decay 172 | self._discriminator_lr = discriminator_lr 173 | self._discriminator_decay = discriminator_decay 174 | 175 | self._batch_size = batch_size 176 | self._discriminator_steps = discriminator_steps 177 | self._log_frequency = log_frequency 178 | self._verbose = verbose 179 | self._epochs = epochs 180 | self.pac = pac 181 | 182 | if not cuda or not torch.cuda.is_available(): 183 | device = 'cpu' 184 | elif isinstance(cuda, str): 185 | device = cuda 186 | else: 187 | device = 'cuda' 188 | 189 | self._device = torch.device(device) 190 | 191 | self._transformer = None 192 | self._data_sampler = None 193 | self._generator = None 194 | self.loss_values = None 195 | 196 | @staticmethod 197 | def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): 198 | """Deals with the instability of the gumbel_softmax for older versions of torch. 199 | 200 | For more details about the issue: 201 | https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing 202 | 203 | Args: 204 | logits […, num_features]: 205 | Unnormalized log probabilities 206 | tau: 207 | Non-negative scalar temperature 208 | hard (bool): 209 | If True, the returned samples will be discretized as one-hot vectors, 210 | but will be differentiated as if it is the soft sample in autograd 211 | dim (int): 212 | A dimension along which softmax will be computed. Default: -1. 213 | 214 | Returns: 215 | Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. 216 | """ 217 | for _ in range(10): 218 | transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim) 219 | if not torch.isnan(transformed).any(): 220 | return transformed 221 | 222 | raise ValueError('gumbel_softmax returning NaN.') 223 | 224 | def _apply_activate(self, data): 225 | """Apply proper activation function to the output of the generator.""" 226 | data_t = [] 227 | st = 0 228 | for column_info in self._transformer.output_info_list: 229 | for span_info in column_info: 230 | if span_info.activation_fn == 'tanh': 231 | ed = st + span_info.dim 232 | data_t.append(torch.tanh(data[:, st:ed])) 233 | st = ed 234 | elif span_info.activation_fn == 'softmax': 235 | ed = st + span_info.dim 236 | transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2) 237 | data_t.append(transformed) 238 | st = ed 239 | else: 240 | raise ValueError(f'Unexpected activation function {span_info.activation_fn}.') 241 | 242 | return torch.cat(data_t, dim=1) 243 | 244 | def _cond_loss(self, data, c, m): 245 | """Compute the cross entropy loss on the fixed discrete column.""" 246 | loss = [] 247 | st = 0 248 | st_c = 0 249 | for column_info in self._transformer.output_info_list: 250 | for span_info in column_info: 251 | if len(column_info) != 1 or span_info.activation_fn != 'softmax': 252 | # not discrete column 253 | st += span_info.dim 254 | else: 255 | ed = st + span_info.dim 256 | ed_c = st_c + span_info.dim 257 | tmp = functional.cross_entropy( 258 | data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none' 259 | ) 260 | loss.append(tmp) 261 | st = ed 262 | st_c = ed_c 263 | 264 | loss = torch.stack(loss, dim=1) # noqa: PD013 265 | 266 | return (loss * m).sum() / data.size()[0] 267 | 268 | def _validate_discrete_columns(self, train_data, discrete_columns): 269 | """Check whether ``discrete_columns`` exists in ``train_data``. 270 | 271 | Args: 272 | train_data (numpy.ndarray or pandas.DataFrame): 273 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. 274 | discrete_columns (list-like): 275 | List of discrete columns to be used to generate the Conditional 276 | Vector. If ``train_data`` is a Numpy array, this list should 277 | contain the integer indices of the columns. Otherwise, if it is 278 | a ``pandas.DataFrame``, this list should contain the column names. 279 | """ 280 | if isinstance(train_data, pd.DataFrame): 281 | invalid_columns = set(discrete_columns) - set(train_data.columns) 282 | elif isinstance(train_data, np.ndarray): 283 | invalid_columns = [] 284 | for column in discrete_columns: 285 | if column < 0 or column >= train_data.shape[1]: 286 | invalid_columns.append(column) 287 | else: 288 | raise TypeError('``train_data`` should be either pd.DataFrame or np.array.') 289 | 290 | if invalid_columns: 291 | raise ValueError(f'Invalid columns found: {invalid_columns}') 292 | 293 | def _validate_null_data(self, train_data, discrete_columns): 294 | """Check whether null values exist in continuous ``train_data``. 295 | 296 | Args: 297 | train_data (numpy.ndarray or pandas.DataFrame): 298 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. 299 | discrete_columns (list-like): 300 | List of discrete columns to be used to generate the Conditional 301 | Vector. If ``train_data`` is a Numpy array, this list should 302 | contain the integer indices of the columns. Otherwise, if it is 303 | a ``pandas.DataFrame``, this list should contain the column names. 304 | """ 305 | if isinstance(train_data, pd.DataFrame): 306 | continuous_cols = list(set(train_data.columns) - set(discrete_columns)) 307 | any_nulls = train_data[continuous_cols].isna().any().any() 308 | else: 309 | continuous_cols = [i for i in range(train_data.shape[1]) if i not in discrete_columns] 310 | any_nulls = pd.DataFrame(train_data)[continuous_cols].isna().any().any() 311 | 312 | if any_nulls: 313 | raise InvalidDataError( 314 | 'CTGAN does not support null values in the continuous training data. ' 315 | 'Please remove all null values from your continuous training data.' 316 | ) 317 | 318 | @random_state 319 | def fit(self, train_data, discrete_columns=(), epochs=None): 320 | """Fit the CTGAN Synthesizer models to the training data. 321 | 322 | Args: 323 | train_data (numpy.ndarray or pandas.DataFrame): 324 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. 325 | discrete_columns (list-like): 326 | List of discrete columns to be used to generate the Conditional 327 | Vector. If ``train_data`` is a Numpy array, this list should 328 | contain the integer indices of the columns. Otherwise, if it is 329 | a ``pandas.DataFrame``, this list should contain the column names. 330 | """ 331 | self._validate_discrete_columns(train_data, discrete_columns) 332 | self._validate_null_data(train_data, discrete_columns) 333 | 334 | if epochs is None: 335 | epochs = self._epochs 336 | else: 337 | warnings.warn( 338 | ( 339 | '`epochs` argument in `fit` method has been deprecated and will be removed ' 340 | 'in a future version. Please pass `epochs` to the constructor instead' 341 | ), 342 | DeprecationWarning, 343 | ) 344 | 345 | self._transformer = DataTransformer() 346 | self._transformer.fit(train_data, discrete_columns) 347 | 348 | train_data = self._transformer.transform(train_data) 349 | 350 | self._data_sampler = DataSampler( 351 | train_data, self._transformer.output_info_list, self._log_frequency 352 | ) 353 | 354 | data_dim = self._transformer.output_dimensions 355 | 356 | self._generator = Generator( 357 | self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim 358 | ).to(self._device) 359 | 360 | discriminator = Discriminator( 361 | data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac 362 | ).to(self._device) 363 | 364 | optimizerG = optim.Adam( 365 | self._generator.parameters(), 366 | lr=self._generator_lr, 367 | betas=(0.5, 0.9), 368 | weight_decay=self._generator_decay, 369 | ) 370 | 371 | optimizerD = optim.Adam( 372 | discriminator.parameters(), 373 | lr=self._discriminator_lr, 374 | betas=(0.5, 0.9), 375 | weight_decay=self._discriminator_decay, 376 | ) 377 | 378 | mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device) 379 | std = mean + 1 380 | 381 | self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss']) 382 | 383 | epoch_iterator = tqdm(range(epochs), disable=(not self._verbose)) 384 | if self._verbose: 385 | description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})' 386 | epoch_iterator.set_description(description.format(gen=0, dis=0)) 387 | 388 | steps_per_epoch = max(len(train_data) // self._batch_size, 1) 389 | for i in epoch_iterator: 390 | for id_ in range(steps_per_epoch): 391 | for n in range(self._discriminator_steps): 392 | fakez = torch.normal(mean=mean, std=std) 393 | 394 | condvec = self._data_sampler.sample_condvec(self._batch_size) 395 | if condvec is None: 396 | c1, m1, col, opt = None, None, None, None 397 | real = self._data_sampler.sample_data( 398 | train_data, self._batch_size, col, opt 399 | ) 400 | else: 401 | c1, m1, col, opt = condvec 402 | c1 = torch.from_numpy(c1).to(self._device) 403 | m1 = torch.from_numpy(m1).to(self._device) 404 | fakez = torch.cat([fakez, c1], dim=1) 405 | 406 | perm = np.arange(self._batch_size) 407 | np.random.shuffle(perm) 408 | real = self._data_sampler.sample_data( 409 | train_data, self._batch_size, col[perm], opt[perm] 410 | ) 411 | c2 = c1[perm] 412 | 413 | fake = self._generator(fakez) 414 | fakeact = self._apply_activate(fake) 415 | 416 | real = torch.from_numpy(real.astype('float32')).to(self._device) 417 | 418 | if c1 is not None: 419 | fake_cat = torch.cat([fakeact, c1], dim=1) 420 | real_cat = torch.cat([real, c2], dim=1) 421 | else: 422 | real_cat = real 423 | fake_cat = fakeact 424 | 425 | y_fake = discriminator(fake_cat) 426 | y_real = discriminator(real_cat) 427 | 428 | pen = discriminator.calc_gradient_penalty( 429 | real_cat, fake_cat, self._device, self.pac 430 | ) 431 | loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) 432 | 433 | optimizerD.zero_grad(set_to_none=False) 434 | pen.backward(retain_graph=True) 435 | loss_d.backward() 436 | optimizerD.step() 437 | 438 | fakez = torch.normal(mean=mean, std=std) 439 | condvec = self._data_sampler.sample_condvec(self._batch_size) 440 | 441 | if condvec is None: 442 | c1, m1, col, opt = None, None, None, None 443 | else: 444 | c1, m1, col, opt = condvec 445 | c1 = torch.from_numpy(c1).to(self._device) 446 | m1 = torch.from_numpy(m1).to(self._device) 447 | fakez = torch.cat([fakez, c1], dim=1) 448 | 449 | fake = self._generator(fakez) 450 | fakeact = self._apply_activate(fake) 451 | 452 | if c1 is not None: 453 | y_fake = discriminator(torch.cat([fakeact, c1], dim=1)) 454 | else: 455 | y_fake = discriminator(fakeact) 456 | 457 | if condvec is None: 458 | cross_entropy = 0 459 | else: 460 | cross_entropy = self._cond_loss(fake, c1, m1) 461 | 462 | loss_g = -torch.mean(y_fake) + cross_entropy 463 | 464 | optimizerG.zero_grad(set_to_none=False) 465 | loss_g.backward() 466 | optimizerG.step() 467 | 468 | generator_loss = loss_g.detach().cpu().item() 469 | discriminator_loss = loss_d.detach().cpu().item() 470 | 471 | epoch_loss_df = pd.DataFrame({ 472 | 'Epoch': [i], 473 | 'Generator Loss': [generator_loss], 474 | 'Discriminator Loss': [discriminator_loss], 475 | }) 476 | if not self.loss_values.empty: 477 | self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index( 478 | drop=True 479 | ) 480 | else: 481 | self.loss_values = epoch_loss_df 482 | 483 | if self._verbose: 484 | epoch_iterator.set_description( 485 | description.format(gen=generator_loss, dis=discriminator_loss) 486 | ) 487 | 488 | @random_state 489 | def sample(self, n, condition_column=None, condition_value=None): 490 | """Sample data similar to the training data. 491 | 492 | Choosing a condition_column and condition_value will increase the probability of the 493 | discrete condition_value happening in the condition_column. 494 | 495 | Args: 496 | n (int): 497 | Number of rows to sample. 498 | condition_column (string): 499 | Name of a discrete column. 500 | condition_value (string): 501 | Name of the category in the condition_column which we wish to increase the 502 | probability of happening. 503 | 504 | Returns: 505 | numpy.ndarray or pandas.DataFrame 506 | """ 507 | if condition_column is not None and condition_value is not None: 508 | condition_info = self._transformer.convert_column_name_value_to_id( 509 | condition_column, condition_value 510 | ) 511 | global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info( 512 | condition_info, self._batch_size 513 | ) 514 | else: 515 | global_condition_vec = None 516 | 517 | steps = n // self._batch_size + 1 518 | data = [] 519 | for i in range(steps): 520 | mean = torch.zeros(self._batch_size, self._embedding_dim) 521 | std = mean + 1 522 | fakez = torch.normal(mean=mean, std=std).to(self._device) 523 | 524 | if global_condition_vec is not None: 525 | condvec = global_condition_vec.copy() 526 | else: 527 | condvec = self._data_sampler.sample_original_condvec(self._batch_size) 528 | 529 | if condvec is None: 530 | pass 531 | else: 532 | c1 = condvec 533 | c1 = torch.from_numpy(c1).to(self._device) 534 | fakez = torch.cat([fakez, c1], dim=1) 535 | 536 | fake = self._generator(fakez) 537 | fakeact = self._apply_activate(fake) 538 | data.append(fakeact.detach().cpu().numpy()) 539 | 540 | data = np.concatenate(data, axis=0) 541 | data = data[:n] 542 | 543 | return self._transformer.inverse_transform(data) 544 | 545 | def set_device(self, device): 546 | """Set the `device` to be used ('GPU' or 'CPU).""" 547 | self._device = device 548 | if self._generator is not None: 549 | self._generator.to(self._device) 550 | -------------------------------------------------------------------------------- /ctgan/synthesizers/tvae.py: -------------------------------------------------------------------------------- 1 | """TVAE module.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch.nn import Linear, Module, Parameter, ReLU, Sequential 7 | from torch.nn.functional import cross_entropy 8 | from torch.optim import Adam 9 | from torch.utils.data import DataLoader, TensorDataset 10 | from tqdm import tqdm 11 | 12 | from ctgan.data_transformer import DataTransformer 13 | from ctgan.synthesizers.base import BaseSynthesizer, random_state 14 | 15 | 16 | class Encoder(Module): 17 | """Encoder for the TVAE. 18 | 19 | Args: 20 | data_dim (int): 21 | Dimensions of the data. 22 | compress_dims (tuple or list of ints): 23 | Size of each hidden layer. 24 | embedding_dim (int): 25 | Size of the output vector. 26 | """ 27 | 28 | def __init__(self, data_dim, compress_dims, embedding_dim): 29 | super(Encoder, self).__init__() 30 | dim = data_dim 31 | seq = [] 32 | for item in list(compress_dims): 33 | seq += [Linear(dim, item), ReLU()] 34 | dim = item 35 | 36 | self.seq = Sequential(*seq) 37 | self.fc1 = Linear(dim, embedding_dim) 38 | self.fc2 = Linear(dim, embedding_dim) 39 | 40 | def forward(self, input_): 41 | """Encode the passed `input_`.""" 42 | feature = self.seq(input_) 43 | mu = self.fc1(feature) 44 | logvar = self.fc2(feature) 45 | std = torch.exp(0.5 * logvar) 46 | return mu, std, logvar 47 | 48 | 49 | class Decoder(Module): 50 | """Decoder for the TVAE. 51 | 52 | Args: 53 | embedding_dim (int): 54 | Size of the input vector. 55 | decompress_dims (tuple or list of ints): 56 | Size of each hidden layer. 57 | data_dim (int): 58 | Dimensions of the data. 59 | """ 60 | 61 | def __init__(self, embedding_dim, decompress_dims, data_dim): 62 | super(Decoder, self).__init__() 63 | dim = embedding_dim 64 | seq = [] 65 | for item in list(decompress_dims): 66 | seq += [Linear(dim, item), ReLU()] 67 | dim = item 68 | 69 | seq.append(Linear(dim, data_dim)) 70 | self.seq = Sequential(*seq) 71 | self.sigma = Parameter(torch.ones(data_dim) * 0.1) 72 | 73 | def forward(self, input_): 74 | """Decode the passed `input_`.""" 75 | return self.seq(input_), self.sigma 76 | 77 | 78 | def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor): 79 | st = 0 80 | loss = [] 81 | for column_info in output_info: 82 | for span_info in column_info: 83 | if span_info.activation_fn != 'softmax': 84 | ed = st + span_info.dim 85 | std = sigmas[st] 86 | eq = x[:, st] - torch.tanh(recon_x[:, st]) 87 | loss.append((eq**2 / 2 / (std**2)).sum()) 88 | loss.append(torch.log(std) * x.size()[0]) 89 | st = ed 90 | 91 | else: 92 | ed = st + span_info.dim 93 | loss.append( 94 | cross_entropy( 95 | recon_x[:, st:ed], torch.argmax(x[:, st:ed], dim=-1), reduction='sum' 96 | ) 97 | ) 98 | st = ed 99 | 100 | assert st == recon_x.size()[1] 101 | KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp()) 102 | return sum(loss) * factor / x.size()[0], KLD / x.size()[0] 103 | 104 | 105 | class TVAE(BaseSynthesizer): 106 | """TVAE.""" 107 | 108 | def __init__( 109 | self, 110 | embedding_dim=128, 111 | compress_dims=(128, 128), 112 | decompress_dims=(128, 128), 113 | l2scale=1e-5, 114 | batch_size=500, 115 | epochs=300, 116 | loss_factor=2, 117 | cuda=True, 118 | verbose=False, 119 | ): 120 | self.embedding_dim = embedding_dim 121 | self.compress_dims = compress_dims 122 | self.decompress_dims = decompress_dims 123 | 124 | self.l2scale = l2scale 125 | self.batch_size = batch_size 126 | self.loss_factor = loss_factor 127 | self.epochs = epochs 128 | self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) 129 | self.verbose = verbose 130 | 131 | if not cuda or not torch.cuda.is_available(): 132 | device = 'cpu' 133 | elif isinstance(cuda, str): 134 | device = cuda 135 | else: 136 | device = 'cuda' 137 | 138 | self._device = torch.device(device) 139 | 140 | @random_state 141 | def fit(self, train_data, discrete_columns=()): 142 | """Fit the TVAE Synthesizer models to the training data. 143 | 144 | Args: 145 | train_data (numpy.ndarray or pandas.DataFrame): 146 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame. 147 | discrete_columns (list-like): 148 | List of discrete columns to be used to generate the Conditional 149 | Vector. If ``train_data`` is a Numpy array, this list should 150 | contain the integer indices of the columns. Otherwise, if it is 151 | a ``pandas.DataFrame``, this list should contain the column names. 152 | """ 153 | self.transformer = DataTransformer() 154 | self.transformer.fit(train_data, discrete_columns) 155 | train_data = self.transformer.transform(train_data) 156 | dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device)) 157 | loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False) 158 | 159 | data_dim = self.transformer.output_dimensions 160 | encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device) 161 | self.decoder = Decoder(self.embedding_dim, self.decompress_dims, data_dim).to(self._device) 162 | optimizerAE = Adam( 163 | list(encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale 164 | ) 165 | 166 | self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) 167 | iterator = tqdm(range(self.epochs), disable=(not self.verbose)) 168 | if self.verbose: 169 | iterator_description = 'Loss: {loss:.3f}' 170 | iterator.set_description(iterator_description.format(loss=0)) 171 | 172 | for i in iterator: 173 | loss_values = [] 174 | batch = [] 175 | for id_, data in enumerate(loader): 176 | optimizerAE.zero_grad() 177 | real = data[0].to(self._device) 178 | mu, std, logvar = encoder(real) 179 | eps = torch.randn_like(std) 180 | emb = eps * std + mu 181 | rec, sigmas = self.decoder(emb) 182 | loss_1, loss_2 = _loss_function( 183 | rec, 184 | real, 185 | sigmas, 186 | mu, 187 | logvar, 188 | self.transformer.output_info_list, 189 | self.loss_factor, 190 | ) 191 | loss = loss_1 + loss_2 192 | loss.backward() 193 | optimizerAE.step() 194 | self.decoder.sigma.data.clamp_(0.01, 1.0) 195 | 196 | batch.append(id_) 197 | loss_values.append(loss.detach().cpu().item()) 198 | 199 | epoch_loss_df = pd.DataFrame({ 200 | 'Epoch': [i] * len(batch), 201 | 'Batch': batch, 202 | 'Loss': loss_values, 203 | }) 204 | if not self.loss_values.empty: 205 | self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index( 206 | drop=True 207 | ) 208 | else: 209 | self.loss_values = epoch_loss_df 210 | 211 | if self.verbose: 212 | iterator.set_description( 213 | iterator_description.format(loss=loss.detach().cpu().item()) 214 | ) 215 | 216 | @random_state 217 | def sample(self, samples): 218 | """Sample data similar to the training data. 219 | 220 | Args: 221 | samples (int): 222 | Number of rows to sample. 223 | 224 | Returns: 225 | numpy.ndarray or pandas.DataFrame 226 | """ 227 | self.decoder.eval() 228 | 229 | steps = samples // self.batch_size + 1 230 | data = [] 231 | for _ in range(steps): 232 | mean = torch.zeros(self.batch_size, self.embedding_dim) 233 | std = mean + 1 234 | noise = torch.normal(mean=mean, std=std).to(self._device) 235 | fake, sigmas = self.decoder(noise) 236 | fake = torch.tanh(fake) 237 | data.append(fake.detach().cpu().numpy()) 238 | 239 | data = np.concatenate(data, axis=0) 240 | data = data[:samples] 241 | return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy()) 242 | 243 | def set_device(self, device): 244 | """Set the `device` to be used ('GPU' or 'CPU).""" 245 | self._device = device 246 | self.decoder.to(self._device) 247 | -------------------------------------------------------------------------------- /examples/csv/adult.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": [ 3 | { 4 | "name": "age", 5 | "type": "continuous" 6 | }, 7 | { 8 | "name": "workclass", 9 | "type": "categorical" 10 | }, 11 | { 12 | "name": "fnlwgt", 13 | "type": "continuous" 14 | }, 15 | { 16 | "name": "education", 17 | "type": "ordinal" 18 | }, 19 | { 20 | "name": "education-num", 21 | "type": "continuous" 22 | }, 23 | { 24 | "name": "marital-status", 25 | "type": "categorical" 26 | }, 27 | { 28 | "name": "occupation", 29 | "type": "categorical" 30 | }, 31 | { 32 | "name": "relationship", 33 | "type": "categorical" 34 | }, 35 | { 36 | "name": "race", 37 | "type": "categorical" 38 | }, 39 | { 40 | "name": "sex", 41 | "type": "categorical" 42 | }, 43 | { 44 | "name": "capital-gain", 45 | "type": "continuous" 46 | }, 47 | { 48 | "name": "capital-loss", 49 | "type": "continuous" 50 | }, 51 | { 52 | "name": "hours-per-week", 53 | "type": "continuous" 54 | }, 55 | { 56 | "name": "native-country", 57 | "type": "categorical" 58 | }, 59 | { 60 | "name": "income", 61 | "type": "categorical" 62 | } 63 | ] 64 | } 65 | -------------------------------------------------------------------------------- /examples/tsv/acs.meta: -------------------------------------------------------------------------------- 1 | D 0 1 2 | D 0 1 3 | D 0 1 4 | D 0 1 5 | D 0 1 6 | D 0 1 7 | D 0 1 8 | D 0 1 9 | D 0 1 10 | D 0 1 11 | D 0 1 12 | D 0 1 13 | D 0 1 14 | D 0 1 15 | D 0 1 16 | D 0 1 17 | D 0 1 18 | D 0 1 19 | D 0 1 20 | D 0 1 21 | D 0 1 22 | D 0 1 23 | D 0 1 24 | -------------------------------------------------------------------------------- /examples/tsv/adult.meta: -------------------------------------------------------------------------------- 1 | C 17.0 90.0 2 | D Federal-gov Local-gov State-gov Private Self-emp-inc Self-emp-not-inc Without-pay 3 | C 13492.0 1490400.0 4 | D Preschool 1st-4th 5th-6th 7th-8th 9th 10th 11th 12th HS-grad Some-college Assoc-voc Assoc-acdm Bachelors Masters Prof-school Doctorate 5 | C 1.0 16.0 6 | D Never-married Married-AF-spouse Married-civ-spouse Married-spouse-absent Separated Widowed Divorced 7 | D Adm-clerical Armed-Forces Craft-repair Exec-managerial Farming-fishing Handlers-cleaners Machine-op-inspct Other-service Priv-house-serv Prof-specialty Protective-serv Sales Tech-support Transport-moving 8 | D Husband Wife Own-child Other-relative Not-in-family Unmarried 9 | D White Black Amer-Indian-Eskimo Asian-Pac-Islander Other 10 | D Female Male 11 | C 0.0 99999.0 12 | C 0.0 4356.0 13 | C 1.0 99.0 14 | D United-States Outlying-US(Guam-USVI-etc) Puerto-Rico Canada Mexico Cambodia China Hong Japan Laos Philippines South Taiwan Thailand Vietnam Columbia Ecuador Peru Cuba Dominican-Republic El-Salvador Guatemala Haiti Honduras Jamaica Nicaragua Trinadad&Tobago England France Germany Greece Holand-Netherlands Hungary Ireland Italy Poland Portugal Scotland Yugoslavia India Iran 15 | D <=50K >50K 16 | -------------------------------------------------------------------------------- /examples/tsv/br2000.meta: -------------------------------------------------------------------------------- 1 | D 0 1 2 | D 0 1 2 3 4 5 6 3 | C 0.0 21.0 4 | D 0 1 2 3 4 5 6 7 8 9 5 | C 0.0 90.0 6 | D 0 1 7 | D 0 1 8 | D 0 1 9 | D 0 1 10 | D 0 1 2 3 11 | C 0.0 125.0 12 | D 0 1 13 | C 0.0 90.0 14 | D 0 1 15 | -------------------------------------------------------------------------------- /examples/tsv/nltcs.meta: -------------------------------------------------------------------------------- 1 | D 0 1 2 | D 0 1 3 | D 0 1 4 | D 0 1 5 | D 0 1 6 | D 0 1 7 | D 0 1 8 | D 0 1 9 | D 0 1 10 | D 0 1 11 | D 0 1 12 | D 0 1 13 | D 0 1 14 | D 0 1 15 | D 0 1 16 | D 0 1 17 | -------------------------------------------------------------------------------- /latest_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==2.0.2 2 | pandas==2.2.3 3 | rdt==1.17.0 4 | torch==2.7.0 5 | tqdm==4.67.1 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = 'ctgan' 3 | description = 'Create tabular synthetic data using a conditional GAN' 4 | authors = [{ name = 'DataCebo, Inc.', email = 'info@sdv.dev' }] 5 | classifiers = [ 6 | 'Development Status :: 2 - Pre-Alpha', 7 | 'Intended Audience :: Developers', 8 | 'License :: Free for non-commercial use', 9 | 'Natural Language :: English', 10 | 'Programming Language :: Python :: 3', 11 | 'Programming Language :: Python :: 3.8', 12 | 'Programming Language :: Python :: 3.9', 13 | 'Programming Language :: Python :: 3.10', 14 | 'Programming Language :: Python :: 3.11', 15 | 'Programming Language :: Python :: 3.12', 16 | 'Programming Language :: Python :: 3.13', 17 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 18 | ] 19 | keywords = ['ctgan', 'CTGAN'] 20 | dynamic = ['version'] 21 | license = { text = 'BSL-1.1' } 22 | requires-python = '>=3.8,<3.14' 23 | readme = 'README.md' 24 | dependencies = [ 25 | "numpy>=1.21.0;python_version<'3.10'", 26 | "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'", 27 | "numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'", 28 | "numpy>=2.1.0;python_version>='3.13'", 29 | "pandas>=1.4.0;python_version<'3.11'", 30 | "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'", 31 | "pandas>=2.1.1;python_version>='3.12' and python_version<'3.13'", 32 | "pandas>=2.2.3;python_version>='3.13'", 33 | "torch>=1.13.0;python_version<'3.11'", 34 | "torch>=2.0.0;python_version>='3.11' and python_version<'3.12'", 35 | "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'", 36 | "torch>=2.6.0;python_version>='3.13'", 37 | 'tqdm>=4.29,<5', 38 | 'rdt>=1.14.0', 39 | ] 40 | 41 | [project.urls] 42 | "Source Code"= "https://github.com/sdv-dev/CTGAN/" 43 | "Issue Tracker" = "https://github.com/sdv-dev/CTGAN/issues" 44 | "Changes" = "https://github.com/sdv-dev/CTGAN/blob/main/HISTORY.md" 45 | "Twitter" = "https://twitter.com/sdv_dev" 46 | "Chat" = "https://bit.ly/sdv-slack-invite" 47 | 48 | [project.entry-points] 49 | ctgan = { main = 'ctgan.cli.__main__:main' } 50 | 51 | [project.optional-dependencies] 52 | test = [ 53 | 'pytest>=3.4.2', 54 | 'pytest-rerunfailures>=10.3,<15', 55 | 'pytest-cov>=2.6.0', 56 | 'pytest-runner >= 2.11.1', 57 | 'tomli>=2.0.0,<3', 58 | ] 59 | dev = [ 60 | 'ctgan[test]', 61 | 62 | # general 63 | 'pip>=9.0.1', 64 | 'build>=1.0.0,<2', 65 | 'bump-my-version>=0.18.3', 66 | 'watchdog>=1.0.1,<5', 67 | 68 | # style check 69 | 'ruff>=0.4.5,<1', 70 | 71 | # distribute on PyPI 72 | 'twine>=1.10.0', 73 | 'wheel>=0.30.0', 74 | 75 | # Advanced testing 76 | 'coverage>=4.5.1,<6', 77 | 'tox>=2.9.1,<4', 78 | 79 | 'invoke', 80 | ] 81 | readme = ['rundoc>=0.4.3,<0.5',] 82 | 83 | [tool.setuptools] 84 | include-package-data = true 85 | license-files = ['LICENSE'] 86 | 87 | [tool.setuptools.packages.find] 88 | include = ['ctgan', 'ctgan.*'] 89 | namespaces = false 90 | 91 | [tool.setuptools.package-data] 92 | '*' = [ 93 | 'AUTHORS.rst', 94 | 'CONTRIBUTING.rst', 95 | 'HISTORY.md', 96 | 'README.md', 97 | '*.md', 98 | '*.rst', 99 | 'conf.py', 100 | 'Makefile', 101 | 'make.bat', 102 | '*.jpg', 103 | '*.png', 104 | '*.gif' 105 | ] 106 | 107 | [tool.setuptools.exclude-package-data] 108 | '*' = [ 109 | '* __pycache__', 110 | '*.py[co]', 111 | 'static_code_analysis.txt', 112 | ] 113 | 114 | [tool.setuptools.dynamic] 115 | version = {attr = 'ctgan.__version__'} 116 | 117 | [tool.pytest.ini_options] 118 | collect_ignore = ['pyproject.toml'] 119 | 120 | [tool.bumpversion] 121 | current_version = "0.11.1.dev0" 122 | parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?' 123 | serialize = [ 124 | '{major}.{minor}.{patch}.{release}{candidate}', 125 | '{major}.{minor}.{patch}' 126 | ] 127 | search = '{current_version}' 128 | replace = '{new_version}' 129 | regex = false 130 | ignore_missing_version = false 131 | tag = true 132 | sign_tags = false 133 | tag_name = 'v{new_version}' 134 | tag_message = 'Bump version: {current_version} → {new_version}' 135 | allow_dirty = false 136 | commit = true 137 | message = 'Bump version: {current_version} → {new_version}' 138 | commit_args = '' 139 | 140 | [tool.bumpversion.parts.release] 141 | first_value = 'dev' 142 | optional_value = 'release' 143 | values = [ 144 | 'dev', 145 | 'release' 146 | ] 147 | 148 | [[tool.bumpversion.files]] 149 | filename = "ctgan/__init__.py" 150 | search = "__version__ = '{current_version}'" 151 | replace = "__version__ = '{new_version}'" 152 | 153 | [build-system] 154 | requires = ['setuptools', 'wheel'] 155 | build-backend = 'setuptools.build_meta' 156 | 157 | [tool.ruff] 158 | preview = true 159 | line-length = 100 160 | indent-width = 4 161 | src = ["ctgan"] 162 | exclude = [ 163 | "docs", 164 | ".tox", 165 | ".git", 166 | "__pycache__", 167 | ".ipynb_checkpoints", 168 | "tasks.py", 169 | ] 170 | 171 | [tool.ruff.lint] 172 | select = [ 173 | # Pyflakes 174 | "F", 175 | # Pycodestyle 176 | "E", 177 | "W", 178 | # pydocstyle 179 | "D", 180 | # isort 181 | "I001", 182 | # print statements 183 | "T201", 184 | # pandas-vet 185 | "PD", 186 | # numpy 2.0 187 | "NPY201" 188 | ] 189 | ignore = [ 190 | # pydocstyle 191 | "D107", # Missing docstring in __init__ 192 | "D417", # Missing argument descriptions in the docstring, this is a bug from pydocstyle: https://github.com/PyCQA/pydocstyle/issues/449 193 | "PD901", 194 | "PD101", 195 | ] 196 | 197 | [tool.ruff.format] 198 | quote-style = "single" 199 | indent-style = "space" 200 | preview = true 201 | docstring-code-format = true 202 | docstring-code-line-length = "dynamic" 203 | 204 | [tool.ruff.lint.isort] 205 | known-first-party = ["ctgan"] 206 | lines-between-types = 0 207 | 208 | [tool.ruff.lint.per-file-ignores] 209 | "__init__.py" = ["F401", "E402", "F403", "F405", "E501", "I001"] 210 | "errors.py" = ["D105"] 211 | "tests/**.py" = ["D"] 212 | 213 | [tool.ruff.lint.pydocstyle] 214 | convention = "google" 215 | 216 | [tool.ruff.lint.pycodestyle] 217 | max-doc-length = 100 218 | max-line-length = 100 219 | -------------------------------------------------------------------------------- /scripts/release_notes_generator.py: -------------------------------------------------------------------------------- 1 | """Script to generate release notes.""" 2 | 3 | import argparse 4 | import os 5 | from collections import defaultdict 6 | 7 | import requests 8 | 9 | LABEL_TO_HEADER = { 10 | 'feature request': 'New Features', 11 | 'bug': 'Bugs Fixed', 12 | 'internal': 'Internal', 13 | 'maintenance': 'Maintenance', 14 | 'customer success': 'Customer Success', 15 | 'documentation': 'Documentation', 16 | 'misc': 'Miscellaneous', 17 | } 18 | ISSUE_LABELS = [ 19 | 'documentation', 20 | 'maintenance', 21 | 'internal', 22 | 'bug', 23 | 'feature request', 24 | 'customer success', 25 | ] 26 | ISSUE_LABELS_ORDERED_BY_IMPORTANCE = [ 27 | 'feature request', 28 | 'customer success', 29 | 'bug', 30 | 'documentation', 31 | 'internal', 32 | 'maintenance', 33 | ] 34 | NEW_LINE = '\n' 35 | GITHUB_URL = 'https://api.github.com/repos/sdv-dev/ctgan' 36 | GITHUB_TOKEN = os.getenv('GH_ACCESS_TOKEN') 37 | 38 | 39 | def _get_milestone_number(milestone_title): 40 | url = f'{GITHUB_URL}/milestones' 41 | headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'} 42 | query_params = {'milestone': milestone_title, 'state': 'all', 'per_page': 100} 43 | response = requests.get(url, headers=headers, params=query_params, timeout=10) 44 | body = response.json() 45 | if response.status_code != 200: 46 | raise Exception(str(body)) 47 | 48 | milestones = body 49 | for milestone in milestones: 50 | if milestone.get('title') == milestone_title: 51 | return milestone.get('number') 52 | 53 | raise ValueError(f'Milestone {milestone_title} not found in past 100 milestones.') 54 | 55 | 56 | def _get_issues_by_milestone(milestone): 57 | headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'} 58 | # get milestone number 59 | milestone_number = _get_milestone_number(milestone) 60 | url = f'{GITHUB_URL}/issues' 61 | page = 1 62 | query_params = {'milestone': milestone_number, 'state': 'all'} 63 | issues = [] 64 | while True: 65 | query_params['page'] = page 66 | response = requests.get(url, headers=headers, params=query_params, timeout=10) 67 | body = response.json() 68 | if response.status_code != 200: 69 | raise Exception(str(body)) 70 | 71 | issues_on_page = body 72 | if not issues_on_page: 73 | break 74 | 75 | # Filter our PRs 76 | issues_on_page = [issue for issue in issues_on_page if issue.get('pull_request') is None] 77 | issues.extend(issues_on_page) 78 | page += 1 79 | 80 | return issues 81 | 82 | 83 | def _get_issues_by_category(release_issues): 84 | category_to_issues = defaultdict(list) 85 | 86 | for issue in release_issues: 87 | issue_title = issue['title'] 88 | issue_number = issue['number'] 89 | issue_url = issue['html_url'] 90 | line = f'* {issue_title} - Issue [#{issue_number}]({issue_url})' 91 | assignee = issue.get('assignee') 92 | if assignee: 93 | login = assignee['login'] 94 | line += f' by @{login}' 95 | 96 | # Check if any known label is marked on the issue 97 | labels = [label['name'] for label in issue['labels']] 98 | found_category = False 99 | for category in ISSUE_LABELS: 100 | if category in labels: 101 | category_to_issues[category].append(line) 102 | found_category = True 103 | break 104 | 105 | if not found_category: 106 | category_to_issues['misc'].append(line) 107 | 108 | return category_to_issues 109 | 110 | 111 | def _create_release_notes(issues_by_category, version, date): 112 | title = f'## v{version} - {date}' 113 | release_notes = f'{title}{NEW_LINE}{NEW_LINE}' 114 | 115 | for category in ISSUE_LABELS_ORDERED_BY_IMPORTANCE + ['misc']: 116 | issues = issues_by_category.get(category) 117 | if issues: 118 | section_text = ( 119 | f'### {LABEL_TO_HEADER[category]}{NEW_LINE}{NEW_LINE}' 120 | f'{NEW_LINE.join(issues)}{NEW_LINE}{NEW_LINE}' 121 | ) 122 | 123 | release_notes += section_text 124 | 125 | return release_notes 126 | 127 | 128 | def update_release_notes(release_notes): 129 | """Add the release notes for the new release to the ``HISTORY.md``.""" 130 | file_path = 'HISTORY.md' 131 | with open(file_path, 'r') as history_file: 132 | history = history_file.read() 133 | 134 | token = '# History\n\n' 135 | split_index = history.find(token) + len(token) + 1 136 | header = history[:split_index] 137 | new_notes = f'{header}{release_notes}{history[split_index:]}' 138 | 139 | with open(file_path, 'w') as new_history_file: 140 | new_history_file.write(new_notes) 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('-v', '--version', type=str, help='Release version number (ie. v1.0.1)') 146 | parser.add_argument('-d', '--date', type=str, help='Date of release in format YYYY-MM-DD') 147 | args = parser.parse_args() 148 | release_number = args.version 149 | release_issues = _get_issues_by_milestone(release_number) 150 | issues_by_category = _get_issues_by_category(release_issues) 151 | release_notes = _create_release_notes(issues_by_category, release_number, args.date) 152 | update_release_notes(release_notes) 153 | -------------------------------------------------------------------------------- /static_code_analysis.txt: -------------------------------------------------------------------------------- 1 | Run started:2025-02-25 21:10:33.223731 2 | 3 | Test results: 4 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 5 | Severity: Low Confidence: High 6 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 7 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 8 | Location: ./ctgan/__main__.py:121:8 9 | 120 if args.sample_condition_column is not None: 10 | 121 assert args.sample_condition_column_value is not None 11 | 122 12 | 13 | -------------------------------------------------- 14 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 15 | Severity: Low Confidence: High 16 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 17 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 18 | Location: ./ctgan/data.py:48:12 19 | 47 else: 20 | 48 assert item[0] == 'D' 21 | 49 discrete.append(idx) 22 | 23 | -------------------------------------------------- 24 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 25 | Severity: Low Confidence: High 26 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 27 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 28 | Location: ./ctgan/data.py:69:16 29 | 68 else: 30 | 69 assert idx in discrete 31 | 70 row.append(column_info[idx].index(col)) 32 | 33 | -------------------------------------------------- 34 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 35 | Severity: Low Confidence: High 36 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 37 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 38 | Location: ./ctgan/data.py:85:20 39 | 84 else: 40 | 85 assert idx in meta['discrete_columns'] 41 | 86 print(meta['column_info'][idx][int(col)], end=' ', file=f) 42 | 43 | -------------------------------------------------- 44 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 45 | Severity: Low Confidence: High 46 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 47 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 48 | Location: ./ctgan/data_sampler.py:40:8 49 | 39 st += sum([span_info.dim for span_info in column_info]) 50 | 40 assert st == data.shape[1] 51 | 41 52 | 53 | -------------------------------------------------- 54 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 55 | Severity: Low Confidence: High 56 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 57 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 58 | Location: ./ctgan/synthesizers/ctgan.py:60:8 59 | 59 """Apply the Discriminator to the `input_`.""" 60 | 60 assert input_.size()[0] % self.pac == 0 61 | 61 return self.seq(input_.view(-1, self.pacdim)) 62 | 63 | -------------------------------------------------- 64 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 65 | Severity: Low Confidence: High 66 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 67 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 68 | Location: ./ctgan/synthesizers/ctgan.py:164:8 69 | 163 ): 70 | 164 assert batch_size % 2 == 0 71 | 165 72 | 73 | -------------------------------------------------- 74 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. 75 | Severity: Low Confidence: High 76 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html) 77 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html 78 | Location: ./ctgan/synthesizers/tvae.py:100:4 79 | 99 80 | 100 assert st == recon_x.size()[1] 81 | 101 KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp()) 82 | 83 | -------------------------------------------------- 84 | 85 | Code scanned: 86 | Total lines of code: 1414 87 | Total lines skipped (#nosec): 0 88 | Total potential issues skipped due to specifically being disabled (e.g., #nosec BXXX): 0 89 | 90 | Run metrics: 91 | Total issues (by severity): 92 | Undefined: 0 93 | Low: 8 94 | Medium: 0 95 | High: 0 96 | Total issues (by confidence): 97 | Undefined: 0 98 | Low: 0 99 | Medium: 0 100 | High: 8 101 | Files skipped (0): 102 | -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import operator 3 | import os 4 | import shutil 5 | import stat 6 | import sys 7 | from pathlib import Path 8 | 9 | import tomli 10 | from invoke import task 11 | from packaging.requirements import Requirement 12 | from packaging.version import Version 13 | 14 | COMPARISONS = {'>=': operator.ge, '>': operator.gt, '<': operator.lt, '<=': operator.le} 15 | 16 | if not hasattr(inspect, 'getargspec'): 17 | inspect.getargspec = inspect.getfullargspec 18 | 19 | 20 | @task 21 | def check_dependencies(c): 22 | c.run('python -m pip check') 23 | 24 | 25 | @task 26 | def unit(c): 27 | c.run('python -m pytest ./tests/unit --cov=ctgan --cov-report=xml:./unit_cov.xml') 28 | 29 | 30 | @task 31 | def integration(c): 32 | c.run('python -m pytest ./tests/integration --reruns 3 --cov=ctgan --cov-report=xml:./integration_cov.xml') 33 | 34 | 35 | @task 36 | def readme(c): 37 | test_path = Path('tests/readme_test') 38 | if test_path.exists() and test_path.is_dir(): 39 | shutil.rmtree(test_path) 40 | 41 | cwd = os.getcwd() 42 | os.makedirs(test_path, exist_ok=True) 43 | shutil.copy('README.md', test_path / 'README.md') 44 | os.chdir(test_path) 45 | c.run('rundoc run --single-session python3 -t python3 README.md') 46 | os.chdir(cwd) 47 | shutil.rmtree(test_path) 48 | 49 | 50 | def _get_minimum_versions(dependencies, python_version): 51 | min_versions = {} 52 | for dependency in dependencies: 53 | if '@' in dependency: 54 | name, url = dependency.split(' @ ') 55 | min_versions[name] = f'{url}#egg={name}' 56 | continue 57 | 58 | req = Requirement(dependency) 59 | if ';' in dependency: 60 | marker = req.marker 61 | if marker and not marker.evaluate({'python_version': python_version}): 62 | continue # Skip this dependency if the marker does not apply to the current Python version 63 | 64 | if req.name not in min_versions: 65 | min_version = next( 66 | (spec.version for spec in req.specifier if spec.operator in ('>=', '==')), None 67 | ) 68 | if min_version: 69 | min_versions[req.name] = f'{req.name}=={min_version}' 70 | 71 | elif '@' not in min_versions[req.name]: 72 | existing_version = Version(min_versions[req.name].split('==')[1]) 73 | new_version = next( 74 | (spec.version for spec in req.specifier if spec.operator in ('>=', '==')), 75 | existing_version, 76 | ) 77 | if new_version > existing_version: 78 | min_versions[req.name] = ( 79 | f'{req.name}=={new_version}' # Change when a valid newer version is found 80 | ) 81 | 82 | return list(min_versions.values()) 83 | 84 | 85 | @task 86 | def install_minimum(c): 87 | with open('pyproject.toml', 'rb') as pyproject_file: 88 | pyproject_data = tomli.load(pyproject_file) 89 | 90 | dependencies = pyproject_data.get('project', {}).get('dependencies', []) 91 | python_version = '.'.join(map(str, sys.version_info[:2])) 92 | minimum_versions = _get_minimum_versions(dependencies, python_version) 93 | 94 | if minimum_versions: 95 | install_deps = ' '.join(minimum_versions) 96 | c.run(f'python -m pip install {install_deps}') 97 | 98 | 99 | @task 100 | def minimum(c): 101 | install_minimum(c) 102 | check_dependencies(c) 103 | unit(c) 104 | integration(c) 105 | 106 | 107 | @task 108 | def lint(c): 109 | check_dependencies(c) 110 | c.run('ruff check .') 111 | c.run('ruff format --check --diff .') 112 | 113 | 114 | @task 115 | def fix_lint(c): 116 | check_dependencies(c) 117 | c.run('ruff check --fix .') 118 | c.run('ruff format .') 119 | 120 | 121 | def remove_readonly(func, path, _): 122 | """Clear the readonly bit and reattempt the removal""" 123 | os.chmod(path, stat.S_IWRITE) 124 | func(path) 125 | 126 | 127 | @task 128 | def rmdir(c, path): 129 | try: 130 | shutil.rmtree(path, onerror=remove_readonly) 131 | except PermissionError: 132 | pass 133 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """CTGAN tests.""" 2 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | """Integration testing subpackage.""" 2 | -------------------------------------------------------------------------------- /tests/integration/synthesizer/__init__.py: -------------------------------------------------------------------------------- 1 | """Subpackage for integration testing of synthesizers.""" 2 | -------------------------------------------------------------------------------- /tests/integration/synthesizer/test_ctgan.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Integration tests for ctgan. 5 | 6 | These tests only ensure that the software does not crash and that 7 | the API works as expected in terms of input and output data formats, 8 | but correctness of the data values and the internal behavior of the 9 | model are not checked. 10 | """ 11 | 12 | import tempfile as tf 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import pytest 17 | 18 | from ctgan.errors import InvalidDataError 19 | from ctgan.synthesizers.ctgan import CTGAN 20 | 21 | 22 | def test_ctgan_no_categoricals(): 23 | """Test the CTGAN with no categorical values.""" 24 | data = pd.DataFrame({'continuous': np.random.random(1000)}) 25 | 26 | ctgan = CTGAN(epochs=1) 27 | ctgan.fit(data, []) 28 | 29 | sampled = ctgan.sample(100) 30 | 31 | assert sampled.shape == (100, 1) 32 | assert isinstance(sampled, pd.DataFrame) 33 | assert set(sampled.columns) == {'continuous'} 34 | assert len(ctgan.loss_values) == 1 35 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss'] 36 | 37 | 38 | def test_ctgan_dataframe(): 39 | """Test the CTGAN when passed a dataframe.""" 40 | data = pd.DataFrame({ 41 | 'continuous': np.random.random(100), 42 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 43 | }) 44 | discrete_columns = ['discrete'] 45 | 46 | ctgan = CTGAN(epochs=1) 47 | ctgan.fit(data, discrete_columns) 48 | 49 | sampled = ctgan.sample(100) 50 | 51 | assert sampled.shape == (100, 2) 52 | assert isinstance(sampled, pd.DataFrame) 53 | assert set(sampled.columns) == {'continuous', 'discrete'} 54 | assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'} 55 | assert len(ctgan.loss_values) == 1 56 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss'] 57 | 58 | 59 | def test_ctgan_numpy(): 60 | """Test the CTGAN when passed a numpy array.""" 61 | data = pd.DataFrame({ 62 | 'continuous': np.random.random(100), 63 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 64 | }) 65 | discrete_columns = [1] 66 | 67 | ctgan = CTGAN(epochs=1) 68 | ctgan.fit(data.to_numpy(), discrete_columns) 69 | 70 | sampled = ctgan.sample(100) 71 | 72 | assert sampled.shape == (100, 2) 73 | assert isinstance(sampled, np.ndarray) 74 | assert set(np.unique(sampled[:, 1])) == {'a', 'b', 'c'} 75 | assert len(ctgan.loss_values) == 1 76 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss'] 77 | 78 | 79 | def test_log_frequency(): 80 | """Test the CTGAN with no `log_frequency` set to False.""" 81 | data = pd.DataFrame({ 82 | 'continuous': np.random.random(1000), 83 | 'discrete': np.repeat(['a', 'b', 'c'], [950, 25, 25]), 84 | }) 85 | 86 | discrete_columns = ['discrete'] 87 | 88 | ctgan = CTGAN(epochs=100) 89 | ctgan.fit(data, discrete_columns) 90 | 91 | assert len(ctgan.loss_values) == 100 92 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss'] 93 | pd.testing.assert_series_equal(ctgan.loss_values['Epoch'], pd.Series(range(100), name='Epoch')) 94 | 95 | sampled = ctgan.sample(10000) 96 | counts = sampled['discrete'].value_counts() 97 | assert counts['a'] < 6500 98 | 99 | ctgan = CTGAN(log_frequency=False, epochs=100) 100 | ctgan.fit(data, discrete_columns) 101 | 102 | assert len(ctgan.loss_values) == 100 103 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss'] 104 | pd.testing.assert_series_equal(ctgan.loss_values['Epoch'], pd.Series(range(100), name='Epoch')) 105 | 106 | sampled = ctgan.sample(10000) 107 | counts = sampled['discrete'].value_counts() 108 | assert counts['a'] > 9000 109 | 110 | 111 | def test_categorical_nan(): 112 | """Test the CTGAN with no categorical values.""" 113 | data = pd.DataFrame({ 114 | 'continuous': np.random.random(30), 115 | # This must be a list (not a np.array) or NaN will be cast to a string. 116 | 'discrete': [np.nan, 'b', 'c'] * 10, 117 | }) 118 | discrete_columns = ['discrete'] 119 | 120 | ctgan = CTGAN(epochs=1) 121 | ctgan.fit(data, discrete_columns) 122 | 123 | sampled = ctgan.sample(100) 124 | 125 | assert sampled.shape == (100, 2) 126 | assert isinstance(sampled, pd.DataFrame) 127 | assert set(sampled.columns) == {'continuous', 'discrete'} 128 | 129 | # since np.nan != np.nan, we need to be careful here 130 | values = set(sampled['discrete'].unique()) 131 | assert len(values) == 3 132 | assert any(pd.isna(x) for x in values) 133 | assert {'b', 'c'}.issubset(values) 134 | 135 | 136 | def test_continuous_nan(): 137 | """Test the CTGAN with missing numerical values.""" 138 | # Setup 139 | data = pd.DataFrame({ 140 | 'continuous': [np.nan, 1.0, 2.0] * 10, 141 | 'discrete': ['a', 'b', 'c'] * 10, 142 | }) 143 | discrete_columns = ['discrete'] 144 | error_message = ( 145 | 'CTGAN does not support null values in the continuous training data. ' 146 | 'Please remove all null values from your continuous training data.' 147 | ) 148 | 149 | # Run and Assert 150 | ctgan = CTGAN(epochs=1) 151 | with pytest.raises(InvalidDataError, match=error_message): 152 | ctgan.fit(data, discrete_columns) 153 | 154 | 155 | def test_synthesizer_sample(): 156 | """Test the CTGAN samples the correct datatype.""" 157 | data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)}) 158 | discrete_columns = ['discrete'] 159 | 160 | ctgan = CTGAN(epochs=1) 161 | ctgan.fit(data, discrete_columns) 162 | 163 | samples = ctgan.sample(1000, 'discrete', 'a') 164 | assert isinstance(samples, pd.DataFrame) 165 | 166 | 167 | def test_save_load(): 168 | """Test the CTGAN load/save methods.""" 169 | data = pd.DataFrame({ 170 | 'continuous': np.random.random(100), 171 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 172 | }) 173 | discrete_columns = ['discrete'] 174 | 175 | ctgan = CTGAN(epochs=1) 176 | ctgan.fit(data, discrete_columns) 177 | 178 | with tf.TemporaryDirectory() as temporary_directory: 179 | ctgan.save(temporary_directory + 'test_tvae.pkl') 180 | ctgan = CTGAN.load(temporary_directory + 'test_tvae.pkl') 181 | 182 | sampled = ctgan.sample(1000) 183 | assert set(sampled.columns) == {'continuous', 'discrete'} 184 | assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'} 185 | 186 | 187 | def test_wrong_discrete_columns_dataframe(): 188 | """Test the CTGAN correctly crashes when passed non-existing discrete columns.""" 189 | data = pd.DataFrame({'discrete': ['a', 'b']}) 190 | discrete_columns = ['b', 'c'] 191 | 192 | ctgan = CTGAN(epochs=1) 193 | with pytest.raises(ValueError, match="Invalid columns found: {'.*', '.*'}"): 194 | ctgan.fit(data, discrete_columns) 195 | 196 | 197 | def test_wrong_discrete_columns_numpy(): 198 | """Test the CTGAN correctly crashes when passed non-existing discrete columns.""" 199 | data = pd.DataFrame({'discrete': ['a', 'b']}) 200 | discrete_columns = [0, 1] 201 | 202 | ctgan = CTGAN(epochs=1) 203 | with pytest.raises(ValueError, match=r'Invalid columns found: \[1\]'): 204 | ctgan.fit(data.to_numpy(), discrete_columns) 205 | 206 | 207 | def test_wrong_sampling_conditions(): 208 | """Test the CTGAN correctly crashes when passed incorrect sampling conditions.""" 209 | data = pd.DataFrame({ 210 | 'continuous': np.random.random(100), 211 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 212 | }) 213 | discrete_columns = ['discrete'] 214 | 215 | ctgan = CTGAN(epochs=1) 216 | ctgan.fit(data, discrete_columns) 217 | 218 | with pytest.raises(ValueError, match="The column_name `cardinal` doesn't exist in the data."): 219 | ctgan.sample(1, 'cardinal', "doesn't matter") 220 | 221 | with pytest.raises(ValueError): # noqa: RDT currently incorrectly raises a tuple instead of a string 222 | ctgan.sample(1, 'discrete', 'd') 223 | 224 | 225 | def test_fixed_random_seed(): 226 | """Test the CTGAN with a fixed seed. 227 | 228 | Expect that when the random seed is reset with the same seed, the same sequence 229 | of data will be produced. Expect that the data generated with the seed is 230 | different than randomly sampled data. 231 | """ 232 | # Setup 233 | data = pd.DataFrame({ 234 | 'continuous': np.random.random(100), 235 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 236 | }) 237 | discrete_columns = ['discrete'] 238 | 239 | ctgan = CTGAN(epochs=1, cuda=False) 240 | 241 | # Run 242 | ctgan.fit(data, discrete_columns) 243 | sampled_random = ctgan.sample(10) 244 | 245 | ctgan.set_random_state(0) 246 | sampled_0_0 = ctgan.sample(10) 247 | sampled_0_1 = ctgan.sample(10) 248 | 249 | ctgan.set_random_state(0) 250 | sampled_1_0 = ctgan.sample(10) 251 | sampled_1_1 = ctgan.sample(10) 252 | 253 | # Assert 254 | assert not np.array_equal(sampled_random, sampled_0_0) 255 | assert not np.array_equal(sampled_random, sampled_0_1) 256 | np.testing.assert_array_equal(sampled_0_0, sampled_1_0) 257 | np.testing.assert_array_equal(sampled_0_1, sampled_1_1) 258 | 259 | 260 | def test_ctgan_save_and_load(tmpdir): 261 | """Test that the ``CTGAN`` model can be saved and loaded.""" 262 | # Setup 263 | data = pd.DataFrame({ 264 | 'continuous': np.random.random(100), 265 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 266 | }) 267 | discrete_columns = [1] 268 | 269 | ctgan = CTGAN(epochs=1) 270 | ctgan.fit(data.to_numpy(), discrete_columns) 271 | ctgan.set_random_state(0) 272 | 273 | ctgan.sample(100) 274 | model_path = tmpdir / 'model.pkl' 275 | 276 | # Save 277 | ctgan.save(str(model_path)) 278 | 279 | # Load 280 | loaded_instance = CTGAN.load(str(model_path)) 281 | loaded_instance.sample(100) 282 | -------------------------------------------------------------------------------- /tests/integration/synthesizer/test_tvae.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Integration tests for tvae. 5 | 6 | These tests only ensure that the software does not crash and that 7 | the API works as expected in terms of input and output data formats, 8 | but correctness of the data values and the internal behavior of the 9 | model are not checked. 10 | """ 11 | 12 | import numpy as np 13 | import pandas as pd 14 | 15 | from ctgan.synthesizers.tvae import TVAE 16 | 17 | 18 | def test_drop_last_false(): 19 | """Test the TVAE predicts the correct values.""" 20 | data = pd.DataFrame({'1': ['a', 'b', 'c'] * 150, '2': ['a', 'b', 'c'] * 150}) 21 | 22 | tvae = TVAE(epochs=300) 23 | tvae.fit(data, ['1', '2']) 24 | 25 | sampled = tvae.sample(100) 26 | correct = 0 27 | for _, row in sampled.iterrows(): 28 | if row['1'] == row['2']: 29 | correct += 1 30 | 31 | assert correct >= 95 32 | 33 | 34 | def test__loss_function(): 35 | """Test the TVAE produces average values similar to the training data.""" 36 | data = pd.DataFrame({ 37 | '1': [float(i) for i in range(1000)], 38 | '2': [float(2 * i) for i in range(1000)], 39 | }) 40 | 41 | tvae = TVAE(epochs=300) 42 | tvae.fit(data) 43 | 44 | num_samples = 1000 45 | sampled = tvae.sample(num_samples) 46 | error = 0 47 | for _, row in sampled.iterrows(): 48 | error += abs(2 * row['1'] - row['2']) 49 | 50 | avg_error = error / num_samples 51 | 52 | assert avg_error < 400 53 | 54 | 55 | def test_fixed_random_seed(): 56 | """Test the TVAE with a fixed seed. 57 | 58 | Expect that when the random seed is reset with the same seed, the same sequence 59 | of data will be produced. Expect that the data generated with the seed is 60 | different than randomly sampled data. 61 | """ 62 | # Setup 63 | data = pd.DataFrame({ 64 | 'continuous': np.random.random(100), 65 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 66 | }) 67 | discrete_columns = ['discrete'] 68 | 69 | tvae = TVAE(epochs=1) 70 | 71 | # Run 72 | tvae.fit(data, discrete_columns) 73 | sampled_random = tvae.sample(10) 74 | 75 | tvae.set_random_state(0) 76 | sampled_0_0 = tvae.sample(10) 77 | sampled_0_1 = tvae.sample(10) 78 | 79 | tvae.set_random_state(0) 80 | sampled_1_0 = tvae.sample(10) 81 | sampled_1_1 = tvae.sample(10) 82 | 83 | # Assert 84 | assert not np.array_equal(sampled_random, sampled_0_0) 85 | assert not np.array_equal(sampled_random, sampled_0_1) 86 | np.testing.assert_array_equal(sampled_0_0, sampled_1_0) 87 | np.testing.assert_array_equal(sampled_0_1, sampled_1_1) 88 | 89 | 90 | def test_tvae_save(tmpdir, capsys): 91 | """Test that the ``TVAE`` model can be saved and loaded.""" 92 | # Setup 93 | data = pd.DataFrame({ 94 | 'continuous': np.random.random(100), 95 | 'discrete': np.random.choice(['a', 'b', 'c'], 100), 96 | }) 97 | discrete_columns = ['discrete'] 98 | 99 | tvae = TVAE(epochs=10, verbose=True) 100 | tvae.fit(data, discrete_columns) 101 | captured_out = capsys.readouterr().err 102 | tvae.set_random_state(0) 103 | 104 | tvae.sample(100) 105 | model_path = tmpdir / 'model.pkl' 106 | 107 | # Save 108 | tvae.save(str(model_path)) 109 | 110 | # Load 111 | loaded_instance = TVAE.load(str(model_path)) 112 | sampled = loaded_instance.sample(100) 113 | 114 | # Assert 115 | assert sampled.shape == (100, 2) 116 | assert isinstance(sampled, pd.DataFrame) 117 | assert set(sampled.columns) == set(data.columns) 118 | assert set(sampled.dtypes) == set(data.dtypes) 119 | loss_values = tvae.loss_values 120 | assert len(loss_values) == 10 121 | assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'} 122 | assert all(loss_values['Batch'] == 0) 123 | last_loss_val = loss_values['Loss'].iloc[-1] 124 | assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out 125 | -------------------------------------------------------------------------------- /tests/integration/test_data_transformer.py: -------------------------------------------------------------------------------- 1 | """Data transformer intergration testing module.""" 2 | 3 | from unittest import TestCase 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from ctgan.data_transformer import DataTransformer 9 | 10 | 11 | class TestDataTransformer(TestCase): 12 | def test_constant(self): 13 | """Test transforming a dataframe containing constant values.""" 14 | # Setup 15 | data = pd.DataFrame({'cnt': [123] * 1000}) 16 | transformer = DataTransformer() 17 | 18 | # Run 19 | transformer.fit(data, []) 20 | new_data = transformer.transform(data) 21 | transformer.inverse_transform(new_data) 22 | 23 | # Assert transformed values are between -1 and 1 24 | assert (new_data[:, 0] > -np.ones(len(new_data))).all() 25 | assert (new_data[:, 0] < np.ones(len(new_data))).all() 26 | 27 | # Assert transformed values are a gaussian centered in 0 and with std ~ 0 28 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1 29 | assert 0 <= np.std(new_data[:, 0]) < 0.1 30 | 31 | # Assert there are at most `max_columns=10` one hot columns 32 | assert new_data.shape[0] == 1000 33 | assert new_data.shape[1] <= 11 34 | assert np.isin(new_data[:, 1:], [0, 1]).all() 35 | 36 | def test_df_continuous(self): 37 | """Test transforming a dataframe containing only continuous values.""" 38 | # Setup 39 | data = pd.DataFrame({'col': np.random.normal(size=1000)}) 40 | transformer = DataTransformer() 41 | 42 | # Run 43 | transformer.fit(data, []) 44 | new_data = transformer.transform(data) 45 | transformer.inverse_transform(new_data) 46 | 47 | # Assert transformed values are between -1 and 1 48 | assert (new_data[:, 0] > -np.ones(len(new_data))).all() 49 | assert (new_data[:, 0] < np.ones(len(new_data))).all() 50 | 51 | # Assert transformed values are a gaussian centered in 0 and with std = 1/4 52 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1 53 | assert 0.2 < np.std(new_data[:, 0]) < 0.3 54 | 55 | # Assert there are at most `max_columns=10` one hot columns 56 | assert new_data.shape[0] == 1000 57 | assert new_data.shape[1] <= 11 58 | assert np.isin(new_data[:, 1:], [0, 1]).all() 59 | 60 | def test_df_categorical_constant(self): 61 | """Test transforming a dataframe containing only constant categorical values.""" 62 | # Setup 63 | data = pd.DataFrame({'cnt': [123] * 1000}) 64 | transformer = DataTransformer() 65 | 66 | # Run 67 | transformer.fit(data, ['cnt']) 68 | new_data = transformer.transform(data) 69 | transformer.inverse_transform(new_data) 70 | 71 | # Assert there is only 1 one hot vector 72 | assert np.array_equal(new_data, np.ones((len(data), 1))) 73 | 74 | def test_df_categorical(self): 75 | """Test transforming a dataframe containing only categorical values.""" 76 | # Setup 77 | data = pd.DataFrame({'cat': np.random.choice(['a', 'b', 'c'], size=1000)}) 78 | transformer = DataTransformer() 79 | 80 | # Run 81 | transformer.fit(data, ['cat']) 82 | new_data = transformer.transform(data) 83 | transformer.inverse_transform(new_data) 84 | 85 | # Assert there are 3 one hot vectors 86 | assert new_data.shape[0] == 1000 87 | assert new_data.shape[1] == 3 88 | assert np.isin(new_data[:, 1:], [0, 1]).all() 89 | 90 | def test_df_mixed(self): 91 | """Test transforming a dataframe containing mixed data types.""" 92 | # Setup 93 | data = pd.DataFrame({ 94 | 'num': np.random.normal(size=1000), 95 | 'cat': np.random.choice(['a', 'b', 'c'], size=1000), 96 | }) 97 | transformer = DataTransformer() 98 | 99 | # Run 100 | transformer.fit(data, ['cat']) 101 | new_data = transformer.transform(data) 102 | transformer.inverse_transform(new_data) 103 | 104 | # Assert transformed numerical values are between -1 and 1 105 | assert (new_data[:, 0] > -np.ones(len(new_data))).all() 106 | assert (new_data[:, 0] < np.ones(len(new_data))).all() 107 | 108 | # Assert transformed numerical values are a gaussian centered in 0 and with std = 1/4 109 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1 110 | assert 0.2 < np.std(new_data[:, 0]) < 0.3 111 | 112 | # Assert there are at most `max_columns=10` one hot columns for the numerical values 113 | # and 3 for the categorical ones 114 | assert new_data.shape[0] == 1000 115 | assert 5 <= new_data.shape[1] <= 17 116 | assert np.isin(new_data[:, 1:], [0, 1]).all() 117 | 118 | def test_numpy(self): 119 | """Test transforming a numpy array.""" 120 | # Setup 121 | data = pd.DataFrame({ 122 | 'num': np.random.normal(size=1000), 123 | 'cat': np.random.choice(['a', 'b', 'c'], size=1000), 124 | }) 125 | data = np.array(data) 126 | transformer = DataTransformer() 127 | 128 | # Run 129 | transformer.fit(data, [1]) 130 | new_data = transformer.transform(data) 131 | transformer.inverse_transform(new_data) 132 | 133 | # Assert transformed numerical values are between -1 and 1 134 | assert (new_data[:, 0] > -np.ones(len(new_data))).all() 135 | assert (new_data[:, 0] < np.ones(len(new_data))).all() 136 | 137 | # Assert transformed numerical values are a gaussian centered in 0 and with std = 1/4 138 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1 139 | assert 0.2 < np.std(new_data[:, 0]) < 0.3 140 | 141 | # Assert there are at most `max_columns=10` one hot columns for the numerical values 142 | # and 3 for the categorical ones 143 | assert new_data.shape[0] == 1000 144 | assert 5 <= new_data.shape[1] <= 17 145 | assert np.isin(new_data[:, 1:], [0, 1]).all() 146 | -------------------------------------------------------------------------------- /tests/integration/test_load_demo.py: -------------------------------------------------------------------------------- 1 | from ctgan import CTGAN, load_demo 2 | 3 | 4 | def test_load_demo(): 5 | """End-to-end test to load and synthesize data.""" 6 | # Setup 7 | discrete_columns = [ 8 | 'workclass', 9 | 'education', 10 | 'marital-status', 11 | 'occupation', 12 | 'relationship', 13 | 'race', 14 | 'sex', 15 | 'native-country', 16 | 'income', 17 | ] 18 | ctgan = CTGAN(epochs=1) 19 | 20 | # Run 21 | data = load_demo() 22 | ctgan.fit(data, discrete_columns) 23 | samples = ctgan.sample(1000, condition_column='native-country', condition_value='United-States') 24 | 25 | # Assert 26 | assert samples.shape == (1000, 15) 27 | assert all([col[0] != ' ' for col in samples.columns]) 28 | assert not samples.isna().any().any() 29 | -------------------------------------------------------------------------------- /tests/test_tasks.py: -------------------------------------------------------------------------------- 1 | """Tests for the ``tasks.py`` file.""" 2 | 3 | from tasks import _get_minimum_versions 4 | 5 | 6 | def test_get_minimum_versions(): 7 | """Test the ``_get_minimum_versions`` method. 8 | 9 | The method should return the minimum versions of the dependencies for the given python version. 10 | If a library is linked to an URL, the minimum version should be the URL. 11 | """ 12 | # Setup 13 | dependencies = [ 14 | "numpy>=1.20.0,<2;python_version<'3.10'", 15 | "numpy>=1.23.3,<2;python_version>='3.10'", 16 | "pandas>=1.2.0,<2;python_version<'3.10'", 17 | "pandas>=1.3.0,<2;python_version>='3.10'", 18 | 'humanfriendly>=8.2,<11', 19 | 'pandas @ git+https://github.com/pandas-dev/pandas.git@master', 20 | ] 21 | 22 | # Run 23 | minimum_versions_39 = _get_minimum_versions(dependencies, '3.9') 24 | minimum_versions_310 = _get_minimum_versions(dependencies, '3.10') 25 | 26 | # Assert 27 | expected_versions_39 = [ 28 | 'numpy==1.20.0', 29 | 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', 30 | 'humanfriendly==8.2', 31 | ] 32 | expected_versions_310 = [ 33 | 'numpy==1.23.3', 34 | 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', 35 | 'humanfriendly==8.2', 36 | ] 37 | 38 | assert minimum_versions_39 == expected_versions_39 39 | assert minimum_versions_310 == expected_versions_310 40 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit testing module.""" 2 | -------------------------------------------------------------------------------- /tests/unit/synthesizer/__init__.py: -------------------------------------------------------------------------------- 1 | """CTGAN testing module.""" 2 | -------------------------------------------------------------------------------- /tests/unit/synthesizer/test_base.py: -------------------------------------------------------------------------------- 1 | """BaseSynthesizer unit testing module.""" 2 | 3 | from unittest.mock import MagicMock, call, patch 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from ctgan.synthesizers.base import BaseSynthesizer, random_state 9 | 10 | 11 | @patch('ctgan.synthesizers.base.torch') 12 | @patch('ctgan.synthesizers.base.np.random') 13 | def test_valid_random_state(random_mock, torch_mock): 14 | """Test the ``random_state`` attribute with a valid random state. 15 | 16 | Expect that the decorated function uses the random_state attribute. 17 | """ 18 | # Setup 19 | my_function = MagicMock() 20 | instance = MagicMock() 21 | 22 | random_state_mock = MagicMock() 23 | random_state_mock.get_state.return_value = 'desired numpy state' 24 | torch_generator_mock = MagicMock() 25 | torch_generator_mock.get_state.return_value = 'desired torch state' 26 | instance.random_states = (random_state_mock, torch_generator_mock) 27 | 28 | args = {'some', 'args'} 29 | kwargs = {'keyword': 'value'} 30 | 31 | random_mock.RandomState.return_value = random_state_mock 32 | random_mock.get_state.return_value = 'random state' 33 | torch_mock.Generator.return_value = torch_generator_mock 34 | torch_mock.get_rng_state.return_value = 'torch random state' 35 | 36 | # Run 37 | decorated_function = random_state(my_function) 38 | decorated_function(instance, *args, **kwargs) 39 | 40 | # Assert 41 | my_function.assert_called_once_with(instance, *args, **kwargs) 42 | 43 | instance.assert_not_called 44 | assert random_mock.get_state.call_count == 2 45 | assert torch_mock.get_rng_state.call_count == 2 46 | random_mock.RandomState.assert_has_calls([ 47 | call().get_state(), 48 | call(), 49 | call().set_state('random state'), 50 | ]) 51 | random_mock.set_state.assert_has_calls([call('desired numpy state'), call('random state')]) 52 | torch_mock.set_rng_state.assert_has_calls([ 53 | call('desired torch state'), 54 | call('torch random state'), 55 | ]) 56 | 57 | 58 | @patch('ctgan.synthesizers.base.torch') 59 | @patch('ctgan.synthesizers.base.np.random') 60 | def test_no_random_seed(random_mock, torch_mock): 61 | """Test the ``random_state`` attribute with no random state. 62 | 63 | Expect that the decorated function calls the original function 64 | when there is no random state. 65 | """ 66 | # Setup 67 | my_function = MagicMock() 68 | instance = MagicMock() 69 | instance.random_states = None 70 | 71 | args = {'some', 'args'} 72 | kwargs = {'keyword': 'value'} 73 | 74 | # Run 75 | decorated_function = random_state(my_function) 76 | decorated_function(instance, *args, **kwargs) 77 | 78 | # Assert 79 | my_function.assert_called_once_with(instance, *args, **kwargs) 80 | 81 | instance.assert_not_called 82 | random_mock.get_state.assert_not_called() 83 | random_mock.RandomState.assert_not_called() 84 | random_mock.set_state.assert_not_called() 85 | torch_mock.get_rng_state.assert_not_called() 86 | torch_mock.Generator.assert_not_called() 87 | torch_mock.set_rng_state.assert_not_called() 88 | 89 | 90 | class TestBaseSynthesizer: 91 | def test_set_random_state(self): 92 | """Test ``set_random_state`` works as expected.""" 93 | # Setup 94 | instance = BaseSynthesizer() 95 | 96 | # Run 97 | instance.set_random_state(3) 98 | 99 | # Assert 100 | assert isinstance(instance.random_states, tuple) 101 | assert isinstance(instance.random_states[0], np.random.RandomState) 102 | assert isinstance(instance.random_states[1], torch.Generator) 103 | 104 | def test_set_random_state_with_none(self): 105 | """Test ``set_random_state`` with None.""" 106 | # Setup 107 | instance = BaseSynthesizer() 108 | 109 | # Run and assert 110 | instance.set_random_state(3) 111 | assert instance.random_states is not None 112 | 113 | instance.set_random_state(None) 114 | assert instance.random_states is None 115 | -------------------------------------------------------------------------------- /tests/unit/synthesizer/test_ctgan.py: -------------------------------------------------------------------------------- 1 | """CTGAN unit testing module.""" 2 | 3 | from unittest import TestCase 4 | from unittest.mock import Mock 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | import torch 10 | 11 | from ctgan.data_transformer import SpanInfo 12 | from ctgan.errors import InvalidDataError 13 | from ctgan.synthesizers.ctgan import CTGAN, Discriminator, Generator, Residual 14 | 15 | 16 | class TestDiscriminator(TestCase): 17 | def test___init__(self): 18 | """Test `__init__` for a generic case. 19 | 20 | Make sure 'self.seq' has same length as 3*`discriminator_dim` + 1. 21 | 22 | Setup: 23 | - Create Discriminator 24 | 25 | Input: 26 | - input_dim = positive integer 27 | - discriminator_dim = list of integers 28 | - pack = positive integer 29 | 30 | Output: 31 | - None 32 | 33 | Side Effects: 34 | - Set `self.seq`, `self.pack` and `self.packdim` 35 | """ 36 | discriminator_dim = [1, 2, 3] 37 | discriminator = Discriminator(input_dim=50, discriminator_dim=discriminator_dim, pac=7) 38 | 39 | assert discriminator.pac == 7 40 | assert discriminator.pacdim == 350 41 | assert len(discriminator.seq) == 3 * len(discriminator_dim) + 1 42 | 43 | def test_forward(self): 44 | """Test `test_forward` for a generic case. 45 | 46 | Check that the output shapes are correct. 47 | We can also test that all parameters have a gradient attached to them 48 | by running `encoder.parameters()`. To do that, we just need to use `loss.backward()` 49 | for some loss, like `loss = torch.mean(output)`. Notice that the input_dim = input_size. 50 | 51 | Setup: 52 | - initialize with input_size, discriminator_dim, pac 53 | - Create random tensor as input 54 | 55 | Input: 56 | - input = random tensor of shape (N, input_size) 57 | 58 | Output: 59 | - tensor of shape (N/pac, 1) 60 | """ 61 | discriminator = Discriminator(input_dim=50, discriminator_dim=[100, 200, 300], pac=7) 62 | output = discriminator(torch.randn(70, 50)) 63 | assert output.shape == (10, 1) 64 | 65 | # Check to make sure no gradients attached 66 | for parameter in discriminator.parameters(): 67 | assert parameter.grad is None 68 | 69 | # Backpropagate 70 | output.mean().backward() 71 | 72 | # Check to make sure all parameters have gradients 73 | for parameter in discriminator.parameters(): 74 | assert parameter.grad is not None 75 | 76 | 77 | class TestResidual(TestCase): 78 | def test_forward(self): 79 | """Test `test_forward` for a generic case. 80 | 81 | Check that the output shapes are correct. 82 | We can also test that all parameters have a gradient attached to them 83 | by running `encoder.parameters()`. To do that, we just need to use `loss.backward()` 84 | for some loss, like `loss = torch.mean(output)`. 85 | 86 | Setup: 87 | - initialize with input_size, output_size 88 | - Create random tensor as input 89 | 90 | Input: 91 | - input = random tensor of shape (N, input_size) 92 | 93 | Output: 94 | - tensor of shape (N, input_size + output_size) 95 | """ 96 | residual = Residual(10, 2) 97 | output = residual(torch.randn(100, 10)) 98 | assert output.shape == (100, 12) 99 | 100 | # Check to make sure no gradients attached 101 | for parameter in residual.parameters(): 102 | assert parameter.grad is None 103 | 104 | # Backpropagate 105 | output.mean().backward() 106 | 107 | # Check to make sure all parameters have gradients 108 | for parameter in residual.parameters(): 109 | assert parameter.grad is not None 110 | 111 | 112 | class TestGenerator(TestCase): 113 | def test___init__(self): 114 | """Test `__init__` for a generic case. 115 | 116 | Make sure `self.seq` has same length as `generator_dim` + 1. 117 | 118 | Setup: 119 | - Create Generator 120 | 121 | Input: 122 | - embedding_dim = positive integer 123 | - generator_dim = list of integers 124 | - data_dim = positive integer 125 | 126 | Output: 127 | - None 128 | 129 | Side Effects: 130 | - Set `self.seq` 131 | """ 132 | generator_dim = [1, 2, 3] 133 | generator = Generator(embedding_dim=50, generator_dim=generator_dim, data_dim=7) 134 | 135 | assert len(generator.seq) == len(generator_dim) + 1 136 | 137 | def test_forward(self): 138 | """Test `test_forward` for a generic case. 139 | 140 | Check that the output shapes are correct. 141 | We can also test that all parameters have a gradient attached to them 142 | by running `encoder.parameters()`. To do that, we just need to use `loss.backward()` 143 | for some loss, like `loss = torch.mean(output)`. 144 | 145 | Setup: 146 | - initialize with embedding_dim, generator_dim, data_dim 147 | - Create random tensor as input 148 | 149 | Input: 150 | - input = random tensor of shape (N, input_size) 151 | 152 | Output: 153 | - tensor of shape (N, data_dim) 154 | """ 155 | generator = Generator(embedding_dim=60, generator_dim=[100, 200, 300], data_dim=500) 156 | output = generator(torch.randn(70, 60)) 157 | assert output.shape == (70, 500) 158 | 159 | # Check to make sure no gradients attached 160 | for parameter in generator.parameters(): 161 | assert parameter.grad is None 162 | 163 | # Backpropagate 164 | output.mean().backward() 165 | 166 | # Check to make sure all parameters have gradients 167 | for parameter in generator.parameters(): 168 | assert parameter.grad is not None 169 | 170 | 171 | def _assert_is_between(data, lower, upper): 172 | """Assert all values of the tensor 'data' are within range.""" 173 | assert all((data >= lower).numpy().tolist()) 174 | assert all((data <= upper).numpy().tolist()) 175 | 176 | 177 | class TestCTGAN(TestCase): 178 | def test__apply_activate_(self): 179 | """Test `_apply_activate` for tables with both continuous and categoricals. 180 | 181 | Check every continuous column has all values between -1 and 1 182 | (since they are normalized), and check every categorical column adds up to 1. 183 | 184 | Setup: 185 | - Mock `self._transformer.output_info_list` 186 | 187 | Input: 188 | - data = tensor of shape (N, data_dims) 189 | 190 | Output: 191 | - tensor = tensor of shape (N, data_dims) 192 | """ 193 | model = CTGAN() 194 | model._transformer = Mock() 195 | model._transformer.output_info_list = [ 196 | [SpanInfo(3, 'softmax')], 197 | [SpanInfo(1, 'tanh'), SpanInfo(2, 'softmax')], 198 | ] 199 | 200 | data = torch.randn(100, 6) 201 | result = model._apply_activate(data) 202 | 203 | assert result.shape == (100, 6) 204 | _assert_is_between(result[:, 0:3], 0.0, 1.0) 205 | _assert_is_between(result[:3], -1.0, 1.0) 206 | _assert_is_between(result[:, 4:6], 0.0, 1.0) 207 | 208 | def test__cond_loss(self): 209 | """Test `_cond_loss`. 210 | 211 | Test that the loss is purely a function of the target categorical. 212 | 213 | Setup: 214 | - mock transformer.output_info_list 215 | - create two categoricals, one continuous 216 | - compute the conditional loss, conditioned on the 1st categorical 217 | - compare the loss to the cross-entropy of the 1st categorical, manually computed 218 | 219 | Input: 220 | data - the synthetic data generated by the model 221 | c - a tensor with the same shape as the data but with only a specific one-hot vector 222 | corresponding to the target column filled in 223 | m - binary mask used to select the categorical column to condition on 224 | 225 | Output: 226 | loss scalar; this should only be affected by the target column 227 | 228 | Note: 229 | - even though the implementation of this is probably right, I'm not sure if the idea 230 | behind it is correct 231 | """ 232 | model = CTGAN() 233 | model._transformer = Mock() 234 | model._transformer.output_info_list = [ 235 | [SpanInfo(1, 'tanh'), SpanInfo(2, 'softmax')], 236 | [SpanInfo(3, 'softmax')], # this is the categorical column we are conditioning on 237 | [SpanInfo(2, 'softmax')], # this is the categorical column we are bry jrbec on 238 | ] 239 | 240 | data = torch.tensor([ 241 | # first 3 dims ignored, next 3 dims are the prediction, last 2 dims are ignored 242 | [0.0, -1.0, 0.0, 0.05, 0.05, 0.9, 0.1, 0.4], 243 | ]) 244 | 245 | c = torch.tensor([ 246 | # first 3 dims are a one-hot for the categorical, 247 | # next 2 are for a different categorical that we are not conditioning on 248 | # (continuous values are not stored in this tensor) 249 | [0.0, 0.0, 1.0, 0.0, 0.0], 250 | ]) 251 | 252 | # this indicates that we are conditioning on the first categorical 253 | m = torch.tensor([[1, 0]]) 254 | 255 | result = model._cond_loss(data, c, m) 256 | expected = torch.nn.functional.cross_entropy( 257 | torch.tensor([ 258 | [0.05, 0.05, 0.9], # 3 categories, one hot 259 | ]), 260 | torch.tensor([2]), 261 | ) 262 | 263 | assert (result - expected).abs() < 1e-3 264 | 265 | def test__validate_discrete_columns(self): 266 | """Test `_validate_discrete_columns` if the discrete column doesn't exist. 267 | 268 | Check the appropriate error is raised if `discrete_columns` is invalid, both 269 | for numpy arrays and dataframes. 270 | 271 | Setup: 272 | - Create dataframe with a discrete column 273 | - Define `discrete_columns` as something not in the dataframe 274 | 275 | Input: 276 | - train_data = 2-dimensional numpy array or a pandas.DataFrame 277 | - discrete_columns = list of strings or integers 278 | 279 | Output: 280 | None 281 | 282 | Side Effects: 283 | - Raises error if the discrete column is invalid. 284 | 285 | Note: 286 | - could create another function for numpy array 287 | """ 288 | data = pd.DataFrame({'discrete': ['a', 'b']}) 289 | discrete_columns = ['doesnt exist'] 290 | 291 | ctgan = CTGAN(epochs=1) 292 | with pytest.raises(ValueError, match=r'Invalid columns found: {\'doesnt exist\'}'): 293 | ctgan.fit(data, discrete_columns) 294 | 295 | def test__validate_null_data(self): 296 | """Test `_validate_null_data` with pandas and numpy data. 297 | 298 | Check the appropriate error is raised if null values are present in 299 | continuous columns, both for numpy arrays and dataframes. 300 | """ 301 | # Setup 302 | discrete_df = pd.DataFrame({'discrete': ['a', 'b']}) 303 | discrete_array = np.array([['a'], ['b']]) 304 | continuous_no_nulls_df = pd.DataFrame({'continuous': [0, 1]}) 305 | continuous_no_nulls_array = np.array([[0], [1]]) 306 | continuous_with_null_df = pd.DataFrame({'continuous': [1, np.nan]}) 307 | continuous_with_null_array = np.array([[1], [np.nan]]) 308 | ctgan = CTGAN(epochs=1) 309 | error_message = ( 310 | 'CTGAN does not support null values in the continuous training data. ' 311 | 'Please remove all null values from your continuous training data.' 312 | ) 313 | 314 | # Test discrete DataFrame fits without error 315 | ctgan.fit(discrete_df, ['discrete']) 316 | 317 | # Test discrete array fits without error 318 | ctgan.fit(discrete_array, [0]) 319 | 320 | # Test continuous DataFrame without nulls fits without error 321 | ctgan.fit(continuous_no_nulls_df) 322 | 323 | # Test continuous array without nulls fits without error 324 | ctgan.fit(continuous_no_nulls_array) 325 | 326 | # Test nulls in continuous columns DataFrame errors on fit 327 | with pytest.raises(InvalidDataError, match=error_message): 328 | ctgan.fit(continuous_with_null_df) 329 | 330 | # Test nulls in continuous columns array errors on fit 331 | with pytest.raises(InvalidDataError, match=error_message): 332 | ctgan.fit(continuous_with_null_array) 333 | -------------------------------------------------------------------------------- /tests/unit/synthesizer/test_tvae.py: -------------------------------------------------------------------------------- 1 | """TVAE unit testing module.""" 2 | 3 | from unittest.mock import MagicMock, Mock, call, patch 4 | 5 | import pandas as pd 6 | 7 | from ctgan.synthesizers import TVAE 8 | 9 | 10 | class TestTVAE: 11 | @patch('ctgan.synthesizers.tvae._loss_function') 12 | @patch('ctgan.synthesizers.tvae.tqdm') 13 | def test_fit_verbose(self, tqdm_mock, loss_func_mock): 14 | """Test verbose parameter prints progress bar.""" 15 | # Setup 16 | epochs = 1 17 | 18 | def mock_iter(): 19 | for i in range(epochs): 20 | yield i 21 | 22 | def mock_add(a, b): 23 | mock_loss = Mock() 24 | mock_loss.detach().cpu().item.return_value = 1.23456789 25 | return mock_loss 26 | 27 | loss_mock = MagicMock() 28 | loss_mock.__add__ = mock_add 29 | loss_func_mock.return_value = (loss_mock, loss_mock) 30 | 31 | iterator_mock = MagicMock() 32 | iterator_mock.__iter__.side_effect = mock_iter 33 | tqdm_mock.return_value = iterator_mock 34 | synth = TVAE(epochs=epochs, verbose=True) 35 | train_data = pd.DataFrame({'col1': [0, 1, 2, 3, 4], 'col2': [10, 11, 12, 13, 14]}) 36 | 37 | # Run 38 | synth.fit(train_data) 39 | 40 | # Assert 41 | tqdm_mock.assert_called_once_with(range(epochs), disable=False) 42 | assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000') 43 | assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235') 44 | assert iterator_mock.set_description.call_count == 2 45 | -------------------------------------------------------------------------------- /tests/unit/test_data_transformer.py: -------------------------------------------------------------------------------- 1 | """Data transformer unit testing module.""" 2 | 3 | from unittest import TestCase 4 | from unittest.mock import Mock, patch 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from ctgan.data_transformer import ColumnTransformInfo, DataTransformer, SpanInfo 10 | 11 | 12 | class TestDataTransformer(TestCase): 13 | @patch('ctgan.data_transformer.ClusterBasedNormalizer') 14 | def test___fit_continuous(self, MockCBN): 15 | """Test ``_fit_continuous`` on a simple continuous column. 16 | 17 | A ``ClusterBasedNormalizer`` will be created and fit with some ``data``. 18 | 19 | Setup: 20 | - Mock the ``ClusterBasedNormalizer`` with ``valid_component_indicator`` as 21 | ``[True, False, True]``. 22 | - Initialize a ``DataTransformer``. 23 | 24 | Input: 25 | - A dataframe with only one column containing random float values. 26 | 27 | Output: 28 | - A ``ColumnTransformInfo`` object where: 29 | - ``column_name`` matches the column of the data. 30 | - ``transform`` is the ``ClusterBasedNormalizer`` instance. 31 | - ``output_dimensions`` is 3 (matches size of ``valid_component_indicator``). 32 | - ``output_info`` assigns the correct activation functions. 33 | 34 | Side Effects: 35 | - ``fit`` should be called with the data. 36 | """ 37 | # Setup 38 | cbn_instance = MockCBN.return_value 39 | cbn_instance.valid_component_indicator = [True, False, True] 40 | transformer = DataTransformer() 41 | data = pd.DataFrame(np.random.normal((100, 1)), columns=['column']) 42 | 43 | # Run 44 | info = transformer._fit_continuous(data) 45 | 46 | # Assert 47 | assert info.column_name == 'column' 48 | assert info.transform == cbn_instance 49 | assert info.output_dimensions == 3 50 | assert info.output_info[0].dim == 1 51 | assert info.output_info[0].activation_fn == 'tanh' 52 | assert info.output_info[1].dim == 2 53 | assert info.output_info[1].activation_fn == 'softmax' 54 | 55 | @patch('ctgan.data_transformer.ClusterBasedNormalizer') 56 | def test__fit_continuous_max_clusters(self, MockCBN): 57 | """Test ``_fit_continuous`` with data that has less than 10 rows. 58 | 59 | Expect that a ``ClusterBasedNormalizer`` is created with the max number of clusters 60 | set to the length of the data. 61 | 62 | Input: 63 | - Data with less than 10 rows. 64 | 65 | Side Effects: 66 | - A ``ClusterBasedNormalizer`` is created with the max number of clusters set to the 67 | length of the data. 68 | """ 69 | # Setup 70 | data = pd.DataFrame(np.random.normal((7, 1)), columns=['column']) 71 | transformer = DataTransformer() 72 | 73 | # Run 74 | transformer._fit_continuous(data) 75 | 76 | # Assert 77 | MockCBN.assert_called_once_with( 78 | missing_value_generation='from_column', max_clusters=len(data), weight_threshold=0.005 79 | ) 80 | 81 | @patch('ctgan.data_transformer.OneHotEncoder') 82 | def test___fit_discrete(self, MockOHE): 83 | """Test ``_fit_discrete_`` on a simple discrete column. 84 | 85 | A ``OneHotEncoder`` will be created and fit with the ``data``. 86 | 87 | Setup: 88 | - Mock the ``OneHotEncoder``. 89 | - Create ``DataTransformer``. 90 | 91 | Input: 92 | - A dataframe with only one column containing ``['a', 'b']`` values. 93 | 94 | Output: 95 | - A ``ColumnTransformInfo`` object where: 96 | - ``column_name`` matches the column of the data. 97 | - ``transform`` is the ``OneHotEncoder`` instance. 98 | - ``output_dimensions`` is 2. 99 | - ``output_info`` assigns the correct activation function. 100 | 101 | Side Effects: 102 | - ``fit`` should be called with the data. 103 | """ 104 | # Setup 105 | ohe_instance = MockOHE.return_value 106 | ohe_instance.dummies = ['a', 'b'] 107 | transformer = DataTransformer() 108 | data = pd.DataFrame(np.array(['a', 'b'] * 100), columns=['column']) 109 | 110 | # Run 111 | info = transformer._fit_discrete(data) 112 | 113 | # Assert 114 | assert info.column_name == 'column' 115 | assert info.transform == ohe_instance 116 | assert info.output_dimensions == 2 117 | assert info.output_info[0].dim == 2 118 | assert info.output_info[0].activation_fn == 'softmax' 119 | 120 | def test_fit(self): 121 | """Test ``fit`` on a np.ndarray with one continuous and one discrete columns. 122 | 123 | The ``fit`` method should: 124 | - Set ``self.dataframe`` to ``False``. 125 | - Set ``self._column_raw_dtypes`` to the appropirate dtypes. 126 | - Use the appropriate ``_fit`` type for each column. 127 | - Update ``self.output_info_list``, ``self.output_dimensions`` and 128 | ``self._column_transform_info_list`` appropriately. 129 | 130 | Setup: 131 | - Create ``DataTransformer``. 132 | - Mock ``_fit_discrete``. 133 | - Mock ``_fit_continuous``. 134 | 135 | Input: 136 | - A table with one continuous and one discrete columns. 137 | - A list with the name of the discrete column. 138 | 139 | Side Effects: 140 | - ``_fit_discrete`` and ``_fit_continuous`` should each be called once. 141 | - Assigns ``self._column_raw_dtypes`` the appropriate dtypes. 142 | - Assigns ``self.output_info_list`` the appropriate ``output_info``. 143 | - Assigns ``self.output_dimensions`` the appropriate ``output_dimensions``. 144 | - Assigns ``self._column_transform_info_list`` the appropriate 145 | ``column_transform_info``. 146 | """ 147 | # Setup 148 | transformer = DataTransformer() 149 | transformer._fit_continuous = Mock() 150 | transformer._fit_continuous.return_value = ColumnTransformInfo( 151 | column_name='x', 152 | column_type='continuous', 153 | transform=None, 154 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')], 155 | output_dimensions=1 + 3, 156 | ) 157 | 158 | transformer._fit_discrete = Mock() 159 | transformer._fit_discrete.return_value = ColumnTransformInfo( 160 | column_name='y', 161 | column_type='discrete', 162 | transform=None, 163 | output_info=[SpanInfo(2, 'softmax')], 164 | output_dimensions=2, 165 | ) 166 | 167 | data = pd.DataFrame({ 168 | 'x': np.random.random(size=100), 169 | 'y': np.random.choice(['yes', 'no'], size=100), 170 | }) 171 | 172 | # Run 173 | transformer.fit(data, discrete_columns=['y']) 174 | 175 | # Assert 176 | transformer._fit_discrete.assert_called_once() 177 | transformer._fit_continuous.assert_called_once() 178 | assert transformer.output_dimensions == 6 179 | 180 | @patch('ctgan.data_transformer.ClusterBasedNormalizer') 181 | def test__transform_continuous(self, MockCBN): 182 | """Test ``_transform_continuous``. 183 | 184 | Setup: 185 | - Mock the ``ClusterBasedNormalizer`` with the transform method returning 186 | some dataframe. 187 | - Create ``DataTransformer``. 188 | 189 | Input: 190 | - ``ColumnTransformInfo`` object. 191 | - A dataframe containing a continuous column. 192 | 193 | Output: 194 | - A np.array where the first column contains the normalized part 195 | of the mocked transform, and the other columns are a one hot encoding 196 | representation of the component part of the mocked transform. 197 | """ 198 | # Setup 199 | cbn_instance = MockCBN.return_value 200 | cbn_instance.transform.return_value = pd.DataFrame({ 201 | 'x.normalized': [0.1, 0.2, 0.3], 202 | 'x.component': [0.0, 1.0, 1.0], 203 | }) 204 | 205 | transformer = DataTransformer() 206 | data = pd.DataFrame({'x': np.array([0.1, 0.3, 0.5])}) 207 | column_transform_info = ColumnTransformInfo( 208 | column_name='x', 209 | column_type='continuous', 210 | transform=cbn_instance, 211 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')], 212 | output_dimensions=1 + 3, 213 | ) 214 | 215 | # Run 216 | result = transformer._transform_continuous(column_transform_info, data) 217 | 218 | # Assert 219 | expected = np.array([ 220 | [0.1, 1, 0, 0], 221 | [0.2, 0, 1, 0], 222 | [0.3, 0, 1, 0], 223 | ]) 224 | np.testing.assert_array_equal(result, expected) 225 | 226 | def test_transform(self): 227 | """Test ``transform`` on a dataframe with one continuous and one discrete columns. 228 | 229 | It should use the appropriate ``_transform`` type for each column and should return 230 | them concanenated appropriately. 231 | 232 | Setup: 233 | - Initialize a ``DataTransformer`` with a ``column_transform_info`` detailing 234 | a continuous and a discrete columns. 235 | - Mock the ``_transform_discrete`` and ``_transform_continuous`` methods. 236 | 237 | Input: 238 | - A table with one continuous and one discrete columns. 239 | 240 | Output: 241 | - np.array containing the transformed columns. 242 | """ 243 | # Setup 244 | data = pd.DataFrame({'x': np.array([0.1, 0.3, 0.5]), 'y': np.array(['yes', 'yes', 'no'])}) 245 | 246 | transformer = DataTransformer() 247 | transformer._column_transform_info_list = [ 248 | ColumnTransformInfo( 249 | column_name='x', 250 | column_type='continuous', 251 | transform=None, 252 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')], 253 | output_dimensions=1 + 3, 254 | ), 255 | ColumnTransformInfo( 256 | column_name='y', 257 | column_type='discrete', 258 | transform=None, 259 | output_info=[SpanInfo(2, 'softmax')], 260 | output_dimensions=2, 261 | ), 262 | ] 263 | 264 | transformer._transform_continuous = Mock() 265 | selected_normalized_value = np.array([[0.1], [0.3], [0.5]]) 266 | selected_component_onehot = np.array([ 267 | [1, 0, 0], 268 | [0, 1, 0], 269 | [0, 1, 0], 270 | ]) 271 | return_value = np.concatenate( 272 | (selected_normalized_value, selected_component_onehot), axis=1 273 | ) 274 | transformer._transform_continuous.return_value = return_value 275 | 276 | transformer._transform_discrete = Mock() 277 | transformer._transform_discrete.return_value = np.array([ 278 | [0, 1], 279 | [0, 1], 280 | [1, 0], 281 | ]) 282 | 283 | # Run 284 | result = transformer.transform(data) 285 | 286 | # Assert 287 | expected = np.array([ 288 | [0.1, 1, 0, 0, 0, 1], 289 | [0.3, 0, 1, 0, 0, 1], 290 | [0.5, 0, 1, 0, 1, 0], 291 | ]) 292 | assert result.shape == (3, 6) 293 | assert (result[:, 0] == expected[:, 0]).all(), 'continuous-cdf' 294 | assert (result[:, 1:4] == expected[:, 1:4]).all(), 'continuous-softmax' 295 | assert (result[:, 4:6] == expected[:, 4:6]).all(), 'discrete' 296 | 297 | def test_parallel_sync_transform_same_output(self): 298 | """Test ``_parallel_transform`` and ``_synchronous_transform`` on a dataframe. 299 | 300 | The output of ``_parallel_transform`` should be the same as the output of 301 | ``_synchronous_transform``. 302 | 303 | Setup: 304 | - Initialize a ``DataTransformer`` with a ``column_transform_info`` detailing 305 | a continuous and a discrete columns. 306 | - Mock the ``_transform_discrete`` and ``_transform_continuous`` methods. 307 | 308 | Input: 309 | - A table with one continuous and one discrete columns. 310 | 311 | Output: 312 | - A list containing the transformed columns. 313 | """ 314 | # Setup 315 | data = pd.DataFrame({'x': np.array([0.1, 0.3, 0.5]), 'y': np.array(['yes', 'yes', 'no'])}) 316 | 317 | transformer = DataTransformer() 318 | transformer._column_transform_info_list = [ 319 | ColumnTransformInfo( 320 | column_name='x', 321 | column_type='continuous', 322 | transform=None, 323 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')], 324 | output_dimensions=1 + 3, 325 | ), 326 | ColumnTransformInfo( 327 | column_name='y', 328 | column_type='discrete', 329 | transform=None, 330 | output_info=[SpanInfo(2, 'softmax')], 331 | output_dimensions=2, 332 | ), 333 | ] 334 | 335 | transformer._transform_continuous = Mock() 336 | selected_normalized_value = np.array([[0.1], [0.3], [0.5]]) 337 | selected_component_onehot = np.array([ 338 | [1, 0, 0], 339 | [0, 1, 0], 340 | [0, 1, 0], 341 | ]) 342 | return_value = np.concatenate( 343 | (selected_normalized_value, selected_component_onehot), axis=1 344 | ) 345 | transformer._transform_continuous.return_value = return_value 346 | 347 | transformer._transform_discrete = Mock() 348 | transformer._transform_discrete.return_value = np.array([ 349 | [0, 1], 350 | [0, 1], 351 | [1, 0], 352 | ]) 353 | 354 | # Run 355 | parallel_result = transformer._parallel_transform( 356 | data, transformer._column_transform_info_list 357 | ) 358 | sync_result = transformer._synchronous_transform( 359 | data, transformer._column_transform_info_list 360 | ) 361 | parallel_result_np = np.concatenate(parallel_result, axis=1).astype(float) 362 | sync_result_np = np.concatenate(sync_result, axis=1).astype(float) 363 | 364 | # Assert 365 | assert len(parallel_result) == len(sync_result) 366 | np.testing.assert_array_equal(parallel_result_np, sync_result_np) 367 | 368 | @patch('ctgan.data_transformer.ClusterBasedNormalizer') 369 | def test__inverse_transform_continuous(self, MockCBN): 370 | """Test ``_inverse_transform_continuous``. 371 | 372 | Setup: 373 | - Create ``DataTransformer``. 374 | - Mock the ``ClusterBasedNormalizer`` where: 375 | - ``get_output_sdtypes`` returns the appropriate dictionary. 376 | - ``reverse_transform`` returns some dataframe. 377 | 378 | Input: 379 | - A ``ColumnTransformInfo`` object. 380 | - A np.ndarray where: 381 | - The first column contains the normalized value 382 | - The remaining columns correspond to the one-hot 383 | - sigmas = np.ndarray of floats 384 | - st = index of the sigmas ndarray 385 | 386 | Output: 387 | - Dataframe where the first column are floats and the second is a lable encoding. 388 | 389 | Side Effects: 390 | - The ``reverse_transform`` method should be called with a dataframe 391 | where the first column are floats and the second is a lable encoding. 392 | """ 393 | # Setup 394 | cbn_instance = MockCBN.return_value 395 | cbn_instance.get_output_sdtypes.return_value = { 396 | 'x.normalized': 'numerical', 397 | 'x.component': 'numerical', 398 | } 399 | 400 | cbn_instance.reverse_transform.return_value = pd.DataFrame({ 401 | 'x.normalized': [0.1, 0.2, 0.3], 402 | 'x.component': [0.0, 1.0, 1.0], 403 | }) 404 | 405 | transformer = DataTransformer() 406 | column_data = np.array([ 407 | [0.1, 1, 0, 0], 408 | [0.3, 0, 1, 0], 409 | [0.5, 0, 1, 0], 410 | ]) 411 | 412 | column_transform_info = ColumnTransformInfo( 413 | column_name='x', 414 | column_type='continuous', 415 | transform=cbn_instance, 416 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')], 417 | output_dimensions=1 + 3, 418 | ) 419 | 420 | # Run 421 | result = transformer._inverse_transform_continuous( 422 | column_transform_info, column_data, None, None 423 | ) 424 | 425 | # Assert 426 | expected = pd.DataFrame({'x.normalized': [0.1, 0.2, 0.3], 'x.component': [0.0, 1.0, 1.0]}) 427 | 428 | np.testing.assert_array_equal(result, expected) 429 | 430 | expected_data = pd.DataFrame({'x.normalized': [0.1, 0.3, 0.5], 'x.component': [0, 1, 1]}) 431 | 432 | pd.testing.assert_frame_equal(cbn_instance.reverse_transform.call_args[0][0], expected_data) 433 | 434 | def test_convert_column_name_value_to_id(self): 435 | """Test ``convert_column_name_value_to_id`` on a simple ``_column_transform_info_list``. 436 | 437 | Tests that the appropriate indexes are returned when a table of three columns, 438 | discrete, continuous, discrete, is passed as '_column_transform_info_list'. 439 | 440 | Setup: 441 | - Mock ``_column_transform_info_list``. 442 | 443 | Input: 444 | - column_name = the name of a discrete column 445 | - value = the categorical value 446 | 447 | Output: 448 | - dictionary containing: 449 | - ``discrete_column_id`` = the index of the target column, 450 | when considering only discrete columns 451 | - ``column_id`` = the index of the target column 452 | (e.g. 3 = the third column in the data) 453 | - ``value_id`` = the index of the indicator value in the one-hot encoding 454 | """ 455 | # Setup 456 | ohe = Mock() 457 | ohe.transform.return_value = pd.DataFrame([ 458 | [0, 1] # one hot encoding, second dimension 459 | ]) 460 | transformer = DataTransformer() 461 | transformer._column_transform_info_list = [ 462 | ColumnTransformInfo( 463 | column_name='x', 464 | column_type='continuous', 465 | transform=None, 466 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')], 467 | output_dimensions=1 + 3, 468 | ), 469 | ColumnTransformInfo( 470 | column_name='y', 471 | column_type='discrete', 472 | transform=ohe, 473 | output_info=[SpanInfo(2, 'softmax')], 474 | output_dimensions=2, 475 | ), 476 | ] 477 | 478 | # Run 479 | result = transformer.convert_column_name_value_to_id('y', 'yes') 480 | 481 | # Assert 482 | assert result['column_id'] == 1 # this is the 2nd column 483 | assert result['discrete_column_id'] == 0 # this is the 1st discrete column 484 | assert result['value_id'] == 1 # this is the 2nd dimension in the one hot encoding 485 | 486 | def test_convert_column_name_value_to_id_multiple(self): 487 | """Test ``convert_column_name_value_to_id``.""" 488 | # Setup 489 | ohe = Mock() 490 | ohe.transform.return_value = pd.DataFrame([ 491 | [0, 1, 0] # one hot encoding, second dimension 492 | ]) 493 | transformer = DataTransformer() 494 | transformer._column_transform_info_list = [ 495 | ColumnTransformInfo( 496 | column_name='x', 497 | column_type='continuous', 498 | transform=None, 499 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')], 500 | output_dimensions=1 + 3, 501 | ), 502 | ColumnTransformInfo( 503 | column_name='y', 504 | column_type='discrete', 505 | transform=ohe, 506 | output_info=[SpanInfo(2, 'softmax')], 507 | output_dimensions=2, 508 | ), 509 | ColumnTransformInfo( 510 | column_name='z', 511 | column_type='discrete', 512 | transform=ohe, 513 | output_info=[SpanInfo(2, 'softmax')], 514 | output_dimensions=2, 515 | ), 516 | ] 517 | 518 | # Run 519 | result = transformer.convert_column_name_value_to_id('z', 'yes') 520 | 521 | # Assert 522 | assert result['column_id'] == 2 # this is the 3rd column 523 | assert result['discrete_column_id'] == 1 # this is the 2nd discrete column 524 | assert result['value_id'] == 1 # this is the 1st dimension in the one hot encoding 525 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py39-lint, py3{8,9,10,11,12,13}-{unit,integration,readme} 3 | 4 | [testenv] 5 | skipsdist = false 6 | skip_install = false 7 | deps = 8 | invoke 9 | readme: rundoc 10 | extras = 11 | lint: dev 12 | unit: test 13 | integration: test 14 | commands = 15 | lint: invoke lint 16 | unit: invoke unit 17 | integration: invoke integration 18 | readme: invoke readme 19 | invoke rmdir --path {envdir} 20 | --------------------------------------------------------------------------------