├── .editorconfig
├── .github
├── CODEOWNERS
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── feature_request.md
│ └── question.md
└── workflows
│ ├── dependency_checker.yml
│ ├── integration.yml
│ ├── lint.yml
│ ├── minimum.yml
│ ├── prepare_release.yml
│ ├── readme.yml
│ └── unit.yml
├── .gitignore
├── AUTHORS.rst
├── CONTRIBUTING.rst
├── HISTORY.md
├── LICENSE
├── Makefile
├── README.md
├── ctgan
├── __init__.py
├── __main__.py
├── data.py
├── data_sampler.py
├── data_transformer.py
├── demo.py
├── errors.py
└── synthesizers
│ ├── __init__.py
│ ├── base.py
│ ├── ctgan.py
│ └── tvae.py
├── examples
├── csv
│ ├── adult.csv
│ └── adult.json
└── tsv
│ ├── acs.dat
│ ├── acs.meta
│ ├── adult.dat
│ ├── adult.meta
│ ├── br2000.dat
│ ├── br2000.meta
│ ├── nltcs.dat
│ └── nltcs.meta
├── latest_requirements.txt
├── pyproject.toml
├── scripts
└── release_notes_generator.py
├── static_code_analysis.txt
├── tasks.py
├── tests
├── __init__.py
├── integration
│ ├── __init__.py
│ ├── synthesizer
│ │ ├── __init__.py
│ │ ├── test_ctgan.py
│ │ └── test_tvae.py
│ ├── test_data_transformer.py
│ └── test_load_demo.py
├── test_tasks.py
└── unit
│ ├── __init__.py
│ ├── synthesizer
│ ├── __init__.py
│ ├── test_base.py
│ ├── test_ctgan.py
│ └── test_tvae.py
│ └── test_data_transformer.py
└── tox.ini
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.py]
14 | max_line_length = 99
15 |
16 | [*.bat]
17 | indent_style = tab
18 | end_of_line = crlf
19 |
20 | [LICENSE]
21 | insert_final_newline = false
22 |
23 | [Makefile]
24 | indent_style = tab
25 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Global rule:
2 | * @sdv-dev/core-contributors
3 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Report an error that you found when using CTGAN
4 | title: ''
5 | labels: bug, new
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### Environment Details
11 |
12 | Please indicate the following details about the environment in which you found the bug:
13 |
14 | * CTGAN version:
15 | * Python version:
16 | * Operating System:
17 |
18 | ### Error Description
19 |
20 |
22 |
23 | ### Steps to reproduce
24 |
25 |
29 |
30 | ```
31 | Paste the command(s) you ran and the output.
32 | If there was a crash, please include the traceback here.
33 | ```
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Request a new feature that you would like to see implemented in CTGAN
4 | title: ''
5 | labels: new feature, new
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### Problem Description
11 |
12 |
14 |
15 | ### Expected behavior
16 |
17 |
20 |
21 | ### Additional context
22 |
23 |
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Question
3 | about: Doubts about CTGAN usage
4 | title: ''
5 | labels: question, new
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### Environment details
11 |
12 | If you are already running CTGAN, please indicate the following details about the environment in
13 | which you are running it:
14 |
15 | * CTGAN version:
16 | * Python version:
17 | * Operating System:
18 |
19 | ### Problem description
20 |
21 |
24 |
25 | ### What I already tried
26 |
27 |
29 |
30 | ```
31 | Paste the command(s) you ran and the output.
32 | If there was a crash, please include the traceback here.
33 | ```
34 |
--------------------------------------------------------------------------------
/.github/workflows/dependency_checker.yml:
--------------------------------------------------------------------------------
1 | name: Dependency Checker
2 | on:
3 | schedule:
4 | - cron: '0 0 * * 1'
5 | workflow_dispatch:
6 | jobs:
7 | build:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v4
11 | - name: Set up Python 3.9
12 | uses: actions/setup-python@v5
13 | with:
14 | python-version: 3.9
15 | - name: Install dependencies
16 | run: |
17 | python -m pip install .[dev]
18 | make check-deps OUTPUT_FILEPATH=latest_requirements.txt
19 | make fix-lint
20 | - name: Create pull request
21 | id: cpr
22 | uses: peter-evans/create-pull-request@v4
23 | with:
24 | token: ${{ secrets.GH_ACCESS_TOKEN }}
25 | commit-message: Update latest dependencies
26 | author: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>"
27 | committer: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>"
28 | title: Automated Latest Dependency Updates
29 | body: "This is an auto-generated PR with **latest** dependency updates."
30 | branch: latest-dependency-update
31 | branch-suffix: short-commit-hash
32 | base: main
33 |
--------------------------------------------------------------------------------
/.github/workflows/integration.yml:
--------------------------------------------------------------------------------
1 | name: Integration Tests
2 |
3 | on:
4 | push:
5 | pull_request:
6 | types: [opened, reopened]
7 |
8 | jobs:
9 | integration:
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | matrix:
13 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
14 | os: [ubuntu-latest, windows-latest]
15 | include:
16 | - os: macos-latest
17 | python-version: '3.8'
18 | - os: macos-latest
19 | python-version: '3.13'
20 | steps:
21 | - uses: actions/checkout@v4
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | python -m pip install invoke .[test]
30 | - name: Run integration tests
31 | run: invoke integration
32 |
33 | - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.13
34 | name: Upload integration codecov report
35 | uses: codecov/codecov-action@v4
36 | with:
37 | flags: integration
38 | file: ${{ github.workspace }}/integration_cov.xml
39 | fail_ci_if_error: true
40 | token: ${{ secrets.CODECOV_TOKEN }}
41 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Style Checks
2 |
3 | on:
4 | push:
5 | pull_request:
6 | types: [opened, reopened]
7 |
8 | jobs:
9 | lint:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - name: Set up Python 3.9
14 | uses: actions/setup-python@v5
15 | with:
16 | python-version: 3.9
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | python -m pip install invoke .[dev]
21 | - name: Run lint checks
22 | run: invoke lint
23 |
--------------------------------------------------------------------------------
/.github/workflows/minimum.yml:
--------------------------------------------------------------------------------
1 | name: Unit Tests Minimum Versions
2 | concurrency:
3 | group: ${{ github.workflow }}-${{ github.ref }}
4 | cancel-in-progress: true
5 | on:
6 | push:
7 | pull_request:
8 | types: [opened, reopened]
9 | jobs:
10 | minimum:
11 | runs-on: ${{ matrix.os }}
12 | strategy:
13 | matrix:
14 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
15 | os: [ubuntu-latest, windows-latest]
16 | include:
17 | - os: macos-13
18 | python-version: '3.8'
19 | - os: macos-latest
20 | python-version: '3.13'
21 | steps:
22 | - uses: actions/checkout@v4
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v5
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install invoke .[test]
31 | - name: Test with minimum versions
32 | run: invoke minimum
33 |
--------------------------------------------------------------------------------
/.github/workflows/prepare_release.yml:
--------------------------------------------------------------------------------
1 | name: Release Prep
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | branch:
7 | description: 'Branch to merge release notes and code analysis into.'
8 | required: true
9 | default: 'main'
10 | version:
11 | description:
12 | 'Version to use for the release. Must be in format: X.Y.Z.'
13 | date:
14 | description:
15 | 'Date of the release. Must be in format YYYY-MM-DD.'
16 |
17 | jobs:
18 | preparerelease:
19 | runs-on: ubuntu-latest
20 | steps:
21 | - uses: actions/checkout@v4
22 | - name: Set up Python 3.10
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: '3.10'
26 |
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install requests==2.31.0
31 | python -m pip install bandit==1.7.7
32 | python -m pip install .[test]
33 |
34 | - name: Generate release notes
35 | env:
36 | GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
37 | run: >
38 | python scripts/release_notes_generator.py
39 | -v ${{ inputs.version }}
40 | -d ${{ inputs.date }}
41 |
42 | - name: Save static code analysis
43 | run: bandit -r . -x ./tests,./scripts,./build -f txt -o static_code_analysis.txt --exit-zero
44 |
45 | - name: Create pull request
46 | id: cpr
47 | uses: peter-evans/create-pull-request@v4
48 | with:
49 | token: ${{ secrets.GH_ACCESS_TOKEN }}
50 | commit-message: Prepare release for v${{ inputs.version }}
51 | author: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>"
52 | committer: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>"
53 | title: v${{ inputs.version }} Release Preparation
54 | body: "This is an auto-generated PR to prepare the release."
55 | branch: prepared-release
56 | branch-suffix: short-commit-hash
57 | base: ${{ inputs.branch }}
58 |
--------------------------------------------------------------------------------
/.github/workflows/readme.yml:
--------------------------------------------------------------------------------
1 | name: Test README
2 |
3 | on:
4 | push:
5 | pull_request:
6 | types: [opened, reopened]
7 |
8 | jobs:
9 | readme:
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | matrix:
13 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
14 | os: [ubuntu-latest, macos-latest] # skip windows bc rundoc fails
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | python -m pip install invoke rundoc .
25 | python -m pip install tomli
26 | python -m pip install packaging
27 | - name: Run the README.md
28 | run: invoke readme
29 |
--------------------------------------------------------------------------------
/.github/workflows/unit.yml:
--------------------------------------------------------------------------------
1 | name: Unit Tests
2 |
3 | on:
4 | push:
5 | pull_request:
6 | types: [opened, reopened]
7 |
8 | jobs:
9 | unit:
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | matrix:
13 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
14 | os: [ubuntu-latest, windows-latest]
15 | include:
16 | - os: macos-latest
17 | python-version: '3.8'
18 | - os: macos-latest
19 | python-version: '3.13'
20 | steps:
21 | - uses: actions/checkout@v4
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | python -m pip install invoke .[test]
30 | - name: Run unit tests
31 | run: invoke unit
32 |
33 | - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.13
34 | name: Upload unit codecov report
35 | uses: codecov/codecov-action@v4
36 | with:
37 | flags: unit
38 | file: ${{ github.workspace }}/unit_cov.xml
39 | fail_ci_if_error: true
40 | token: ${{ secrets.CODECOV_TOKEN }}
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 | tests/readme_test/
50 | *_cov.xml
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 | docs/api/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # dotenv
87 | .env
88 |
89 | # virtualenv
90 | .venv
91 | venv/
92 | ENV/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | # Vim
108 | .*.swp
109 |
--------------------------------------------------------------------------------
/AUTHORS.rst:
--------------------------------------------------------------------------------
1 | See: https://github.com/sdv-dev/CTGAN/graphs/contributors
2 |
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
1 | ============
2 | Contributing
3 | ============
4 |
5 | We love our community’s interest in the SDV ecosystem. The SDV software
6 | (and its related libraries) is owned and maintained by DataCebo, Inc.
7 | It is available under the `Business Source License`_ for you to browse.
8 |
9 | We support a large set of users with enterprise-specific intricacies and
10 | reliability needs. This has required us to be deliberate about setting
11 | the roadmap for SDV libraries. As a result, we are unable to prioritize
12 | reviewing and accepting external pull requests. As a policy, we will
13 | not be able to accept external contributions.
14 |
15 | **Would you like a bug or feature request to be addressed?** If you haven't
16 | already, we would greatly appreciate it if you could `file an issue`_
17 | instead with the overall description of your problem. We can determine
18 | whether it’s aligned with our framework. Once discussed, our team
19 | typically resolves smaller issues within a few release cycles.
20 | We appreciate your understanding.
21 |
22 |
23 | .. _Business Source License: https://github.com/sdv-dev/CTGAN/blob/main/LICENSE
24 | .. _file an issue: https://github.com/sdv-dev/CTGAN/issues
25 |
--------------------------------------------------------------------------------
/HISTORY.md:
--------------------------------------------------------------------------------
1 | # History
2 |
3 | ### v0.11.0 - 2025-02-26
4 |
5 | ### New Features
6 |
7 | * Surface error to user during fit if training data contains null values - Issue [#414](https://github.com/sdv-dev/CTGAN/issues/414) by @rwedge
8 |
9 | ### Maintenance
10 |
11 | * Combine `static_code_analysis.yml` with `release_notes.yml` - Issue [#421](https://github.com/sdv-dev/CTGAN/issues/421) by @R-Palazzo
12 | * Support Python 3.13 - Issue [#411](https://github.com/sdv-dev/CTGAN/issues/411) by @rwedge
13 | * Update codecov and add flag for integration tests - Issue [#410](https://github.com/sdv-dev/CTGAN/issues/410) by @pvk-developer
14 |
15 | ## v0.10.2 - 2024-10-22
16 |
17 | ### Bugs Fixed
18 |
19 | * Cap numpy to less than 2.0.0 until CTGan supports - Issue [#387](https://github.com/sdv-dev/CTGAN/issues/387) by @gsheni
20 | * Redundant whitespace in the demo data - Issue [#233](https://github.com/sdv-dev/CTGAN/issues/233)
21 |
22 | ### Internal
23 |
24 | * Add workflow to generate release notes - Issue [#404](https://github.com/sdv-dev/CTGAN/issues/404) by @amontanez24
25 |
26 | ### Maintenance
27 |
28 | * Switch to using ruff for Python linting and code formatting - Issue [#335](https://github.com/sdv-dev/CTGAN/issues/335) by @gsheni
29 |
30 | ### Miscellaneous
31 |
32 | * Add support for numpy 2.0.0 - Issue [#386](https://github.com/sdv-dev/CTGAN/issues/386) by @R-Palazzo
33 |
34 | # v0.10.1 - 2024-05-13
35 |
36 | This release removes a warning that was cluttering the console.
37 |
38 | ### Maintenance
39 |
40 | * Cleanup automated PR workflows - Issue [#370](https://github.com/sdv-dev/CTGAN/issues/370) by @R-Palazzo
41 | * Only run unit and integration tests on oldest and latest python versions for macos - Issue [#375](https://github.com/sdv-dev/CTGAN/issues/375) by @R-Palazzo
42 |
43 | ### Internal
44 |
45 | * Remove FutureWarning: Setting an item of incompatible dtype is deprecated - Issue [#373](https://github.com/sdv-dev/CTGAN/issues/373) by @fealho
46 |
47 | ## v0.10.0 - 2024-04-11
48 |
49 | This release adds support for Python 3.12!
50 |
51 | ### Maintenance
52 |
53 | * Support Python 3.12 - Issue [#324](https://github.com/sdv-dev/CTGAN/issues/324) by @fealho
54 | * Remove scikit-learn dependency - Issue [#346](https://github.com/sdv-dev/CTGAN/issues/346) by @R-Palazzo
55 | * Add bandit workflow - Issue [#353](https://github.com/sdv-dev/CTGAN/issues/353) by @R-Palazzo
56 |
57 | ### Internal
58 |
59 | * Replace integration test that uses the iris demo data - Issue [#352](https://github.com/sdv-dev/CTGAN/issues/352) by @R-Palazzo
60 |
61 | ### Bugs Fixed
62 |
63 | * Fix minimum version workflow when pointing to github branch - Issue [#355](https://github.com/sdv-dev/CTGAN/issues/355) by @R-Palazzo
64 |
65 | ## v0.9.1 - 2024-03-14
66 |
67 | This release changes the `loss_values` attribute of a CTGAN model to contain floats instead of `torch.Tensors`.
68 |
69 | ### New Features
70 |
71 | * Return loss values as float values not PyTorch objects - Issue [#332](https://github.com/sdv-dev/CTGAN/issues/332) by @fealho
72 |
73 | ### Maintenance
74 |
75 | * Transition from using setup.py to pyproject.toml to specify project metadata - Issue [#333](https://github.com/sdv-dev/CTGAN/issues/333) by @R-Palazzo
76 | * Remove bumpversion and use bump-my-version - Issue [#334](https://github.com/sdv-dev/CTGAN/issues/334) by @R-Palazzo
77 | * Add dependency checker - Issue [#336](https://github.com/sdv-dev/CTGAN/issues/336) by @amontanez24
78 |
79 | ## v0.9.0 - 2024-02-13
80 |
81 | This release makes CTGAN sampling more efficient by saving the frequency of each categorical value.
82 |
83 | ### New Features
84 |
85 | * Improve DataSampler efficiency - Issue [#327] ((https://github.com/sdv-dev/CTGAN/issue/327)) by @fealho
86 |
87 | ## v0.8.0 - 2023-11-13
88 |
89 | This release adds a progress bar that will show when setting the `verbose` parameter to `True`
90 | when initializing `TVAE`.
91 |
92 | ### New Features
93 |
94 | * Add verbosity TVAE (progress bar + save the loss values) - Issue [#300]((https://github.com/sdv-dev/CTGAN/issues/300) by @frances-h
95 |
96 | ## v0.7.5 - 2023-10-05
97 |
98 | This release adds a progress bar that will show when setting the `verbose` parameter to True when initializing `CTGAN`. It also removes a warning that was showing.
99 |
100 | ### Maintenance
101 |
102 | * Remove model_missing_values from ClusterBasedNormalizer call - PR [#310](https://github.com/sdv-dev/CTGAN/pull/310) by @fealho
103 | * Switch default branch from master to main - Issue [#311](https://github.com/sdv-dev/CTGAN/issues/311) by @amontanez24
104 | * Remove or implement CTGAN tests - Issue [#312](https://github.com/sdv-dev/CTGAN/issues/312) by @fealho
105 |
106 | ### New Features
107 |
108 | * Add progress bar for CTGAN fitting (+ save the loss values) - Issue [#298](https://github.com/sdv-dev/CTGAN/issues/298) by @frances-h
109 |
110 | ## v0.7.4 - 2023-07-25
111 |
112 | This release adds support for Python 3.11 and drops support for Python 3.7.
113 |
114 | ### Maintenance
115 |
116 | * Why is there an upper bound in the packaging requirement? (packaging<22) - Issue [#276](https://github.com/sdv-dev/CTGAN/issues/276) by @fealho
117 | * Add support for Python 3.11 - Issue [#296](https://github.com/sdv-dev/CTGAN/issues/296) by @fealho
118 | * Drop support for Python 3.7 - Issue [#302](https://github.com/sdv-dev/CTGAN/issues/302) by @fealho
119 |
120 | ## v0.7.3 - 2023-05-25
121 |
122 | This release adds support for Torch 2.0!
123 |
124 | ### Bugs Fixed
125 |
126 | * Torch 2.0 fails with cuda=False - Issue [#288](https://github.com/sdv-dev/CTGAN/issues/288) by @amontanez24
127 |
128 | ### Maintenance
129 |
130 | * Upgrade to torch 2.0 - Issue [#280](https://github.com/sdv-dev/CTGAN/issues/280) by @frances-h
131 |
132 | ## v0.7.2 - 2023-05-09
133 |
134 | This release adds support for Pandas 2.0! It also fixes a bug in the `load_demo` function.
135 |
136 | ### Bugs Fixed
137 |
138 | * load_demo raises urllib.error.HTTPError: HTTP Error 403: Forbidden - Issue [#284](https://github.com/sdv-dev/CTGAN/issues/284) by @amontanez24
139 |
140 | ### Maintenance
141 |
142 | * Remove upper bound for pandas - Issue [#282](https://github.com/sdv-dev/CTGAN/issues/282) by @frances-h
143 |
144 | ## v0.7.1 - 2023-02-23
145 |
146 | This release fixes a bug that prevented the `CTGAN` model from being saved after sampling.
147 |
148 | ### Bugs Fixed
149 |
150 | * Cannot save CTGANSynthesizer after sampling (TypeError) - Issue [#270](https://github.com/sdv-dev/CTGAN/issues/270) by @pvk-developer
151 |
152 | ## v0.7.0 - 2023-01-20
153 |
154 | This release adds support for python 3.10 and drops support for python 3.6. It also fixes a couple of the most common warnings that were surfacing.
155 |
156 | ### New Features
157 |
158 | * Support Python 3.10 and 3.11 - Issue [#259](https://github.com/sdv-dev/CTGAN/issues/259) by @pvk-developer
159 |
160 | ### Bugs Fixed
161 |
162 | * Fix SettingWithCopyWarning (may be leading to a numerical calculation bug) - Issue [#215](https://github.com/sdv-dev/CTGAN/issues/215) by @amontanez24
163 | * FutureWarning in data_transformer with pandas 1.5.0 - Issue [#246](https://github.com/sdv-dev/CTGAN/issues/246) by @amontanez24
164 |
165 | ### Maintenance
166 |
167 | * CTGAN Package Maintenance Updates - Issue [#257](https://github.com/sdv-dev/CTGAN/issues/257) by @amontanez24
168 |
169 | ## v0.6.0 - 2022-10-07
170 |
171 | This release renames the models in CTGAN. `CTGANSynthesizer` is now called `CTGAN` and `TVAESynthesizer` is now called `TVAE`.
172 |
173 | ### New Features
174 |
175 | * Rename synthesizers - Issue [#243](https://github.com/sdv-dev/CTGAN/issues/243) by @amontanez24
176 |
177 | ## v0.5.2 - 2022-08-18
178 |
179 | This release updates CTGAN to use the latest version of RDT. It also includes performance and robustness updates to the data transformer.
180 |
181 | ### Issues closed
182 | * Bump rdt version - Issue [#242](https://github.com/sdv-dev/CTGAN/issues/242) by @katxiao
183 | * Single thread data transform is slow for huge table - Issue [#151](https://github.com/sdv-dev/CTGAN/issues/151) by @mfhbree
184 | * Fix RDT api - Issue [#232](https://github.com/sdv-dev/CTGAN/issues/232) by @pvk-developer
185 | * Update macos to use latest version. - Issue [#237](https://github.com/sdv-dev/CTGAN/issues/237) by @pvk-developer
186 | * Update the RDT version to 1.0 - Issue [#224](https://github.com/sdv-dev/CTGAN/issues/224) by @pvk-developer
187 | * Update slack invite link. - Issue [#222](https://github.com/sdv-dev/CTGAN/issues/222) by @pvk-developer
188 | * robustness fix, when data have less rows than the default number of cl… - Issue [#211](https://github.com/sdv-dev/CTGAN/issues/211) by @Deathn0t
189 |
190 | ## v0.5.1 - 2022-02-25
191 |
192 | This release fixes a bug with the decoder instantiation, and also allows users to set a random state for the model
193 | fitting and sampling.
194 |
195 | ### Issues closed
196 |
197 | * Update self.decoder with correct variable name - Issue [#203](https://github.com/sdv-dev/CTGAN/issues/203) by @tejuafonja
198 | * Add random state - Issue [#204](https://github.com/sdv-dev/CTGAN/issues/204) by @katxiao
199 |
200 | ## v0.5.0 - 2021-11-18
201 |
202 | This release adds support for Python 3.9 and updates dependencies to ensure compatibility with the
203 | rest of the SDV ecosystem, and upgrades to the latests [RDT](https://github.com/sdv-dev/RDT/releases/tag/v0.6.1)
204 | release.
205 |
206 | ### Issues closed
207 |
208 | * Add support for Python 3.9 - Issue [#177](https://github.com/sdv-dev/CTGAN/issues/177) by @pvk-developer
209 | * Add pip check to CI workflows - Issue [#174](https://github.com/sdv-dev/CTGAN/issues/174) by @pvk-developer
210 | * Typo in `CTGAN` code - Issue [#158](https://github.com/sdv-dev/CTGAN/issues/158) by @ori-katz100 and @fealho
211 |
212 | ## v0.4.3 - 2021-07-12
213 |
214 | Dependency upgrades to ensure compatibility with the rest of the SDV ecosystem.
215 |
216 | ## v0.4.2 - 2021-04-27
217 |
218 | In this release, the way in which the loss function of the TVAE model was computed has been fixed.
219 | In addition, the default value of the `discriminator_decay` has been changed to a more optimal
220 | value. Also some improvements to the tests were added.
221 |
222 | ### Issues closed
223 |
224 | * `TVAE`: loss function - Issue [#143](https://github.com/sdv-dev/CTGAN/issues/143) by @fealho and @DingfanChen
225 | * Set `discriminator_decay` to `1e-6` - Pull request [#145](https://github.com/sdv-dev/CTGAN/pull/145/) by @fealho
226 | * Adds unit tests - Pull requests [#140](https://github.com/sdv-dev/CTGAN/pull/140) by @fealho
227 |
228 | ## v0.4.1 - 2021-03-30
229 |
230 | This release exposes all the hyperparameters which the user may find useful for both `CTGAN`
231 | and `TVAE`. Also `TVAE` can now be fitted on datasets that are shorter than the batch
232 | size and drops the last batch only if the data size is not divisible by the batch size.
233 |
234 | ### Issues closed
235 |
236 | * `TVAE`: Adapt `batch_size` to data size - Issue [#135](https://github.com/sdv-dev/CTGAN/issues/135) by @fealho and @csala
237 | * `ValueError` from `validate_discre_columns` with `uniqueCombinationConstraint` - Issue [133](https://github.com/sdv-dev/CTGAN/issues/133) by @fealho and @MLjungg
238 |
239 | ## v0.4.0 - 2021-02-24
240 |
241 | Maintenance relese to upgrade dependencies to ensure compatibility with the rest
242 | of the SDV libraries.
243 |
244 | Also add a validation on the CTGAN `condition_column` and `condition_value` inputs.
245 |
246 | ### Improvements
247 |
248 | * Validate condition_column and condition_value - Issue [#124](https://github.com/sdv-dev/CTGAN/issues/124) by @fealho
249 |
250 | ## v0.3.1 - 2021-01-27
251 |
252 | ### Improvements
253 |
254 | * Check discrete_columns valid before fitting - [Issue #35](https://github.com/sdv-dev/CTGAN/issues/35) by @fealho
255 |
256 | ## Bugs fixed
257 |
258 | * ValueError: max() arg is an empty sequence - [Issue #115](https://github.com/sdv-dev/CTGAN/issues/115) by @fealho
259 |
260 | ## v0.3.0 - 2020-12-18
261 |
262 | In this release we add a new TVAE model which was presented in the original CTGAN paper.
263 | It also exposes more hyperparameters and moves epochs and log_frequency from fit to the constructor.
264 |
265 | A new verbose argument has been added to optionally disable unnecessary printing, and a new hyperparameter
266 | called `discriminator_steps` has been added to CTGAN to control the number of optimization steps performed
267 | in the discriminator for each generator epoch.
268 |
269 | The code has also been reorganized and cleaned up for better readability and interpretability.
270 |
271 | Special thanks to @Baukebrenninkmeijer @fealho @leix28 @csala for the contributions!
272 |
273 | ### Improvements
274 |
275 | * Add TVAE - [Issue #111](https://github.com/sdv-dev/CTGAN/issues/111) by @fealho
276 | * Move `log_frequency` to `__init__` - [Issue #102](https://github.com/sdv-dev/CTGAN/issues/102) by @fealho
277 | * Add discriminator steps hyperparameter - [Issue #101](https://github.com/sdv-dev/CTGAN/issues/101) by @Baukebrenninkmeijer
278 | * Code cleanup / Expose hyperparameters - [Issue #59](https://github.com/sdv-dev/CTGAN/issues/59) by @fealho and @leix28
279 | * Publish to conda repo - [Issue #54](https://github.com/sdv-dev/CTGAN/issues/54) by @fealho
280 |
281 | ### Bugs fixed
282 |
283 | * Fixed NaN != NaN counting bug. - [Issue #100](https://github.com/sdv-dev/CTGAN/issues/100) by @fealho
284 | * Update dependencies and testing - [Issue #90](https://github.com/sdv-dev/CTGAN/issues/90) by @csala
285 |
286 | ## v0.2.2 - 2020-11-13
287 |
288 | In this release we introduce several minor improvements to make CTGAN more versatile and
289 | propertly support new types of data, such as categorical NaN values, as well as conditional
290 | sampling and features to save and load models.
291 |
292 | Additionally, the dependency ranges and python versions have been updated to support up
293 | to date runtimes.
294 |
295 | Many thanks @fealho @leix28 @csala @oregonpillow and @lurosenb for working on making this release possible!
296 |
297 | ### Improvements
298 |
299 | * Drop Python 3.5 support - [Issue #79](https://github.com/sdv-dev/CTGAN/issues/79) by @fealho
300 | * Support NaN values in categorical variables - [Issue #78](https://github.com/sdv-dev/CTGAN/issues/78) by @fealho
301 | * Sample synthetic data conditioning on a discrete column - [Issue #69](https://github.com/sdv-dev/CTGAN/issues/69) by @leix28
302 | * Support recent versions of pandas - [Issue #57](https://github.com/sdv-dev/CTGAN/issues/57) by @csala
303 | * Easy solution for restoring original dtypes - [Issue #26](https://github.com/sdv-dev/CTGAN/issues/26) by @oregonpillow
304 |
305 | ### Bugs fixed
306 |
307 | * Loss to nan - [Issue #73](https://github.com/sdv-dev/CTGAN/issues/73) by @fealho
308 | * Swapped the sklearn utils testing import statement - [Issue #53](https://github.com/sdv-dev/CTGAN/issues/53) by @lurosenb
309 |
310 | ## v0.2.1 - 2020-01-27
311 |
312 | Minor version including changes to ensure the logs are properly printed and
313 | the option to disable the log transformation to the discrete column frequencies.
314 |
315 | Special thanks to @kevinykuo for the contributions!
316 |
317 | ### Issues Resolved:
318 |
319 | * Option to sample from true data frequency instead of logged frequency - [Issue #16](https://github.com/sdv-dev/CTGAN/issues/16) by @kevinykuo
320 | * Flush stdout buffer for epoch updates - [Issue #14](https://github.com/sdv-dev/CTGAN/issues/14) by @kevinykuo
321 |
322 | ## v0.2.0 - 2019-12-18
323 |
324 | Reorganization of the project structure with a new Python API, new Command Line Interface
325 | and increased data format support.
326 |
327 | ### Issues Resolved:
328 |
329 | * Reorganize the project structure - [Issue #10](https://github.com/sdv-dev/CTGAN/issues/10) by @csala
330 | * Move epochs to the fit method - [Issue #5](https://github.com/sdv-dev/CTGAN/issues/5) by @csala
331 |
332 | ## v0.1.0 - 2019-11-07
333 |
334 | First Release - NeurIPS 2019 Version.
335 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Business Source License 1.1
3 |
4 | Parameters
5 |
6 | Licensor: DataCebo, Inc.
7 |
8 | Licensed Work: CTGAN
9 | The Licensed Work is (c) DataCebo, Inc.
10 |
11 | Additional Use Grant: You may make use of the Licensed Work, and derivatives of the Licensed
12 | Work, provided that you do not use the Licensed Work, or derivatives of
13 | the Licensed Work, for a Synthetic Data Creation Service.
14 |
15 | A "Synthetic Data Creation Service" is a commercial offering
16 | that allows third parties (other than your employees and
17 | contractors) to access the functionality of the Licensed
18 | Work so that such third parties directly benefit from the
19 | data processing, machine learning or synthetic data creation
20 | features of the Licensed Work.
21 |
22 | Change Date: Change date is four years from release date.
23 | Please see https://github.com/sdv-dev/CTGAN/releases
24 | for exact dates.
25 |
26 | Change License: MIT License
27 |
28 |
29 | Notice
30 |
31 | The Business Source License (this document, or the "License") is not an Open
32 | Source license. However, the Licensed Work will eventually be made available
33 | under an Open Source License, as stated in this License.
34 |
35 | License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved.
36 | "Business Source License" is a trademark of MariaDB Corporation Ab.
37 |
38 | -----------------------------------------------------------------------------
39 |
40 | Business Source License 1.1
41 |
42 | Terms
43 |
44 | The Licensor hereby grants you the right to copy, modify, create derivative
45 | works, redistribute, and make non-production use of the Licensed Work. The
46 | Licensor may make an Additional Use Grant, above, permitting limited
47 | production use.
48 |
49 | Effective on the Change Date, or the fourth anniversary of the first publicly
50 | available distribution of a specific version of the Licensed Work under this
51 | License, whichever comes first, the Licensor hereby grants you rights under
52 | the terms of the Change License, and the rights granted in the paragraph
53 | above terminate.
54 |
55 | If your use of the Licensed Work does not comply with the requirements
56 | currently in effect as described in this License, you must purchase a
57 | commercial license from the Licensor, its affiliated entities, or authorized
58 | resellers, or you must refrain from using the Licensed Work.
59 |
60 | All copies of the original and modified Licensed Work, and derivative works
61 | of the Licensed Work, are subject to this License. This License applies
62 | separately for each version of the Licensed Work and the Change Date may vary
63 | for each version of the Licensed Work released by Licensor.
64 |
65 | You must conspicuously display this License on each original or modified copy
66 | of the Licensed Work. If you receive the Licensed Work in original or
67 | modified form from a third party, the terms and conditions set forth in this
68 | License apply to your use of that work.
69 |
70 | Any use of the Licensed Work in violation of this License will automatically
71 | terminate your rights under this License for the current and all other
72 | versions of the Licensed Work.
73 |
74 | This License does not grant you any right in any trademark or logo of
75 | Licensor or its affiliates (provided that you may use a trademark or logo of
76 | Licensor as expressly required by this License).
77 |
78 | TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON
79 | AN "AS IS" BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS,
80 | EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF
81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND
82 | TITLE.
83 |
84 | MariaDB hereby grants you permission to use this License’s text to license
85 | your works, and to refer to it using the trademark "Business Source License",
86 | as long as you comply with the Covenants of Licensor below.
87 |
88 | Covenants of Licensor
89 |
90 | In consideration of the right to use this License’s text and the "Business
91 | Source License" name and trademark, Licensor covenants to MariaDB, and to all
92 | other recipients of the licensed work to be provided by Licensor:
93 |
94 | 1. To specify as the Change License the GPL Version 2.0 or any later version,
95 | or a license that is compatible with GPL Version 2.0 or a later version,
96 | where "compatible" means that software provided under the Change License can
97 | be included in a program with software provided under GPL Version 2.0 or a
98 | later version. Licensor may specify additional Change Licenses without
99 | limitation.
100 |
101 | 2. To either: (a) specify an additional grant of rights to use that does not
102 | impose any additional restriction on the right granted in this License, as
103 | the Additional Use Grant; or (b) insert the text "None".
104 |
105 | 3. To specify a Change Date.
106 |
107 | 4. Not to modify this License in any other way.
108 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .DEFAULT_GOAL := help
2 |
3 | define BROWSER_PYSCRIPT
4 | import os, webbrowser, sys
5 |
6 | try:
7 | from urllib import pathname2url
8 | except:
9 | from urllib.request import pathname2url
10 |
11 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
12 | endef
13 | export BROWSER_PYSCRIPT
14 |
15 | define PRINT_HELP_PYSCRIPT
16 | import re, sys
17 |
18 | for line in sys.stdin:
19 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
20 | if match:
21 | target, help = match.groups()
22 | print("%-20s %s" % (target, help))
23 | endef
24 | export PRINT_HELP_PYSCRIPT
25 |
26 | BROWSER := python -c "$$BROWSER_PYSCRIPT"
27 |
28 | .PHONY: help
29 | help:
30 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
31 |
32 |
33 | # CLEAN TARGETS
34 |
35 | .PHONY: clean-build
36 | clean-build: ## remove build artifacts
37 | rm -fr build/
38 | rm -fr dist/
39 | rm -fr .eggs/
40 | find . -name '*.egg-info' -exec rm -fr {} +
41 | find . -name '*.egg' -exec rm -f {} +
42 |
43 | .PHONY: clean-pyc
44 | clean-pyc: ## remove Python file artifacts
45 | find . -name '*.pyc' -exec rm -f {} +
46 | find . -name '*.pyo' -exec rm -f {} +
47 | find . -name '*~' -exec rm -f {} +
48 | find . -name '__pycache__' -exec rm -fr {} +
49 |
50 | .PHONY: clean-coverage
51 | clean-coverage: ## remove coverage artifacts
52 | rm -f .coverage
53 | rm -f .coverage.*
54 | rm -fr htmlcov/
55 |
56 | .PHONY: clean-test
57 | clean-test: ## remove test artifacts
58 | rm -fr .tox/
59 | rm -fr .pytest_cache
60 |
61 | .PHONY: clean
62 | clean: clean-build clean-pyc clean-test clean-coverage ## remove all build, test, coverage and Python artifacts
63 |
64 |
65 | # INSTALL TARGETS
66 |
67 | .PHONY: install
68 | install: clean-build clean-pyc ## install the package to the active Python's site-packages
69 | pip install .
70 |
71 | .PHONY: install-test
72 | install-test: clean-build clean-pyc ## install the package and test dependencies
73 | pip install .[test]
74 |
75 | .PHONY: install-develop
76 | install-develop: clean-build clean-pyc ## install the package in editable mode and dependencies for development
77 | pip install -e .[dev]
78 |
79 |
80 | # LINT TARGETS
81 |
82 | .PHONY: lint
83 | lint:
84 | invoke lint
85 |
86 | .PHONY: fix-lint
87 | fix-lint:
88 | invoke fix-lint
89 |
90 |
91 | # TEST TARGETS
92 |
93 | .PHONY: test-unit
94 | test-unit: ## run unit tests using pytest
95 | invoke unit
96 |
97 | .PHONY: test-integration
98 | test-integration: ## run integration tests using pytest
99 | invoke integration
100 |
101 | .PHONY: test-readme
102 | test-readme: ## run the readme snippets
103 | invoke readme
104 |
105 | .PHONY: check-dependencies
106 | check-dependencies: ## test if there are any broken dependencies
107 | pip check
108 |
109 | .PHONY: test
110 | test: test-unit test-integration test-readme ## test everything that needs test dependencies
111 |
112 | .PHONY: test-devel
113 | test-devel: lint ## test everything that needs development dependencies
114 |
115 | .PHONY: test-all
116 | test-all: ## run tests on every Python version with tox
117 | tox -r
118 |
119 |
120 | .PHONY: coverage
121 | coverage: ## check code coverage quickly with the default Python
122 | coverage run --source ctgan -m pytest
123 | coverage report -m
124 | coverage html
125 | $(BROWSER) htmlcov/index.html
126 |
127 |
128 | # RELEASE TARGETS
129 |
130 | .PHONY: dist
131 | dist: clean ## builds source and wheel package
132 | python -m build --wheel --sdist
133 | ls -l dist
134 |
135 | .PHONY: publish-confirm
136 | publish-confirm:
137 | @echo "WARNING: This will irreversibly upload a new version to PyPI!"
138 | @echo -n "Please type 'confirm' to proceed: " \
139 | && read answer \
140 | && [ "$${answer}" = "confirm" ]
141 |
142 | .PHONY: publish-test
143 | publish-test: dist publish-confirm ## package and upload a release on TestPyPI
144 | twine upload --repository-url https://test.pypi.org/legacy/ dist/*
145 |
146 | .PHONY: publish
147 | publish: dist publish-confirm ## package and upload a release
148 | twine upload dist/*
149 |
150 | .PHONY: bumpversion-release
151 | bumpversion-release: ## Merge main to stable and bumpversion release
152 | git checkout stable || git checkout -b stable
153 | git merge --no-ff main -m"make release-tag: Merge branch 'main' into stable"
154 | bump-my-version bump release
155 | git push --tags origin stable
156 |
157 | .PHONY: bumpversion-release-test
158 | bumpversion-release-test: ## Merge main to stable and bumpversion release
159 | git checkout stable || git checkout -b stable
160 | git merge --no-ff main -m"make release-tag: Merge branch 'main' into stable"
161 | bump-my-version bump release --no-tag
162 | @echo git push --tags origin stable
163 |
164 | .PHONY: bumpversion-patch
165 | bumpversion-patch: ## Merge stable to main and bumpversion patch
166 | git checkout main
167 | git merge stable
168 | bump-my-version bump --no-tag patch
169 | git push
170 |
171 | .PHONY: bumpversion-candidate
172 | bumpversion-candidate: ## Bump the version to the next candidate
173 | bump-my-version bump candidate --no-tag
174 |
175 | .PHONY: bumpversion-minor
176 | bumpversion-minor: ## Bump the version the next minor skipping the release
177 | bump-my-version bump --no-tag minor
178 |
179 | .PHONY: bumpversion-major
180 | bumpversion-major: ## Bump the version the next major skipping the release
181 | bump-my-version bump --no-tag major
182 |
183 | .PHONY: bumpversion-revert
184 | bumpversion-revert: ## Undo a previous bumpversion-release
185 | git checkout main
186 | git branch -D stable
187 |
188 | CLEAN_DIR := $(shell git status --short | grep -v ??)
189 | CURRENT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD 2>/dev/null)
190 | CHANGELOG_LINES := $(shell git diff HEAD..origin/stable HISTORY.md 2>&1 | wc -l)
191 |
192 | .PHONY: check-clean
193 | check-clean: ## Check if the directory has uncommitted changes
194 | ifneq ($(CLEAN_DIR),)
195 | $(error There are uncommitted changes)
196 | endif
197 |
198 | .PHONY: check-main
199 | check-main: ## Check if we are in main branch
200 | ifneq ($(CURRENT_BRANCH),main)
201 | $(error Please make the release from main branch\n)
202 | endif
203 |
204 | .PHONY: check-history
205 | check-history: ## Check if HISTORY.md has been modified
206 | ifeq ($(CHANGELOG_LINES),0)
207 | $(error Please insert the release notes in HISTORY.md before releasing)
208 | endif
209 |
210 | .PHONY: git-push
211 | git-push: ## Simply push the repository to github
212 | git push
213 |
214 | .PHONY: check-release
215 | check-release: check-clean check-main check-history ## Check if the release can be made
216 | @echo "A new release can be made"
217 |
218 | .PHONY: release
219 | release: check-release bumpversion-release publish bumpversion-patch
220 |
221 | .PHONY: release-test
222 | release-test: check-release bumpversion-release-test publish-test bumpversion-revert
223 |
224 | .PHONY: release-candidate
225 | release-candidate: check-main publish bumpversion-candidate git-push
226 |
227 | .PHONY: release-candidate-test
228 | release-candidate-test: check-clean check-main publish-test
229 |
230 | .PHONY: release-minor
231 | release-minor: check-release bumpversion-minor release
232 |
233 | .PHONY: release-major
234 | release-major: check-release bumpversion-major release
235 |
236 | # Dependency targets
237 |
238 | .PHONY: check-deps
239 | check-deps:
240 | $(eval allow_list='numpy=|pandas=|tqdm=|torch=|rdt=')
241 | pip freeze | grep -v "CTGAN.git" | grep -E $(allow_list) > $(OUTPUT_FILEPATH)
242 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | This repository is part of The Synthetic Data Vault Project, a project from DataCebo.
5 |
6 |
7 | [](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
8 | [](https://pypi.python.org/pypi/ctgan)
9 | [](https://github.com/sdv-dev/CTGAN/actions/workflows/unit.yml)
10 | [](https://pepy.tech/project/ctgan)
11 | [](https://codecov.io/gh/sdv-dev/CTGAN)
12 |
13 |
21 |
22 |
23 |
24 | # Overview
25 |
26 | CTGAN is a collection of Deep Learning based synthetic data generators for single table data, which are able to learn from real data and generate synthetic data with high fidelity.
27 |
28 | | Important Links | |
29 | | --------------------------------------------- | -------------------------------------------------------------------- |
30 | | :computer: **[Website]** | Check out the SDV Website for more information about our overall synthetic data ecosystem.|
31 | | :orange_book: **[Blog]** | A deeper look at open source, synthetic data creation and evaluation.|
32 | | :book: **[Documentation]** | Quickstarts, User and Development Guides, and API Reference. |
33 | | :octocat: **[Repository]** | The link to the Github Repository of this library. |
34 | | :keyboard: **[Development Status]** | This software is in its Pre-Alpha stage. |
35 | | [![][Slack Logo] **Community**][Community] | Join our Slack Workspace for announcements and discussions. |
36 |
37 | [Website]: https://sdv.dev
38 | [Blog]: https://datacebo.com/blog
39 | [Documentation]: https://bit.ly/sdv-docs
40 | [Repository]: https://github.com/sdv-dev/CTGAN
41 | [License]: https://github.com/sdv-dev/CTGAN/blob/main/LICENSE
42 | [Development Status]: https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha
43 | [Slack Logo]: https://github.com/sdv-dev/SDV/blob/stable/docs/images/slack.png
44 | [Community]: https://bit.ly/sdv-slack-invite
45 |
46 | Currently, this library implements the **CTGAN** and **TVAE** models described in the [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503) paper, presented at the 2019 NeurIPS conference.
47 |
48 | # Install
49 |
50 | ## Use CTGAN through the SDV library
51 |
52 | :warning: If you're just getting started with synthetic data, we recommend installing the SDV library which provides user-friendly APIs for accessing CTGAN. :warning:
53 |
54 | The SDV library provides wrappers for preprocessing your data as well as additional usability features like constraints. See the [SDV documentation](https://bit.ly/sdv-docs) to get started.
55 |
56 | ## Use the CTGAN standalone library
57 |
58 | Alternatively, you can also install and use **CTGAN** directly, as a standalone library:
59 |
60 | **Using `pip`:**
61 |
62 | ```bash
63 | pip install ctgan
64 | ```
65 |
66 | **Using `conda`:**
67 |
68 | ```bash
69 | conda install -c pytorch -c conda-forge ctgan
70 | ```
71 |
72 | When using the CTGAN library directly, you may need to manually preprocess your data into the correct format, for example:
73 |
74 | * Continuous data must be represented as floats
75 | * Discrete data must be represented as ints or strings
76 | * The data should not contain any missing values
77 |
78 | # Usage Example
79 |
80 | In this example we load the [Adult Census Dataset](https://archive.ics.uci.edu/ml/datasets/adult)* which is a built-in demo dataset. We use CTGAN to learn from the real data and then generate some synthetic data.
81 |
82 | ```python3
83 | from ctgan import CTGAN
84 | from ctgan import load_demo
85 |
86 | real_data = load_demo()
87 |
88 | # Names of the columns that are discrete
89 | discrete_columns = [
90 | 'workclass',
91 | 'education',
92 | 'marital-status',
93 | 'occupation',
94 | 'relationship',
95 | 'race',
96 | 'sex',
97 | 'native-country',
98 | 'income'
99 | ]
100 |
101 | ctgan = CTGAN(epochs=10)
102 | ctgan.fit(real_data, discrete_columns)
103 |
104 | # Create synthetic data
105 | synthetic_data = ctgan.sample(1000)
106 | ```
107 |
108 | *For more information about the dataset see:
109 | Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml].
110 | Irvine, CA: University of California, School of Information and Computer Science.
111 |
112 | # Join our community
113 |
114 | Join our [Slack channel](https://bit.ly/sdv-slack-invite) to discuss more about CTGAN and synthetic data. If you find a bug or have a feature request, you can also [open an issue](https://github.com/sdv-dev/CTGAN/issues) on our GitHub.
115 |
116 | **Interested in contributing to CTGAN?** Read our [Contribution Guide](CONTRIBUTING.rst) to get started.
117 |
118 | # Citing CTGAN
119 |
120 | If you use CTGAN, please cite the following work:
121 |
122 | *Lei Xu, Maria Skoularidou, Alfredo Cuesta-Infante, Kalyan Veeramachaneni.* **Modeling Tabular data using Conditional GAN**. NeurIPS, 2019.
123 |
124 | ```LaTeX
125 | @inproceedings{ctgan,
126 | title={Modeling Tabular data using Conditional GAN},
127 | author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
128 | booktitle={Advances in Neural Information Processing Systems},
129 | year={2019}
130 | }
131 | ```
132 |
133 | # Related Projects
134 | Please note that these projects are external to the SDV Ecosystem. They are not affiliated with or maintained by DataCebo.
135 |
136 | * **R Interface for CTGAN**: A wrapper around **CTGAN** that brings the functionalities to **R** users.
137 | More details can be found in the corresponding repository: https://github.com/kasaai/ctgan
138 | * **CTGAN Server CLI**: A package to easily deploy CTGAN onto a remote server. Created by Timothy Pillow @oregonpillow at: https://github.com/oregonpillow/ctgan-server-cli
139 |
140 | ---
141 |
142 |
143 |
144 |

145 |
146 |
147 |
148 |
149 | [The Synthetic Data Vault Project](https://sdv.dev) was first created at MIT's [Data to AI Lab](
150 | https://dai.lids.mit.edu/) in 2016. After 4 years of research and traction with enterprise, we
151 | created [DataCebo](https://datacebo.com) in 2020 with the goal of growing the project.
152 | Today, DataCebo is the proud developer of SDV, the largest ecosystem for
153 | synthetic data generation & evaluation. It is home to multiple libraries that support synthetic
154 | data, including:
155 |
156 | * 🔄 Data discovery & transformation. Reverse the transforms to reproduce realistic data.
157 | * 🧠 Multiple machine learning models -- ranging from Copulas to Deep Learning -- to create tabular,
158 | multi table and time series data.
159 | * 📊 Measuring quality and privacy of synthetic data, and comparing different synthetic data
160 | generation models.
161 |
162 | [Get started using the SDV package](https://sdv.dev/SDV/getting_started/install.html) -- a fully
163 | integrated solution and your one-stop shop for synthetic data. Or, use the standalone libraries
164 | for specific needs.
165 |
--------------------------------------------------------------------------------
/ctgan/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Top-level package for ctgan."""
4 |
5 | __author__ = 'DataCebo, Inc.'
6 | __email__ = 'info@sdv.dev'
7 | __version__ = '0.11.1.dev0'
8 |
9 | from ctgan.demo import load_demo
10 | from ctgan.synthesizers.ctgan import CTGAN
11 | from ctgan.synthesizers.tvae import TVAE
12 |
13 | __all__ = ('CTGAN', 'TVAE', 'load_demo')
14 |
--------------------------------------------------------------------------------
/ctgan/__main__.py:
--------------------------------------------------------------------------------
1 | """CLI."""
2 |
3 | import argparse
4 |
5 | from ctgan.data import read_csv, read_tsv, write_tsv
6 | from ctgan.synthesizers.ctgan import CTGAN
7 |
8 |
9 | def _parse_args():
10 | parser = argparse.ArgumentParser(description='CTGAN Command Line Interface')
11 | parser.add_argument('-e', '--epochs', default=300, type=int, help='Number of training epochs')
12 | parser.add_argument(
13 | '-t', '--tsv', action='store_true', help='Load data in TSV format instead of CSV'
14 | )
15 | parser.add_argument(
16 | '--no-header',
17 | dest='header',
18 | action='store_false',
19 | help='The CSV file has no header. Discrete columns will be indices.',
20 | )
21 |
22 | parser.add_argument('-m', '--metadata', help='Path to the metadata')
23 | parser.add_argument(
24 | '-d', '--discrete', help='Comma separated list of discrete columns without whitespaces.'
25 | )
26 | parser.add_argument(
27 | '-n',
28 | '--num-samples',
29 | type=int,
30 | help='Number of rows to sample. Defaults to the training data size',
31 | )
32 |
33 | parser.add_argument(
34 | '--generator_lr', type=float, default=2e-4, help='Learning rate for the generator.'
35 | )
36 | parser.add_argument(
37 | '--discriminator_lr', type=float, default=2e-4, help='Learning rate for the discriminator.'
38 | )
39 |
40 | parser.add_argument(
41 | '--generator_decay', type=float, default=1e-6, help='Weight decay for the generator.'
42 | )
43 | parser.add_argument(
44 | '--discriminator_decay', type=float, default=0, help='Weight decay for the discriminator.'
45 | )
46 |
47 | parser.add_argument(
48 | '--embedding_dim', type=int, default=128, help='Dimension of input z to the generator.'
49 | )
50 | parser.add_argument(
51 | '--generator_dim',
52 | type=str,
53 | default='256,256',
54 | help='Dimension of each generator layer. Comma separated integers with no whitespaces.',
55 | )
56 | parser.add_argument(
57 | '--discriminator_dim',
58 | type=str,
59 | default='256,256',
60 | help='Dimension of each discriminator layer. Comma separated integers with no whitespaces.',
61 | )
62 |
63 | parser.add_argument(
64 | '--batch_size', type=int, default=500, help='Batch size. Must be an even number.'
65 | )
66 | parser.add_argument(
67 | '--save', default=None, type=str, help='A filename to save the trained synthesizer.'
68 | )
69 | parser.add_argument(
70 | '--load', default=None, type=str, help='A filename to load a trained synthesizer.'
71 | )
72 |
73 | parser.add_argument(
74 | '--sample_condition_column', default=None, type=str, help='Select a discrete column name.'
75 | )
76 | parser.add_argument(
77 | '--sample_condition_column_value',
78 | default=None,
79 | type=str,
80 | help='Specify the value of the selected discrete column.',
81 | )
82 |
83 | parser.add_argument('data', help='Path to training data')
84 | parser.add_argument('output', help='Path of the output file')
85 |
86 | return parser.parse_args()
87 |
88 |
89 | def main():
90 | """CLI."""
91 | args = _parse_args()
92 | if args.tsv:
93 | data, discrete_columns = read_tsv(args.data, args.metadata)
94 | else:
95 | data, discrete_columns = read_csv(args.data, args.metadata, args.header, args.discrete)
96 |
97 | if args.load:
98 | model = CTGAN.load(args.load)
99 | else:
100 | generator_dim = [int(x) for x in args.generator_dim.split(',')]
101 | discriminator_dim = [int(x) for x in args.discriminator_dim.split(',')]
102 | model = CTGAN(
103 | embedding_dim=args.embedding_dim,
104 | generator_dim=generator_dim,
105 | discriminator_dim=discriminator_dim,
106 | generator_lr=args.generator_lr,
107 | generator_decay=args.generator_decay,
108 | discriminator_lr=args.discriminator_lr,
109 | discriminator_decay=args.discriminator_decay,
110 | batch_size=args.batch_size,
111 | epochs=args.epochs,
112 | )
113 | model.fit(data, discrete_columns)
114 |
115 | if args.save is not None:
116 | model.save(args.save)
117 |
118 | num_samples = args.num_samples or len(data)
119 |
120 | if args.sample_condition_column is not None:
121 | assert args.sample_condition_column_value is not None
122 |
123 | sampled = model.sample(
124 | num_samples, args.sample_condition_column, args.sample_condition_column_value
125 | )
126 |
127 | if args.tsv:
128 | write_tsv(sampled, args.metadata, args.output)
129 | else:
130 | sampled.to_csv(args.output, index=False)
131 |
132 |
133 | if __name__ == '__main__':
134 | main()
135 |
--------------------------------------------------------------------------------
/ctgan/data.py:
--------------------------------------------------------------------------------
1 | """Data loading."""
2 |
3 | import json
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 |
9 | def read_csv(csv_filename, meta_filename=None, header=True, discrete=None):
10 | """Read a csv file."""
11 | data = pd.read_csv(csv_filename, header='infer' if header else None)
12 |
13 | if meta_filename:
14 | with open(meta_filename) as meta_file:
15 | metadata = json.load(meta_file)
16 |
17 | discrete_columns = [
18 | column['name'] for column in metadata['columns'] if column['type'] != 'continuous'
19 | ]
20 |
21 | elif discrete:
22 | discrete_columns = discrete.split(',')
23 | if not header:
24 | discrete_columns = [int(i) for i in discrete_columns]
25 |
26 | else:
27 | discrete_columns = []
28 |
29 | return data, discrete_columns
30 |
31 |
32 | def read_tsv(data_filename, meta_filename):
33 | """Read a tsv file."""
34 | with open(meta_filename) as f:
35 | column_info = f.readlines()
36 |
37 | column_info_raw = [x.replace('{', ' ').replace('}', ' ').split() for x in column_info]
38 |
39 | discrete = []
40 | continuous = []
41 | column_info = []
42 |
43 | for idx, item in enumerate(column_info_raw):
44 | if item[0] == 'C':
45 | continuous.append(idx)
46 | column_info.append((float(item[1]), float(item[2])))
47 | else:
48 | assert item[0] == 'D'
49 | discrete.append(idx)
50 | column_info.append(item[1:])
51 |
52 | meta = {
53 | 'continuous_columns': continuous,
54 | 'discrete_columns': discrete,
55 | 'column_info': column_info,
56 | }
57 |
58 | with open(data_filename) as f:
59 | lines = f.readlines()
60 |
61 | data = []
62 | for row in lines:
63 | row_raw = row.split()
64 | row = []
65 | for idx, col in enumerate(row_raw):
66 | if idx in continuous:
67 | row.append(col)
68 | else:
69 | assert idx in discrete
70 | row.append(column_info[idx].index(col))
71 |
72 | data.append(row)
73 |
74 | return np.asarray(data, dtype='float32'), meta['discrete_columns']
75 |
76 |
77 | def write_tsv(data, meta, output_filename):
78 | """Write to a tsv file."""
79 | with open(output_filename, 'w') as f:
80 | for row in data:
81 | for idx, col in enumerate(row):
82 | if idx in meta['continuous_columns']:
83 | print(col, end=' ', file=f)
84 | else:
85 | assert idx in meta['discrete_columns']
86 | print(meta['column_info'][idx][int(col)], end=' ', file=f)
87 |
88 | print(file=f)
89 |
--------------------------------------------------------------------------------
/ctgan/data_sampler.py:
--------------------------------------------------------------------------------
1 | """DataSampler module."""
2 |
3 | import numpy as np
4 |
5 |
6 | class DataSampler(object):
7 | """DataSampler samples the conditional vector and corresponding data for CTGAN."""
8 |
9 | def __init__(self, data, output_info, log_frequency):
10 | self._data_length = len(data)
11 |
12 | def is_discrete_column(column_info):
13 | return len(column_info) == 1 and column_info[0].activation_fn == 'softmax'
14 |
15 | n_discrete_columns = sum([
16 | 1 for column_info in output_info if is_discrete_column(column_info)
17 | ])
18 |
19 | self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype='int32')
20 |
21 | # Store the row id for each category in each discrete column.
22 | # For example _rid_by_cat_cols[a][b] is a list of all rows with the
23 | # a-th discrete column equal value b.
24 | self._rid_by_cat_cols = []
25 |
26 | # Compute _rid_by_cat_cols
27 | st = 0
28 | for column_info in output_info:
29 | if is_discrete_column(column_info):
30 | span_info = column_info[0]
31 | ed = st + span_info.dim
32 |
33 | rid_by_cat = []
34 | for j in range(span_info.dim):
35 | rid_by_cat.append(np.nonzero(data[:, st + j])[0])
36 | self._rid_by_cat_cols.append(rid_by_cat)
37 | st = ed
38 | else:
39 | st += sum([span_info.dim for span_info in column_info])
40 | assert st == data.shape[1]
41 |
42 | # Prepare an interval matrix for efficiently sample conditional vector
43 | max_category = max(
44 | [column_info[0].dim for column_info in output_info if is_discrete_column(column_info)],
45 | default=0,
46 | )
47 |
48 | self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
49 | self._discrete_column_n_category = np.zeros(n_discrete_columns, dtype='int32')
50 | self._discrete_column_category_prob = np.zeros((n_discrete_columns, max_category))
51 | self._n_discrete_columns = n_discrete_columns
52 | self._n_categories = sum([
53 | column_info[0].dim for column_info in output_info if is_discrete_column(column_info)
54 | ])
55 |
56 | st = 0
57 | current_id = 0
58 | current_cond_st = 0
59 | for column_info in output_info:
60 | if is_discrete_column(column_info):
61 | span_info = column_info[0]
62 | ed = st + span_info.dim
63 | category_freq = np.sum(data[:, st:ed], axis=0)
64 | if log_frequency:
65 | category_freq = np.log(category_freq + 1)
66 | category_prob = category_freq / np.sum(category_freq)
67 | self._discrete_column_category_prob[current_id, : span_info.dim] = category_prob
68 | self._discrete_column_cond_st[current_id] = current_cond_st
69 | self._discrete_column_n_category[current_id] = span_info.dim
70 | current_cond_st += span_info.dim
71 | current_id += 1
72 | st = ed
73 | else:
74 | st += sum([span_info.dim for span_info in column_info])
75 |
76 | def _random_choice_prob_index(self, discrete_column_id):
77 | probs = self._discrete_column_category_prob[discrete_column_id]
78 | r = np.expand_dims(np.random.rand(probs.shape[0]), axis=1)
79 | return (probs.cumsum(axis=1) > r).argmax(axis=1)
80 |
81 | def sample_condvec(self, batch):
82 | """Generate the conditional vector for training.
83 |
84 | Returns:
85 | cond (batch x #categories):
86 | The conditional vector.
87 | mask (batch x #discrete columns):
88 | A one-hot vector indicating the selected discrete column.
89 | discrete column id (batch):
90 | Integer representation of mask.
91 | category_id_in_col (batch):
92 | Selected category in the selected discrete column.
93 | """
94 | if self._n_discrete_columns == 0:
95 | return None
96 |
97 | discrete_column_id = np.random.choice(np.arange(self._n_discrete_columns), batch)
98 |
99 | cond = np.zeros((batch, self._n_categories), dtype='float32')
100 | mask = np.zeros((batch, self._n_discrete_columns), dtype='float32')
101 | mask[np.arange(batch), discrete_column_id] = 1
102 | category_id_in_col = self._random_choice_prob_index(discrete_column_id)
103 | category_id = self._discrete_column_cond_st[discrete_column_id] + category_id_in_col
104 | cond[np.arange(batch), category_id] = 1
105 |
106 | return cond, mask, discrete_column_id, category_id_in_col
107 |
108 | def sample_original_condvec(self, batch):
109 | """Generate the conditional vector for generation use original frequency."""
110 | if self._n_discrete_columns == 0:
111 | return None
112 |
113 | category_freq = self._discrete_column_category_prob.flatten()
114 | category_freq = category_freq[category_freq != 0]
115 | category_freq = category_freq / np.sum(category_freq)
116 | col_idxs = np.random.choice(np.arange(len(category_freq)), batch, p=category_freq)
117 | cond = np.zeros((batch, self._n_categories), dtype='float32')
118 | cond[np.arange(batch), col_idxs] = 1
119 |
120 | return cond
121 |
122 | def sample_data(self, data, n, col, opt):
123 | """Sample data from original training data satisfying the sampled conditional vector.
124 |
125 | Args:
126 | data:
127 | The training data.
128 |
129 | Returns:
130 | n:
131 | n rows of matrix data.
132 | """
133 | if col is None:
134 | idx = np.random.randint(len(data), size=n)
135 | return data[idx]
136 |
137 | idx = []
138 | for c, o in zip(col, opt):
139 | idx.append(np.random.choice(self._rid_by_cat_cols[c][o]))
140 |
141 | return data[idx]
142 |
143 | def dim_cond_vec(self):
144 | """Return the total number of categories."""
145 | return self._n_categories
146 |
147 | def generate_cond_from_condition_column_info(self, condition_info, batch):
148 | """Generate the condition vector."""
149 | vec = np.zeros((batch, self._n_categories), dtype='float32')
150 | id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]
151 | id_ += condition_info['value_id']
152 | vec[:, id_] = 1
153 | return vec
154 |
--------------------------------------------------------------------------------
/ctgan/data_transformer.py:
--------------------------------------------------------------------------------
1 | """DataTransformer module."""
2 |
3 | from collections import namedtuple
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from joblib import Parallel, delayed
8 | from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder
9 |
10 | SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn'])
11 | ColumnTransformInfo = namedtuple(
12 | 'ColumnTransformInfo',
13 | ['column_name', 'column_type', 'transform', 'output_info', 'output_dimensions'],
14 | )
15 |
16 |
17 | class DataTransformer(object):
18 | """Data Transformer.
19 |
20 | Model continuous columns with a BayesianGMM and normalize them to a scalar between [-1, 1]
21 | and a vector. Discrete columns are encoded using a OneHotEncoder.
22 | """
23 |
24 | def __init__(self, max_clusters=10, weight_threshold=0.005):
25 | """Create a data transformer.
26 |
27 | Args:
28 | max_clusters (int):
29 | Maximum number of Gaussian distributions in Bayesian GMM.
30 | weight_threshold (float):
31 | Weight threshold for a Gaussian distribution to be kept.
32 | """
33 | self._max_clusters = max_clusters
34 | self._weight_threshold = weight_threshold
35 |
36 | def _fit_continuous(self, data):
37 | """Train Bayesian GMM for continuous columns.
38 |
39 | Args:
40 | data (pd.DataFrame):
41 | A dataframe containing a column.
42 |
43 | Returns:
44 | namedtuple:
45 | A ``ColumnTransformInfo`` object.
46 | """
47 | column_name = data.columns[0]
48 | gm = ClusterBasedNormalizer(
49 | missing_value_generation='from_column',
50 | max_clusters=min(len(data), self._max_clusters),
51 | weight_threshold=self._weight_threshold,
52 | )
53 | gm.fit(data, column_name)
54 | num_components = sum(gm.valid_component_indicator)
55 |
56 | return ColumnTransformInfo(
57 | column_name=column_name,
58 | column_type='continuous',
59 | transform=gm,
60 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(num_components, 'softmax')],
61 | output_dimensions=1 + num_components,
62 | )
63 |
64 | def _fit_discrete(self, data):
65 | """Fit one hot encoder for discrete column.
66 |
67 | Args:
68 | data (pd.DataFrame):
69 | A dataframe containing a column.
70 |
71 | Returns:
72 | namedtuple:
73 | A ``ColumnTransformInfo`` object.
74 | """
75 | column_name = data.columns[0]
76 | ohe = OneHotEncoder()
77 | ohe.fit(data, column_name)
78 | num_categories = len(ohe.dummies)
79 |
80 | return ColumnTransformInfo(
81 | column_name=column_name,
82 | column_type='discrete',
83 | transform=ohe,
84 | output_info=[SpanInfo(num_categories, 'softmax')],
85 | output_dimensions=num_categories,
86 | )
87 |
88 | def fit(self, raw_data, discrete_columns=()):
89 | """Fit the ``DataTransformer``.
90 |
91 | Fits a ``ClusterBasedNormalizer`` for continuous columns and a
92 | ``OneHotEncoder`` for discrete columns.
93 |
94 | This step also counts the #columns in matrix data and span information.
95 | """
96 | self.output_info_list = []
97 | self.output_dimensions = 0
98 | self.dataframe = True
99 |
100 | if not isinstance(raw_data, pd.DataFrame):
101 | self.dataframe = False
102 | # work around for RDT issue #328 Fitting with numerical column names fails
103 | discrete_columns = [str(column) for column in discrete_columns]
104 | column_names = [str(num) for num in range(raw_data.shape[1])]
105 | raw_data = pd.DataFrame(raw_data, columns=column_names)
106 |
107 | self._column_raw_dtypes = raw_data.infer_objects().dtypes
108 | self._column_transform_info_list = []
109 | for column_name in raw_data.columns:
110 | if column_name in discrete_columns:
111 | column_transform_info = self._fit_discrete(raw_data[[column_name]])
112 | else:
113 | column_transform_info = self._fit_continuous(raw_data[[column_name]])
114 |
115 | self.output_info_list.append(column_transform_info.output_info)
116 | self.output_dimensions += column_transform_info.output_dimensions
117 | self._column_transform_info_list.append(column_transform_info)
118 |
119 | def _transform_continuous(self, column_transform_info, data):
120 | column_name = data.columns[0]
121 | flattened_column = data[column_name].to_numpy().flatten()
122 | data = data.assign(**{column_name: flattened_column})
123 | gm = column_transform_info.transform
124 | transformed = gm.transform(data)
125 |
126 | # Converts the transformed data to the appropriate output format.
127 | # The first column (ending in '.normalized') stays the same,
128 | # but the lable encoded column (ending in '.component') is one hot encoded.
129 | output = np.zeros((len(transformed), column_transform_info.output_dimensions))
130 | output[:, 0] = transformed[f'{column_name}.normalized'].to_numpy()
131 | index = transformed[f'{column_name}.component'].to_numpy().astype(int)
132 | output[np.arange(index.size), index + 1] = 1.0
133 |
134 | return output
135 |
136 | def _transform_discrete(self, column_transform_info, data):
137 | ohe = column_transform_info.transform
138 | return ohe.transform(data).to_numpy()
139 |
140 | def _synchronous_transform(self, raw_data, column_transform_info_list):
141 | """Take a Pandas DataFrame and transform columns synchronous.
142 |
143 | Outputs a list with Numpy arrays.
144 | """
145 | column_data_list = []
146 | for column_transform_info in column_transform_info_list:
147 | column_name = column_transform_info.column_name
148 | data = raw_data[[column_name]]
149 | if column_transform_info.column_type == 'continuous':
150 | column_data_list.append(self._transform_continuous(column_transform_info, data))
151 | else:
152 | column_data_list.append(self._transform_discrete(column_transform_info, data))
153 |
154 | return column_data_list
155 |
156 | def _parallel_transform(self, raw_data, column_transform_info_list):
157 | """Take a Pandas DataFrame and transform columns in parallel.
158 |
159 | Outputs a list with Numpy arrays.
160 | """
161 | processes = []
162 | for column_transform_info in column_transform_info_list:
163 | column_name = column_transform_info.column_name
164 | data = raw_data[[column_name]]
165 | process = None
166 | if column_transform_info.column_type == 'continuous':
167 | process = delayed(self._transform_continuous)(column_transform_info, data)
168 | else:
169 | process = delayed(self._transform_discrete)(column_transform_info, data)
170 | processes.append(process)
171 |
172 | return Parallel(n_jobs=-1)(processes)
173 |
174 | def transform(self, raw_data):
175 | """Take raw data and output a matrix data."""
176 | if not isinstance(raw_data, pd.DataFrame):
177 | column_names = [str(num) for num in range(raw_data.shape[1])]
178 | raw_data = pd.DataFrame(raw_data, columns=column_names)
179 |
180 | # Only use parallelization with larger data sizes.
181 | # Otherwise, the transformation will be slower.
182 | if raw_data.shape[0] < 500:
183 | column_data_list = self._synchronous_transform(
184 | raw_data, self._column_transform_info_list
185 | )
186 | else:
187 | column_data_list = self._parallel_transform(raw_data, self._column_transform_info_list)
188 |
189 | return np.concatenate(column_data_list, axis=1).astype(float)
190 |
191 | def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
192 | gm = column_transform_info.transform
193 | data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())).astype(float)
194 | data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1)
195 | if sigmas is not None:
196 | selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])
197 | data.iloc[:, 0] = selected_normalized_value
198 |
199 | return gm.reverse_transform(data)
200 |
201 | def _inverse_transform_discrete(self, column_transform_info, column_data):
202 | ohe = column_transform_info.transform
203 | data = pd.DataFrame(column_data, columns=list(ohe.get_output_sdtypes()))
204 | return ohe.reverse_transform(data)[column_transform_info.column_name]
205 |
206 | def inverse_transform(self, data, sigmas=None):
207 | """Take matrix data and output raw data.
208 |
209 | Output uses the same type as input to the transform function.
210 | Either np array or pd dataframe.
211 | """
212 | st = 0
213 | recovered_column_data_list = []
214 | column_names = []
215 | for column_transform_info in self._column_transform_info_list:
216 | dim = column_transform_info.output_dimensions
217 | column_data = data[:, st : st + dim]
218 | if column_transform_info.column_type == 'continuous':
219 | recovered_column_data = self._inverse_transform_continuous(
220 | column_transform_info, column_data, sigmas, st
221 | )
222 | else:
223 | recovered_column_data = self._inverse_transform_discrete(
224 | column_transform_info, column_data
225 | )
226 |
227 | recovered_column_data_list.append(recovered_column_data)
228 | column_names.append(column_transform_info.column_name)
229 | st += dim
230 |
231 | recovered_data = np.column_stack(recovered_column_data_list)
232 | recovered_data = pd.DataFrame(recovered_data, columns=column_names).astype(
233 | self._column_raw_dtypes
234 | )
235 | if not self.dataframe:
236 | recovered_data = recovered_data.to_numpy()
237 |
238 | return recovered_data
239 |
240 | def convert_column_name_value_to_id(self, column_name, value):
241 | """Get the ids of the given `column_name`."""
242 | discrete_counter = 0
243 | column_id = 0
244 | for column_transform_info in self._column_transform_info_list:
245 | if column_transform_info.column_name == column_name:
246 | break
247 | if column_transform_info.column_type == 'discrete':
248 | discrete_counter += 1
249 |
250 | column_id += 1
251 |
252 | else:
253 | raise ValueError(f"The column_name `{column_name}` doesn't exist in the data.")
254 |
255 | ohe = column_transform_info.transform
256 | data = pd.DataFrame([value], columns=[column_transform_info.column_name])
257 | one_hot = ohe.transform(data).to_numpy()[0]
258 | if sum(one_hot) == 0:
259 | raise ValueError(f"The value `{value}` doesn't exist in the column `{column_name}`.")
260 |
261 | return {
262 | 'discrete_column_id': discrete_counter,
263 | 'column_id': column_id,
264 | 'value_id': np.argmax(one_hot),
265 | }
266 |
--------------------------------------------------------------------------------
/ctgan/demo.py:
--------------------------------------------------------------------------------
1 | """Demo module."""
2 |
3 | import pandas as pd
4 |
5 | DEMO_URL = 'http://ctgan-demo.s3.amazonaws.com/census.csv.gz'
6 |
7 |
8 | def load_demo():
9 | """Load the demo."""
10 | return pd.read_csv(DEMO_URL, compression='gzip')
11 |
--------------------------------------------------------------------------------
/ctgan/errors.py:
--------------------------------------------------------------------------------
1 | """Custom errors for CTGAN."""
2 |
3 |
4 | class InvalidDataError(Exception):
5 | """Error to raise when data is not valid."""
6 |
--------------------------------------------------------------------------------
/ctgan/synthesizers/__init__.py:
--------------------------------------------------------------------------------
1 | """Synthesizers module."""
2 |
3 | from ctgan.synthesizers.ctgan import CTGAN
4 | from ctgan.synthesizers.tvae import TVAE
5 |
6 | __all__ = ('CTGAN', 'TVAE')
7 |
8 |
9 | def get_all_synthesizers():
10 | return {name: globals()[name] for name in __all__}
11 |
--------------------------------------------------------------------------------
/ctgan/synthesizers/base.py:
--------------------------------------------------------------------------------
1 | """BaseSynthesizer module."""
2 |
3 | import contextlib
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | @contextlib.contextmanager
10 | def set_random_states(random_state, set_model_random_state):
11 | """Context manager for managing the random state.
12 |
13 | Args:
14 | random_state (int or tuple):
15 | The random seed or a tuple of (numpy.random.RandomState, torch.Generator).
16 | set_model_random_state (function):
17 | Function to set the random state on the model.
18 | """
19 | original_np_state = np.random.get_state()
20 | original_torch_state = torch.get_rng_state()
21 |
22 | random_np_state, random_torch_state = random_state
23 |
24 | np.random.set_state(random_np_state.get_state())
25 | torch.set_rng_state(random_torch_state.get_state())
26 |
27 | try:
28 | yield
29 | finally:
30 | current_np_state = np.random.RandomState()
31 | current_np_state.set_state(np.random.get_state())
32 | current_torch_state = torch.Generator()
33 | current_torch_state.set_state(torch.get_rng_state())
34 | set_model_random_state((current_np_state, current_torch_state))
35 |
36 | np.random.set_state(original_np_state)
37 | torch.set_rng_state(original_torch_state)
38 |
39 |
40 | def random_state(function):
41 | """Set the random state before calling the function.
42 |
43 | Args:
44 | function (Callable):
45 | The function to wrap around.
46 | """
47 |
48 | def wrapper(self, *args, **kwargs):
49 | if self.random_states is None:
50 | return function(self, *args, **kwargs)
51 |
52 | else:
53 | with set_random_states(self.random_states, self.set_random_state):
54 | return function(self, *args, **kwargs)
55 |
56 | return wrapper
57 |
58 |
59 | class BaseSynthesizer:
60 | """Base class for all default synthesizers of ``CTGAN``."""
61 |
62 | random_states = None
63 |
64 | def __getstate__(self):
65 | """Improve pickling state for ``BaseSynthesizer``.
66 |
67 | Convert to ``cpu`` device before starting the pickling process in order to be able to
68 | load the model even when used from an external tool such as ``SDV``. Also, if
69 | ``random_states`` are set, store their states as dictionaries rather than generators.
70 |
71 | Returns:
72 | dict:
73 | Python dict representing the object.
74 | """
75 | device_backup = self._device
76 | self.set_device(torch.device('cpu'))
77 | state = self.__dict__.copy()
78 | self.set_device(device_backup)
79 | if (
80 | isinstance(self.random_states, tuple)
81 | and isinstance(self.random_states[0], np.random.RandomState)
82 | and isinstance(self.random_states[1], torch.Generator)
83 | ):
84 | state['_numpy_random_state'] = self.random_states[0].get_state()
85 | state['_torch_random_state'] = self.random_states[1].get_state()
86 | state.pop('random_states')
87 |
88 | return state
89 |
90 | def __setstate__(self, state):
91 | """Restore the state of a ``BaseSynthesizer``.
92 |
93 | Restore the ``random_states`` from the state dict if those are present and then
94 | set the device according to the current hardware.
95 | """
96 | if '_numpy_random_state' in state and '_torch_random_state' in state:
97 | np_state = state.pop('_numpy_random_state')
98 | torch_state = state.pop('_torch_random_state')
99 |
100 | current_torch_state = torch.Generator()
101 | current_torch_state.set_state(torch_state)
102 |
103 | current_numpy_state = np.random.RandomState()
104 | current_numpy_state.set_state(np_state)
105 | state['random_states'] = (current_numpy_state, current_torch_state)
106 |
107 | self.__dict__ = state
108 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
109 | self.set_device(device)
110 |
111 | def save(self, path):
112 | """Save the model in the passed `path`."""
113 | device_backup = self._device
114 | self.set_device(torch.device('cpu'))
115 | torch.save(self, path)
116 | self.set_device(device_backup)
117 |
118 | @classmethod
119 | def load(cls, path):
120 | """Load the model stored in the passed `path`."""
121 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
122 | model = torch.load(path, weights_only=False)
123 | model.set_device(device)
124 | return model
125 |
126 | def set_random_state(self, random_state):
127 | """Set the random state.
128 |
129 | Args:
130 | random_state (int, tuple, or None):
131 | Either a tuple containing the (numpy.random.RandomState, torch.Generator)
132 | or an int representing the random seed to use for both random states.
133 | """
134 | if random_state is None:
135 | self.random_states = random_state
136 | elif isinstance(random_state, int):
137 | self.random_states = (
138 | np.random.RandomState(seed=random_state),
139 | torch.Generator().manual_seed(random_state),
140 | )
141 | elif (
142 | isinstance(random_state, tuple)
143 | and isinstance(random_state[0], np.random.RandomState)
144 | and isinstance(random_state[1], torch.Generator)
145 | ):
146 | self.random_states = random_state
147 | else:
148 | raise TypeError(
149 | f'`random_state` {random_state} expected to be an int or a tuple of '
150 | '(`np.random.RandomState`, `torch.Generator`)'
151 | )
152 |
--------------------------------------------------------------------------------
/ctgan/synthesizers/ctgan.py:
--------------------------------------------------------------------------------
1 | """CTGAN module."""
2 |
3 | import warnings
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from torch import optim
9 | from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
10 | from tqdm import tqdm
11 |
12 | from ctgan.data_sampler import DataSampler
13 | from ctgan.data_transformer import DataTransformer
14 | from ctgan.errors import InvalidDataError
15 | from ctgan.synthesizers.base import BaseSynthesizer, random_state
16 |
17 |
18 | class Discriminator(Module):
19 | """Discriminator for the CTGAN."""
20 |
21 | def __init__(self, input_dim, discriminator_dim, pac=10):
22 | super(Discriminator, self).__init__()
23 | dim = input_dim * pac
24 | self.pac = pac
25 | self.pacdim = dim
26 | seq = []
27 | for item in list(discriminator_dim):
28 | seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
29 | dim = item
30 |
31 | seq += [Linear(dim, 1)]
32 | self.seq = Sequential(*seq)
33 |
34 | def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):
35 | """Compute the gradient penalty."""
36 | alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
37 | alpha = alpha.repeat(1, pac, real_data.size(1))
38 | alpha = alpha.view(-1, real_data.size(1))
39 |
40 | interpolates = alpha * real_data + ((1 - alpha) * fake_data)
41 |
42 | disc_interpolates = self(interpolates)
43 |
44 | gradients = torch.autograd.grad(
45 | outputs=disc_interpolates,
46 | inputs=interpolates,
47 | grad_outputs=torch.ones(disc_interpolates.size(), device=device),
48 | create_graph=True,
49 | retain_graph=True,
50 | only_inputs=True,
51 | )[0]
52 |
53 | gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
54 | gradient_penalty = ((gradients_view) ** 2).mean() * lambda_
55 |
56 | return gradient_penalty
57 |
58 | def forward(self, input_):
59 | """Apply the Discriminator to the `input_`."""
60 | assert input_.size()[0] % self.pac == 0
61 | return self.seq(input_.view(-1, self.pacdim))
62 |
63 |
64 | class Residual(Module):
65 | """Residual layer for the CTGAN."""
66 |
67 | def __init__(self, i, o):
68 | super(Residual, self).__init__()
69 | self.fc = Linear(i, o)
70 | self.bn = BatchNorm1d(o)
71 | self.relu = ReLU()
72 |
73 | def forward(self, input_):
74 | """Apply the Residual layer to the `input_`."""
75 | out = self.fc(input_)
76 | out = self.bn(out)
77 | out = self.relu(out)
78 | return torch.cat([out, input_], dim=1)
79 |
80 |
81 | class Generator(Module):
82 | """Generator for the CTGAN."""
83 |
84 | def __init__(self, embedding_dim, generator_dim, data_dim):
85 | super(Generator, self).__init__()
86 | dim = embedding_dim
87 | seq = []
88 | for item in list(generator_dim):
89 | seq += [Residual(dim, item)]
90 | dim += item
91 | seq.append(Linear(dim, data_dim))
92 | self.seq = Sequential(*seq)
93 |
94 | def forward(self, input_):
95 | """Apply the Generator to the `input_`."""
96 | data = self.seq(input_)
97 | return data
98 |
99 |
100 | class CTGAN(BaseSynthesizer):
101 | """Conditional Table GAN Synthesizer.
102 |
103 | This is the core class of the CTGAN project, where the different components
104 | are orchestrated together.
105 | For more details about the process, please check the [Modeling Tabular data using
106 | Conditional GAN](https://arxiv.org/abs/1907.00503) paper.
107 |
108 | Args:
109 | embedding_dim (int):
110 | Size of the random sample passed to the Generator. Defaults to 128.
111 | generator_dim (tuple or list of ints):
112 | Size of the output samples for each one of the Residuals. A Residual Layer
113 | will be created for each one of the values provided. Defaults to (256, 256).
114 | discriminator_dim (tuple or list of ints):
115 | Size of the output samples for each one of the Discriminator Layers. A Linear Layer
116 | will be created for each one of the values provided. Defaults to (256, 256).
117 | generator_lr (float):
118 | Learning rate for the generator. Defaults to 2e-4.
119 | generator_decay (float):
120 | Generator weight decay for the Adam Optimizer. Defaults to 1e-6.
121 | discriminator_lr (float):
122 | Learning rate for the discriminator. Defaults to 2e-4.
123 | discriminator_decay (float):
124 | Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.
125 | batch_size (int):
126 | Number of data samples to process in each step.
127 | discriminator_steps (int):
128 | Number of discriminator updates to do for each generator update.
129 | From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
130 | default is 5. Default used is 1 to match original CTGAN implementation.
131 | log_frequency (boolean):
132 | Whether to use log frequency of categorical levels in conditional
133 | sampling. Defaults to ``True``.
134 | verbose (boolean):
135 | Whether to have print statements for progress results. Defaults to ``False``.
136 | epochs (int):
137 | Number of training epochs. Defaults to 300.
138 | pac (int):
139 | Number of samples to group together when applying the discriminator.
140 | Defaults to 10.
141 | cuda (bool):
142 | Whether to attempt to use cuda for GPU computation.
143 | If this is False or CUDA is not available, CPU will be used.
144 | Defaults to ``True``.
145 | """
146 |
147 | def __init__(
148 | self,
149 | embedding_dim=128,
150 | generator_dim=(256, 256),
151 | discriminator_dim=(256, 256),
152 | generator_lr=2e-4,
153 | generator_decay=1e-6,
154 | discriminator_lr=2e-4,
155 | discriminator_decay=1e-6,
156 | batch_size=500,
157 | discriminator_steps=1,
158 | log_frequency=True,
159 | verbose=False,
160 | epochs=300,
161 | pac=10,
162 | cuda=True,
163 | ):
164 | assert batch_size % 2 == 0
165 |
166 | self._embedding_dim = embedding_dim
167 | self._generator_dim = generator_dim
168 | self._discriminator_dim = discriminator_dim
169 |
170 | self._generator_lr = generator_lr
171 | self._generator_decay = generator_decay
172 | self._discriminator_lr = discriminator_lr
173 | self._discriminator_decay = discriminator_decay
174 |
175 | self._batch_size = batch_size
176 | self._discriminator_steps = discriminator_steps
177 | self._log_frequency = log_frequency
178 | self._verbose = verbose
179 | self._epochs = epochs
180 | self.pac = pac
181 |
182 | if not cuda or not torch.cuda.is_available():
183 | device = 'cpu'
184 | elif isinstance(cuda, str):
185 | device = cuda
186 | else:
187 | device = 'cuda'
188 |
189 | self._device = torch.device(device)
190 |
191 | self._transformer = None
192 | self._data_sampler = None
193 | self._generator = None
194 | self.loss_values = None
195 |
196 | @staticmethod
197 | def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
198 | """Deals with the instability of the gumbel_softmax for older versions of torch.
199 |
200 | For more details about the issue:
201 | https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
202 |
203 | Args:
204 | logits […, num_features]:
205 | Unnormalized log probabilities
206 | tau:
207 | Non-negative scalar temperature
208 | hard (bool):
209 | If True, the returned samples will be discretized as one-hot vectors,
210 | but will be differentiated as if it is the soft sample in autograd
211 | dim (int):
212 | A dimension along which softmax will be computed. Default: -1.
213 |
214 | Returns:
215 | Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
216 | """
217 | for _ in range(10):
218 | transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
219 | if not torch.isnan(transformed).any():
220 | return transformed
221 |
222 | raise ValueError('gumbel_softmax returning NaN.')
223 |
224 | def _apply_activate(self, data):
225 | """Apply proper activation function to the output of the generator."""
226 | data_t = []
227 | st = 0
228 | for column_info in self._transformer.output_info_list:
229 | for span_info in column_info:
230 | if span_info.activation_fn == 'tanh':
231 | ed = st + span_info.dim
232 | data_t.append(torch.tanh(data[:, st:ed]))
233 | st = ed
234 | elif span_info.activation_fn == 'softmax':
235 | ed = st + span_info.dim
236 | transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2)
237 | data_t.append(transformed)
238 | st = ed
239 | else:
240 | raise ValueError(f'Unexpected activation function {span_info.activation_fn}.')
241 |
242 | return torch.cat(data_t, dim=1)
243 |
244 | def _cond_loss(self, data, c, m):
245 | """Compute the cross entropy loss on the fixed discrete column."""
246 | loss = []
247 | st = 0
248 | st_c = 0
249 | for column_info in self._transformer.output_info_list:
250 | for span_info in column_info:
251 | if len(column_info) != 1 or span_info.activation_fn != 'softmax':
252 | # not discrete column
253 | st += span_info.dim
254 | else:
255 | ed = st + span_info.dim
256 | ed_c = st_c + span_info.dim
257 | tmp = functional.cross_entropy(
258 | data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none'
259 | )
260 | loss.append(tmp)
261 | st = ed
262 | st_c = ed_c
263 |
264 | loss = torch.stack(loss, dim=1) # noqa: PD013
265 |
266 | return (loss * m).sum() / data.size()[0]
267 |
268 | def _validate_discrete_columns(self, train_data, discrete_columns):
269 | """Check whether ``discrete_columns`` exists in ``train_data``.
270 |
271 | Args:
272 | train_data (numpy.ndarray or pandas.DataFrame):
273 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
274 | discrete_columns (list-like):
275 | List of discrete columns to be used to generate the Conditional
276 | Vector. If ``train_data`` is a Numpy array, this list should
277 | contain the integer indices of the columns. Otherwise, if it is
278 | a ``pandas.DataFrame``, this list should contain the column names.
279 | """
280 | if isinstance(train_data, pd.DataFrame):
281 | invalid_columns = set(discrete_columns) - set(train_data.columns)
282 | elif isinstance(train_data, np.ndarray):
283 | invalid_columns = []
284 | for column in discrete_columns:
285 | if column < 0 or column >= train_data.shape[1]:
286 | invalid_columns.append(column)
287 | else:
288 | raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')
289 |
290 | if invalid_columns:
291 | raise ValueError(f'Invalid columns found: {invalid_columns}')
292 |
293 | def _validate_null_data(self, train_data, discrete_columns):
294 | """Check whether null values exist in continuous ``train_data``.
295 |
296 | Args:
297 | train_data (numpy.ndarray or pandas.DataFrame):
298 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
299 | discrete_columns (list-like):
300 | List of discrete columns to be used to generate the Conditional
301 | Vector. If ``train_data`` is a Numpy array, this list should
302 | contain the integer indices of the columns. Otherwise, if it is
303 | a ``pandas.DataFrame``, this list should contain the column names.
304 | """
305 | if isinstance(train_data, pd.DataFrame):
306 | continuous_cols = list(set(train_data.columns) - set(discrete_columns))
307 | any_nulls = train_data[continuous_cols].isna().any().any()
308 | else:
309 | continuous_cols = [i for i in range(train_data.shape[1]) if i not in discrete_columns]
310 | any_nulls = pd.DataFrame(train_data)[continuous_cols].isna().any().any()
311 |
312 | if any_nulls:
313 | raise InvalidDataError(
314 | 'CTGAN does not support null values in the continuous training data. '
315 | 'Please remove all null values from your continuous training data.'
316 | )
317 |
318 | @random_state
319 | def fit(self, train_data, discrete_columns=(), epochs=None):
320 | """Fit the CTGAN Synthesizer models to the training data.
321 |
322 | Args:
323 | train_data (numpy.ndarray or pandas.DataFrame):
324 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
325 | discrete_columns (list-like):
326 | List of discrete columns to be used to generate the Conditional
327 | Vector. If ``train_data`` is a Numpy array, this list should
328 | contain the integer indices of the columns. Otherwise, if it is
329 | a ``pandas.DataFrame``, this list should contain the column names.
330 | """
331 | self._validate_discrete_columns(train_data, discrete_columns)
332 | self._validate_null_data(train_data, discrete_columns)
333 |
334 | if epochs is None:
335 | epochs = self._epochs
336 | else:
337 | warnings.warn(
338 | (
339 | '`epochs` argument in `fit` method has been deprecated and will be removed '
340 | 'in a future version. Please pass `epochs` to the constructor instead'
341 | ),
342 | DeprecationWarning,
343 | )
344 |
345 | self._transformer = DataTransformer()
346 | self._transformer.fit(train_data, discrete_columns)
347 |
348 | train_data = self._transformer.transform(train_data)
349 |
350 | self._data_sampler = DataSampler(
351 | train_data, self._transformer.output_info_list, self._log_frequency
352 | )
353 |
354 | data_dim = self._transformer.output_dimensions
355 |
356 | self._generator = Generator(
357 | self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
358 | ).to(self._device)
359 |
360 | discriminator = Discriminator(
361 | data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
362 | ).to(self._device)
363 |
364 | optimizerG = optim.Adam(
365 | self._generator.parameters(),
366 | lr=self._generator_lr,
367 | betas=(0.5, 0.9),
368 | weight_decay=self._generator_decay,
369 | )
370 |
371 | optimizerD = optim.Adam(
372 | discriminator.parameters(),
373 | lr=self._discriminator_lr,
374 | betas=(0.5, 0.9),
375 | weight_decay=self._discriminator_decay,
376 | )
377 |
378 | mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
379 | std = mean + 1
380 |
381 | self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])
382 |
383 | epoch_iterator = tqdm(range(epochs), disable=(not self._verbose))
384 | if self._verbose:
385 | description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
386 | epoch_iterator.set_description(description.format(gen=0, dis=0))
387 |
388 | steps_per_epoch = max(len(train_data) // self._batch_size, 1)
389 | for i in epoch_iterator:
390 | for id_ in range(steps_per_epoch):
391 | for n in range(self._discriminator_steps):
392 | fakez = torch.normal(mean=mean, std=std)
393 |
394 | condvec = self._data_sampler.sample_condvec(self._batch_size)
395 | if condvec is None:
396 | c1, m1, col, opt = None, None, None, None
397 | real = self._data_sampler.sample_data(
398 | train_data, self._batch_size, col, opt
399 | )
400 | else:
401 | c1, m1, col, opt = condvec
402 | c1 = torch.from_numpy(c1).to(self._device)
403 | m1 = torch.from_numpy(m1).to(self._device)
404 | fakez = torch.cat([fakez, c1], dim=1)
405 |
406 | perm = np.arange(self._batch_size)
407 | np.random.shuffle(perm)
408 | real = self._data_sampler.sample_data(
409 | train_data, self._batch_size, col[perm], opt[perm]
410 | )
411 | c2 = c1[perm]
412 |
413 | fake = self._generator(fakez)
414 | fakeact = self._apply_activate(fake)
415 |
416 | real = torch.from_numpy(real.astype('float32')).to(self._device)
417 |
418 | if c1 is not None:
419 | fake_cat = torch.cat([fakeact, c1], dim=1)
420 | real_cat = torch.cat([real, c2], dim=1)
421 | else:
422 | real_cat = real
423 | fake_cat = fakeact
424 |
425 | y_fake = discriminator(fake_cat)
426 | y_real = discriminator(real_cat)
427 |
428 | pen = discriminator.calc_gradient_penalty(
429 | real_cat, fake_cat, self._device, self.pac
430 | )
431 | loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
432 |
433 | optimizerD.zero_grad(set_to_none=False)
434 | pen.backward(retain_graph=True)
435 | loss_d.backward()
436 | optimizerD.step()
437 |
438 | fakez = torch.normal(mean=mean, std=std)
439 | condvec = self._data_sampler.sample_condvec(self._batch_size)
440 |
441 | if condvec is None:
442 | c1, m1, col, opt = None, None, None, None
443 | else:
444 | c1, m1, col, opt = condvec
445 | c1 = torch.from_numpy(c1).to(self._device)
446 | m1 = torch.from_numpy(m1).to(self._device)
447 | fakez = torch.cat([fakez, c1], dim=1)
448 |
449 | fake = self._generator(fakez)
450 | fakeact = self._apply_activate(fake)
451 |
452 | if c1 is not None:
453 | y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
454 | else:
455 | y_fake = discriminator(fakeact)
456 |
457 | if condvec is None:
458 | cross_entropy = 0
459 | else:
460 | cross_entropy = self._cond_loss(fake, c1, m1)
461 |
462 | loss_g = -torch.mean(y_fake) + cross_entropy
463 |
464 | optimizerG.zero_grad(set_to_none=False)
465 | loss_g.backward()
466 | optimizerG.step()
467 |
468 | generator_loss = loss_g.detach().cpu().item()
469 | discriminator_loss = loss_d.detach().cpu().item()
470 |
471 | epoch_loss_df = pd.DataFrame({
472 | 'Epoch': [i],
473 | 'Generator Loss': [generator_loss],
474 | 'Discriminator Loss': [discriminator_loss],
475 | })
476 | if not self.loss_values.empty:
477 | self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
478 | drop=True
479 | )
480 | else:
481 | self.loss_values = epoch_loss_df
482 |
483 | if self._verbose:
484 | epoch_iterator.set_description(
485 | description.format(gen=generator_loss, dis=discriminator_loss)
486 | )
487 |
488 | @random_state
489 | def sample(self, n, condition_column=None, condition_value=None):
490 | """Sample data similar to the training data.
491 |
492 | Choosing a condition_column and condition_value will increase the probability of the
493 | discrete condition_value happening in the condition_column.
494 |
495 | Args:
496 | n (int):
497 | Number of rows to sample.
498 | condition_column (string):
499 | Name of a discrete column.
500 | condition_value (string):
501 | Name of the category in the condition_column which we wish to increase the
502 | probability of happening.
503 |
504 | Returns:
505 | numpy.ndarray or pandas.DataFrame
506 | """
507 | if condition_column is not None and condition_value is not None:
508 | condition_info = self._transformer.convert_column_name_value_to_id(
509 | condition_column, condition_value
510 | )
511 | global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
512 | condition_info, self._batch_size
513 | )
514 | else:
515 | global_condition_vec = None
516 |
517 | steps = n // self._batch_size + 1
518 | data = []
519 | for i in range(steps):
520 | mean = torch.zeros(self._batch_size, self._embedding_dim)
521 | std = mean + 1
522 | fakez = torch.normal(mean=mean, std=std).to(self._device)
523 |
524 | if global_condition_vec is not None:
525 | condvec = global_condition_vec.copy()
526 | else:
527 | condvec = self._data_sampler.sample_original_condvec(self._batch_size)
528 |
529 | if condvec is None:
530 | pass
531 | else:
532 | c1 = condvec
533 | c1 = torch.from_numpy(c1).to(self._device)
534 | fakez = torch.cat([fakez, c1], dim=1)
535 |
536 | fake = self._generator(fakez)
537 | fakeact = self._apply_activate(fake)
538 | data.append(fakeact.detach().cpu().numpy())
539 |
540 | data = np.concatenate(data, axis=0)
541 | data = data[:n]
542 |
543 | return self._transformer.inverse_transform(data)
544 |
545 | def set_device(self, device):
546 | """Set the `device` to be used ('GPU' or 'CPU)."""
547 | self._device = device
548 | if self._generator is not None:
549 | self._generator.to(self._device)
550 |
--------------------------------------------------------------------------------
/ctgan/synthesizers/tvae.py:
--------------------------------------------------------------------------------
1 | """TVAE module."""
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | from torch.nn import Linear, Module, Parameter, ReLU, Sequential
7 | from torch.nn.functional import cross_entropy
8 | from torch.optim import Adam
9 | from torch.utils.data import DataLoader, TensorDataset
10 | from tqdm import tqdm
11 |
12 | from ctgan.data_transformer import DataTransformer
13 | from ctgan.synthesizers.base import BaseSynthesizer, random_state
14 |
15 |
16 | class Encoder(Module):
17 | """Encoder for the TVAE.
18 |
19 | Args:
20 | data_dim (int):
21 | Dimensions of the data.
22 | compress_dims (tuple or list of ints):
23 | Size of each hidden layer.
24 | embedding_dim (int):
25 | Size of the output vector.
26 | """
27 |
28 | def __init__(self, data_dim, compress_dims, embedding_dim):
29 | super(Encoder, self).__init__()
30 | dim = data_dim
31 | seq = []
32 | for item in list(compress_dims):
33 | seq += [Linear(dim, item), ReLU()]
34 | dim = item
35 |
36 | self.seq = Sequential(*seq)
37 | self.fc1 = Linear(dim, embedding_dim)
38 | self.fc2 = Linear(dim, embedding_dim)
39 |
40 | def forward(self, input_):
41 | """Encode the passed `input_`."""
42 | feature = self.seq(input_)
43 | mu = self.fc1(feature)
44 | logvar = self.fc2(feature)
45 | std = torch.exp(0.5 * logvar)
46 | return mu, std, logvar
47 |
48 |
49 | class Decoder(Module):
50 | """Decoder for the TVAE.
51 |
52 | Args:
53 | embedding_dim (int):
54 | Size of the input vector.
55 | decompress_dims (tuple or list of ints):
56 | Size of each hidden layer.
57 | data_dim (int):
58 | Dimensions of the data.
59 | """
60 |
61 | def __init__(self, embedding_dim, decompress_dims, data_dim):
62 | super(Decoder, self).__init__()
63 | dim = embedding_dim
64 | seq = []
65 | for item in list(decompress_dims):
66 | seq += [Linear(dim, item), ReLU()]
67 | dim = item
68 |
69 | seq.append(Linear(dim, data_dim))
70 | self.seq = Sequential(*seq)
71 | self.sigma = Parameter(torch.ones(data_dim) * 0.1)
72 |
73 | def forward(self, input_):
74 | """Decode the passed `input_`."""
75 | return self.seq(input_), self.sigma
76 |
77 |
78 | def _loss_function(recon_x, x, sigmas, mu, logvar, output_info, factor):
79 | st = 0
80 | loss = []
81 | for column_info in output_info:
82 | for span_info in column_info:
83 | if span_info.activation_fn != 'softmax':
84 | ed = st + span_info.dim
85 | std = sigmas[st]
86 | eq = x[:, st] - torch.tanh(recon_x[:, st])
87 | loss.append((eq**2 / 2 / (std**2)).sum())
88 | loss.append(torch.log(std) * x.size()[0])
89 | st = ed
90 |
91 | else:
92 | ed = st + span_info.dim
93 | loss.append(
94 | cross_entropy(
95 | recon_x[:, st:ed], torch.argmax(x[:, st:ed], dim=-1), reduction='sum'
96 | )
97 | )
98 | st = ed
99 |
100 | assert st == recon_x.size()[1]
101 | KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
102 | return sum(loss) * factor / x.size()[0], KLD / x.size()[0]
103 |
104 |
105 | class TVAE(BaseSynthesizer):
106 | """TVAE."""
107 |
108 | def __init__(
109 | self,
110 | embedding_dim=128,
111 | compress_dims=(128, 128),
112 | decompress_dims=(128, 128),
113 | l2scale=1e-5,
114 | batch_size=500,
115 | epochs=300,
116 | loss_factor=2,
117 | cuda=True,
118 | verbose=False,
119 | ):
120 | self.embedding_dim = embedding_dim
121 | self.compress_dims = compress_dims
122 | self.decompress_dims = decompress_dims
123 |
124 | self.l2scale = l2scale
125 | self.batch_size = batch_size
126 | self.loss_factor = loss_factor
127 | self.epochs = epochs
128 | self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
129 | self.verbose = verbose
130 |
131 | if not cuda or not torch.cuda.is_available():
132 | device = 'cpu'
133 | elif isinstance(cuda, str):
134 | device = cuda
135 | else:
136 | device = 'cuda'
137 |
138 | self._device = torch.device(device)
139 |
140 | @random_state
141 | def fit(self, train_data, discrete_columns=()):
142 | """Fit the TVAE Synthesizer models to the training data.
143 |
144 | Args:
145 | train_data (numpy.ndarray or pandas.DataFrame):
146 | Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
147 | discrete_columns (list-like):
148 | List of discrete columns to be used to generate the Conditional
149 | Vector. If ``train_data`` is a Numpy array, this list should
150 | contain the integer indices of the columns. Otherwise, if it is
151 | a ``pandas.DataFrame``, this list should contain the column names.
152 | """
153 | self.transformer = DataTransformer()
154 | self.transformer.fit(train_data, discrete_columns)
155 | train_data = self.transformer.transform(train_data)
156 | dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device))
157 | loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)
158 |
159 | data_dim = self.transformer.output_dimensions
160 | encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device)
161 | self.decoder = Decoder(self.embedding_dim, self.decompress_dims, data_dim).to(self._device)
162 | optimizerAE = Adam(
163 | list(encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale
164 | )
165 |
166 | self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
167 | iterator = tqdm(range(self.epochs), disable=(not self.verbose))
168 | if self.verbose:
169 | iterator_description = 'Loss: {loss:.3f}'
170 | iterator.set_description(iterator_description.format(loss=0))
171 |
172 | for i in iterator:
173 | loss_values = []
174 | batch = []
175 | for id_, data in enumerate(loader):
176 | optimizerAE.zero_grad()
177 | real = data[0].to(self._device)
178 | mu, std, logvar = encoder(real)
179 | eps = torch.randn_like(std)
180 | emb = eps * std + mu
181 | rec, sigmas = self.decoder(emb)
182 | loss_1, loss_2 = _loss_function(
183 | rec,
184 | real,
185 | sigmas,
186 | mu,
187 | logvar,
188 | self.transformer.output_info_list,
189 | self.loss_factor,
190 | )
191 | loss = loss_1 + loss_2
192 | loss.backward()
193 | optimizerAE.step()
194 | self.decoder.sigma.data.clamp_(0.01, 1.0)
195 |
196 | batch.append(id_)
197 | loss_values.append(loss.detach().cpu().item())
198 |
199 | epoch_loss_df = pd.DataFrame({
200 | 'Epoch': [i] * len(batch),
201 | 'Batch': batch,
202 | 'Loss': loss_values,
203 | })
204 | if not self.loss_values.empty:
205 | self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
206 | drop=True
207 | )
208 | else:
209 | self.loss_values = epoch_loss_df
210 |
211 | if self.verbose:
212 | iterator.set_description(
213 | iterator_description.format(loss=loss.detach().cpu().item())
214 | )
215 |
216 | @random_state
217 | def sample(self, samples):
218 | """Sample data similar to the training data.
219 |
220 | Args:
221 | samples (int):
222 | Number of rows to sample.
223 |
224 | Returns:
225 | numpy.ndarray or pandas.DataFrame
226 | """
227 | self.decoder.eval()
228 |
229 | steps = samples // self.batch_size + 1
230 | data = []
231 | for _ in range(steps):
232 | mean = torch.zeros(self.batch_size, self.embedding_dim)
233 | std = mean + 1
234 | noise = torch.normal(mean=mean, std=std).to(self._device)
235 | fake, sigmas = self.decoder(noise)
236 | fake = torch.tanh(fake)
237 | data.append(fake.detach().cpu().numpy())
238 |
239 | data = np.concatenate(data, axis=0)
240 | data = data[:samples]
241 | return self.transformer.inverse_transform(data, sigmas.detach().cpu().numpy())
242 |
243 | def set_device(self, device):
244 | """Set the `device` to be used ('GPU' or 'CPU)."""
245 | self._device = device
246 | self.decoder.to(self._device)
247 |
--------------------------------------------------------------------------------
/examples/csv/adult.json:
--------------------------------------------------------------------------------
1 | {
2 | "columns": [
3 | {
4 | "name": "age",
5 | "type": "continuous"
6 | },
7 | {
8 | "name": "workclass",
9 | "type": "categorical"
10 | },
11 | {
12 | "name": "fnlwgt",
13 | "type": "continuous"
14 | },
15 | {
16 | "name": "education",
17 | "type": "ordinal"
18 | },
19 | {
20 | "name": "education-num",
21 | "type": "continuous"
22 | },
23 | {
24 | "name": "marital-status",
25 | "type": "categorical"
26 | },
27 | {
28 | "name": "occupation",
29 | "type": "categorical"
30 | },
31 | {
32 | "name": "relationship",
33 | "type": "categorical"
34 | },
35 | {
36 | "name": "race",
37 | "type": "categorical"
38 | },
39 | {
40 | "name": "sex",
41 | "type": "categorical"
42 | },
43 | {
44 | "name": "capital-gain",
45 | "type": "continuous"
46 | },
47 | {
48 | "name": "capital-loss",
49 | "type": "continuous"
50 | },
51 | {
52 | "name": "hours-per-week",
53 | "type": "continuous"
54 | },
55 | {
56 | "name": "native-country",
57 | "type": "categorical"
58 | },
59 | {
60 | "name": "income",
61 | "type": "categorical"
62 | }
63 | ]
64 | }
65 |
--------------------------------------------------------------------------------
/examples/tsv/acs.meta:
--------------------------------------------------------------------------------
1 | D 0 1
2 | D 0 1
3 | D 0 1
4 | D 0 1
5 | D 0 1
6 | D 0 1
7 | D 0 1
8 | D 0 1
9 | D 0 1
10 | D 0 1
11 | D 0 1
12 | D 0 1
13 | D 0 1
14 | D 0 1
15 | D 0 1
16 | D 0 1
17 | D 0 1
18 | D 0 1
19 | D 0 1
20 | D 0 1
21 | D 0 1
22 | D 0 1
23 | D 0 1
24 |
--------------------------------------------------------------------------------
/examples/tsv/adult.meta:
--------------------------------------------------------------------------------
1 | C 17.0 90.0
2 | D Federal-gov Local-gov State-gov Private Self-emp-inc Self-emp-not-inc Without-pay
3 | C 13492.0 1490400.0
4 | D Preschool 1st-4th 5th-6th 7th-8th 9th 10th 11th 12th HS-grad Some-college Assoc-voc Assoc-acdm Bachelors Masters Prof-school Doctorate
5 | C 1.0 16.0
6 | D Never-married Married-AF-spouse Married-civ-spouse Married-spouse-absent Separated Widowed Divorced
7 | D Adm-clerical Armed-Forces Craft-repair Exec-managerial Farming-fishing Handlers-cleaners Machine-op-inspct Other-service Priv-house-serv Prof-specialty Protective-serv Sales Tech-support Transport-moving
8 | D Husband Wife Own-child Other-relative Not-in-family Unmarried
9 | D White Black Amer-Indian-Eskimo Asian-Pac-Islander Other
10 | D Female Male
11 | C 0.0 99999.0
12 | C 0.0 4356.0
13 | C 1.0 99.0
14 | D United-States Outlying-US(Guam-USVI-etc) Puerto-Rico Canada Mexico Cambodia China Hong Japan Laos Philippines South Taiwan Thailand Vietnam Columbia Ecuador Peru Cuba Dominican-Republic El-Salvador Guatemala Haiti Honduras Jamaica Nicaragua Trinadad&Tobago England France Germany Greece Holand-Netherlands Hungary Ireland Italy Poland Portugal Scotland Yugoslavia India Iran
15 | D <=50K >50K
16 |
--------------------------------------------------------------------------------
/examples/tsv/br2000.meta:
--------------------------------------------------------------------------------
1 | D 0 1
2 | D 0 1 2 3 4 5 6
3 | C 0.0 21.0
4 | D 0 1 2 3 4 5 6 7 8 9
5 | C 0.0 90.0
6 | D 0 1
7 | D 0 1
8 | D 0 1
9 | D 0 1
10 | D 0 1 2 3
11 | C 0.0 125.0
12 | D 0 1
13 | C 0.0 90.0
14 | D 0 1
15 |
--------------------------------------------------------------------------------
/examples/tsv/nltcs.meta:
--------------------------------------------------------------------------------
1 | D 0 1
2 | D 0 1
3 | D 0 1
4 | D 0 1
5 | D 0 1
6 | D 0 1
7 | D 0 1
8 | D 0 1
9 | D 0 1
10 | D 0 1
11 | D 0 1
12 | D 0 1
13 | D 0 1
14 | D 0 1
15 | D 0 1
16 | D 0 1
17 |
--------------------------------------------------------------------------------
/latest_requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==2.0.2
2 | pandas==2.2.3
3 | rdt==1.17.0
4 | torch==2.7.0
5 | tqdm==4.67.1
6 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = 'ctgan'
3 | description = 'Create tabular synthetic data using a conditional GAN'
4 | authors = [{ name = 'DataCebo, Inc.', email = 'info@sdv.dev' }]
5 | classifiers = [
6 | 'Development Status :: 2 - Pre-Alpha',
7 | 'Intended Audience :: Developers',
8 | 'License :: Free for non-commercial use',
9 | 'Natural Language :: English',
10 | 'Programming Language :: Python :: 3',
11 | 'Programming Language :: Python :: 3.8',
12 | 'Programming Language :: Python :: 3.9',
13 | 'Programming Language :: Python :: 3.10',
14 | 'Programming Language :: Python :: 3.11',
15 | 'Programming Language :: Python :: 3.12',
16 | 'Programming Language :: Python :: 3.13',
17 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
18 | ]
19 | keywords = ['ctgan', 'CTGAN']
20 | dynamic = ['version']
21 | license = { text = 'BSL-1.1' }
22 | requires-python = '>=3.8,<3.14'
23 | readme = 'README.md'
24 | dependencies = [
25 | "numpy>=1.21.0;python_version<'3.10'",
26 | "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
27 | "numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'",
28 | "numpy>=2.1.0;python_version>='3.13'",
29 | "pandas>=1.4.0;python_version<'3.11'",
30 | "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
31 | "pandas>=2.1.1;python_version>='3.12' and python_version<'3.13'",
32 | "pandas>=2.2.3;python_version>='3.13'",
33 | "torch>=1.13.0;python_version<'3.11'",
34 | "torch>=2.0.0;python_version>='3.11' and python_version<'3.12'",
35 | "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'",
36 | "torch>=2.6.0;python_version>='3.13'",
37 | 'tqdm>=4.29,<5',
38 | 'rdt>=1.14.0',
39 | ]
40 |
41 | [project.urls]
42 | "Source Code"= "https://github.com/sdv-dev/CTGAN/"
43 | "Issue Tracker" = "https://github.com/sdv-dev/CTGAN/issues"
44 | "Changes" = "https://github.com/sdv-dev/CTGAN/blob/main/HISTORY.md"
45 | "Twitter" = "https://twitter.com/sdv_dev"
46 | "Chat" = "https://bit.ly/sdv-slack-invite"
47 |
48 | [project.entry-points]
49 | ctgan = { main = 'ctgan.cli.__main__:main' }
50 |
51 | [project.optional-dependencies]
52 | test = [
53 | 'pytest>=3.4.2',
54 | 'pytest-rerunfailures>=10.3,<15',
55 | 'pytest-cov>=2.6.0',
56 | 'pytest-runner >= 2.11.1',
57 | 'tomli>=2.0.0,<3',
58 | ]
59 | dev = [
60 | 'ctgan[test]',
61 |
62 | # general
63 | 'pip>=9.0.1',
64 | 'build>=1.0.0,<2',
65 | 'bump-my-version>=0.18.3',
66 | 'watchdog>=1.0.1,<5',
67 |
68 | # style check
69 | 'ruff>=0.4.5,<1',
70 |
71 | # distribute on PyPI
72 | 'twine>=1.10.0',
73 | 'wheel>=0.30.0',
74 |
75 | # Advanced testing
76 | 'coverage>=4.5.1,<6',
77 | 'tox>=2.9.1,<4',
78 |
79 | 'invoke',
80 | ]
81 | readme = ['rundoc>=0.4.3,<0.5',]
82 |
83 | [tool.setuptools]
84 | include-package-data = true
85 | license-files = ['LICENSE']
86 |
87 | [tool.setuptools.packages.find]
88 | include = ['ctgan', 'ctgan.*']
89 | namespaces = false
90 |
91 | [tool.setuptools.package-data]
92 | '*' = [
93 | 'AUTHORS.rst',
94 | 'CONTRIBUTING.rst',
95 | 'HISTORY.md',
96 | 'README.md',
97 | '*.md',
98 | '*.rst',
99 | 'conf.py',
100 | 'Makefile',
101 | 'make.bat',
102 | '*.jpg',
103 | '*.png',
104 | '*.gif'
105 | ]
106 |
107 | [tool.setuptools.exclude-package-data]
108 | '*' = [
109 | '* __pycache__',
110 | '*.py[co]',
111 | 'static_code_analysis.txt',
112 | ]
113 |
114 | [tool.setuptools.dynamic]
115 | version = {attr = 'ctgan.__version__'}
116 |
117 | [tool.pytest.ini_options]
118 | collect_ignore = ['pyproject.toml']
119 |
120 | [tool.bumpversion]
121 | current_version = "0.11.1.dev0"
122 | parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?'
123 | serialize = [
124 | '{major}.{minor}.{patch}.{release}{candidate}',
125 | '{major}.{minor}.{patch}'
126 | ]
127 | search = '{current_version}'
128 | replace = '{new_version}'
129 | regex = false
130 | ignore_missing_version = false
131 | tag = true
132 | sign_tags = false
133 | tag_name = 'v{new_version}'
134 | tag_message = 'Bump version: {current_version} → {new_version}'
135 | allow_dirty = false
136 | commit = true
137 | message = 'Bump version: {current_version} → {new_version}'
138 | commit_args = ''
139 |
140 | [tool.bumpversion.parts.release]
141 | first_value = 'dev'
142 | optional_value = 'release'
143 | values = [
144 | 'dev',
145 | 'release'
146 | ]
147 |
148 | [[tool.bumpversion.files]]
149 | filename = "ctgan/__init__.py"
150 | search = "__version__ = '{current_version}'"
151 | replace = "__version__ = '{new_version}'"
152 |
153 | [build-system]
154 | requires = ['setuptools', 'wheel']
155 | build-backend = 'setuptools.build_meta'
156 |
157 | [tool.ruff]
158 | preview = true
159 | line-length = 100
160 | indent-width = 4
161 | src = ["ctgan"]
162 | exclude = [
163 | "docs",
164 | ".tox",
165 | ".git",
166 | "__pycache__",
167 | ".ipynb_checkpoints",
168 | "tasks.py",
169 | ]
170 |
171 | [tool.ruff.lint]
172 | select = [
173 | # Pyflakes
174 | "F",
175 | # Pycodestyle
176 | "E",
177 | "W",
178 | # pydocstyle
179 | "D",
180 | # isort
181 | "I001",
182 | # print statements
183 | "T201",
184 | # pandas-vet
185 | "PD",
186 | # numpy 2.0
187 | "NPY201"
188 | ]
189 | ignore = [
190 | # pydocstyle
191 | "D107", # Missing docstring in __init__
192 | "D417", # Missing argument descriptions in the docstring, this is a bug from pydocstyle: https://github.com/PyCQA/pydocstyle/issues/449
193 | "PD901",
194 | "PD101",
195 | ]
196 |
197 | [tool.ruff.format]
198 | quote-style = "single"
199 | indent-style = "space"
200 | preview = true
201 | docstring-code-format = true
202 | docstring-code-line-length = "dynamic"
203 |
204 | [tool.ruff.lint.isort]
205 | known-first-party = ["ctgan"]
206 | lines-between-types = 0
207 |
208 | [tool.ruff.lint.per-file-ignores]
209 | "__init__.py" = ["F401", "E402", "F403", "F405", "E501", "I001"]
210 | "errors.py" = ["D105"]
211 | "tests/**.py" = ["D"]
212 |
213 | [tool.ruff.lint.pydocstyle]
214 | convention = "google"
215 |
216 | [tool.ruff.lint.pycodestyle]
217 | max-doc-length = 100
218 | max-line-length = 100
219 |
--------------------------------------------------------------------------------
/scripts/release_notes_generator.py:
--------------------------------------------------------------------------------
1 | """Script to generate release notes."""
2 |
3 | import argparse
4 | import os
5 | from collections import defaultdict
6 |
7 | import requests
8 |
9 | LABEL_TO_HEADER = {
10 | 'feature request': 'New Features',
11 | 'bug': 'Bugs Fixed',
12 | 'internal': 'Internal',
13 | 'maintenance': 'Maintenance',
14 | 'customer success': 'Customer Success',
15 | 'documentation': 'Documentation',
16 | 'misc': 'Miscellaneous',
17 | }
18 | ISSUE_LABELS = [
19 | 'documentation',
20 | 'maintenance',
21 | 'internal',
22 | 'bug',
23 | 'feature request',
24 | 'customer success',
25 | ]
26 | ISSUE_LABELS_ORDERED_BY_IMPORTANCE = [
27 | 'feature request',
28 | 'customer success',
29 | 'bug',
30 | 'documentation',
31 | 'internal',
32 | 'maintenance',
33 | ]
34 | NEW_LINE = '\n'
35 | GITHUB_URL = 'https://api.github.com/repos/sdv-dev/ctgan'
36 | GITHUB_TOKEN = os.getenv('GH_ACCESS_TOKEN')
37 |
38 |
39 | def _get_milestone_number(milestone_title):
40 | url = f'{GITHUB_URL}/milestones'
41 | headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'}
42 | query_params = {'milestone': milestone_title, 'state': 'all', 'per_page': 100}
43 | response = requests.get(url, headers=headers, params=query_params, timeout=10)
44 | body = response.json()
45 | if response.status_code != 200:
46 | raise Exception(str(body))
47 |
48 | milestones = body
49 | for milestone in milestones:
50 | if milestone.get('title') == milestone_title:
51 | return milestone.get('number')
52 |
53 | raise ValueError(f'Milestone {milestone_title} not found in past 100 milestones.')
54 |
55 |
56 | def _get_issues_by_milestone(milestone):
57 | headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'}
58 | # get milestone number
59 | milestone_number = _get_milestone_number(milestone)
60 | url = f'{GITHUB_URL}/issues'
61 | page = 1
62 | query_params = {'milestone': milestone_number, 'state': 'all'}
63 | issues = []
64 | while True:
65 | query_params['page'] = page
66 | response = requests.get(url, headers=headers, params=query_params, timeout=10)
67 | body = response.json()
68 | if response.status_code != 200:
69 | raise Exception(str(body))
70 |
71 | issues_on_page = body
72 | if not issues_on_page:
73 | break
74 |
75 | # Filter our PRs
76 | issues_on_page = [issue for issue in issues_on_page if issue.get('pull_request') is None]
77 | issues.extend(issues_on_page)
78 | page += 1
79 |
80 | return issues
81 |
82 |
83 | def _get_issues_by_category(release_issues):
84 | category_to_issues = defaultdict(list)
85 |
86 | for issue in release_issues:
87 | issue_title = issue['title']
88 | issue_number = issue['number']
89 | issue_url = issue['html_url']
90 | line = f'* {issue_title} - Issue [#{issue_number}]({issue_url})'
91 | assignee = issue.get('assignee')
92 | if assignee:
93 | login = assignee['login']
94 | line += f' by @{login}'
95 |
96 | # Check if any known label is marked on the issue
97 | labels = [label['name'] for label in issue['labels']]
98 | found_category = False
99 | for category in ISSUE_LABELS:
100 | if category in labels:
101 | category_to_issues[category].append(line)
102 | found_category = True
103 | break
104 |
105 | if not found_category:
106 | category_to_issues['misc'].append(line)
107 |
108 | return category_to_issues
109 |
110 |
111 | def _create_release_notes(issues_by_category, version, date):
112 | title = f'## v{version} - {date}'
113 | release_notes = f'{title}{NEW_LINE}{NEW_LINE}'
114 |
115 | for category in ISSUE_LABELS_ORDERED_BY_IMPORTANCE + ['misc']:
116 | issues = issues_by_category.get(category)
117 | if issues:
118 | section_text = (
119 | f'### {LABEL_TO_HEADER[category]}{NEW_LINE}{NEW_LINE}'
120 | f'{NEW_LINE.join(issues)}{NEW_LINE}{NEW_LINE}'
121 | )
122 |
123 | release_notes += section_text
124 |
125 | return release_notes
126 |
127 |
128 | def update_release_notes(release_notes):
129 | """Add the release notes for the new release to the ``HISTORY.md``."""
130 | file_path = 'HISTORY.md'
131 | with open(file_path, 'r') as history_file:
132 | history = history_file.read()
133 |
134 | token = '# History\n\n'
135 | split_index = history.find(token) + len(token) + 1
136 | header = history[:split_index]
137 | new_notes = f'{header}{release_notes}{history[split_index:]}'
138 |
139 | with open(file_path, 'w') as new_history_file:
140 | new_history_file.write(new_notes)
141 |
142 |
143 | if __name__ == '__main__':
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument('-v', '--version', type=str, help='Release version number (ie. v1.0.1)')
146 | parser.add_argument('-d', '--date', type=str, help='Date of release in format YYYY-MM-DD')
147 | args = parser.parse_args()
148 | release_number = args.version
149 | release_issues = _get_issues_by_milestone(release_number)
150 | issues_by_category = _get_issues_by_category(release_issues)
151 | release_notes = _create_release_notes(issues_by_category, release_number, args.date)
152 | update_release_notes(release_notes)
153 |
--------------------------------------------------------------------------------
/static_code_analysis.txt:
--------------------------------------------------------------------------------
1 | Run started:2025-02-25 21:10:33.223731
2 |
3 | Test results:
4 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
5 | Severity: Low Confidence: High
6 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
7 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
8 | Location: ./ctgan/__main__.py:121:8
9 | 120 if args.sample_condition_column is not None:
10 | 121 assert args.sample_condition_column_value is not None
11 | 122
12 |
13 | --------------------------------------------------
14 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
15 | Severity: Low Confidence: High
16 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
17 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
18 | Location: ./ctgan/data.py:48:12
19 | 47 else:
20 | 48 assert item[0] == 'D'
21 | 49 discrete.append(idx)
22 |
23 | --------------------------------------------------
24 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
25 | Severity: Low Confidence: High
26 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
27 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
28 | Location: ./ctgan/data.py:69:16
29 | 68 else:
30 | 69 assert idx in discrete
31 | 70 row.append(column_info[idx].index(col))
32 |
33 | --------------------------------------------------
34 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
35 | Severity: Low Confidence: High
36 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
37 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
38 | Location: ./ctgan/data.py:85:20
39 | 84 else:
40 | 85 assert idx in meta['discrete_columns']
41 | 86 print(meta['column_info'][idx][int(col)], end=' ', file=f)
42 |
43 | --------------------------------------------------
44 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
45 | Severity: Low Confidence: High
46 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
47 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
48 | Location: ./ctgan/data_sampler.py:40:8
49 | 39 st += sum([span_info.dim for span_info in column_info])
50 | 40 assert st == data.shape[1]
51 | 41
52 |
53 | --------------------------------------------------
54 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
55 | Severity: Low Confidence: High
56 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
57 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
58 | Location: ./ctgan/synthesizers/ctgan.py:60:8
59 | 59 """Apply the Discriminator to the `input_`."""
60 | 60 assert input_.size()[0] % self.pac == 0
61 | 61 return self.seq(input_.view(-1, self.pacdim))
62 |
63 | --------------------------------------------------
64 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
65 | Severity: Low Confidence: High
66 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
67 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
68 | Location: ./ctgan/synthesizers/ctgan.py:164:8
69 | 163 ):
70 | 164 assert batch_size % 2 == 0
71 | 165
72 |
73 | --------------------------------------------------
74 | >> Issue: [B101:assert_used] Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
75 | Severity: Low Confidence: High
76 | CWE: CWE-703 (https://cwe.mitre.org/data/definitions/703.html)
77 | More Info: https://bandit.readthedocs.io/en/1.7.7/plugins/b101_assert_used.html
78 | Location: ./ctgan/synthesizers/tvae.py:100:4
79 | 99
80 | 100 assert st == recon_x.size()[1]
81 | 101 KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
82 |
83 | --------------------------------------------------
84 |
85 | Code scanned:
86 | Total lines of code: 1414
87 | Total lines skipped (#nosec): 0
88 | Total potential issues skipped due to specifically being disabled (e.g., #nosec BXXX): 0
89 |
90 | Run metrics:
91 | Total issues (by severity):
92 | Undefined: 0
93 | Low: 8
94 | Medium: 0
95 | High: 0
96 | Total issues (by confidence):
97 | Undefined: 0
98 | Low: 0
99 | Medium: 0
100 | High: 8
101 | Files skipped (0):
102 |
--------------------------------------------------------------------------------
/tasks.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import operator
3 | import os
4 | import shutil
5 | import stat
6 | import sys
7 | from pathlib import Path
8 |
9 | import tomli
10 | from invoke import task
11 | from packaging.requirements import Requirement
12 | from packaging.version import Version
13 |
14 | COMPARISONS = {'>=': operator.ge, '>': operator.gt, '<': operator.lt, '<=': operator.le}
15 |
16 | if not hasattr(inspect, 'getargspec'):
17 | inspect.getargspec = inspect.getfullargspec
18 |
19 |
20 | @task
21 | def check_dependencies(c):
22 | c.run('python -m pip check')
23 |
24 |
25 | @task
26 | def unit(c):
27 | c.run('python -m pytest ./tests/unit --cov=ctgan --cov-report=xml:./unit_cov.xml')
28 |
29 |
30 | @task
31 | def integration(c):
32 | c.run('python -m pytest ./tests/integration --reruns 3 --cov=ctgan --cov-report=xml:./integration_cov.xml')
33 |
34 |
35 | @task
36 | def readme(c):
37 | test_path = Path('tests/readme_test')
38 | if test_path.exists() and test_path.is_dir():
39 | shutil.rmtree(test_path)
40 |
41 | cwd = os.getcwd()
42 | os.makedirs(test_path, exist_ok=True)
43 | shutil.copy('README.md', test_path / 'README.md')
44 | os.chdir(test_path)
45 | c.run('rundoc run --single-session python3 -t python3 README.md')
46 | os.chdir(cwd)
47 | shutil.rmtree(test_path)
48 |
49 |
50 | def _get_minimum_versions(dependencies, python_version):
51 | min_versions = {}
52 | for dependency in dependencies:
53 | if '@' in dependency:
54 | name, url = dependency.split(' @ ')
55 | min_versions[name] = f'{url}#egg={name}'
56 | continue
57 |
58 | req = Requirement(dependency)
59 | if ';' in dependency:
60 | marker = req.marker
61 | if marker and not marker.evaluate({'python_version': python_version}):
62 | continue # Skip this dependency if the marker does not apply to the current Python version
63 |
64 | if req.name not in min_versions:
65 | min_version = next(
66 | (spec.version for spec in req.specifier if spec.operator in ('>=', '==')), None
67 | )
68 | if min_version:
69 | min_versions[req.name] = f'{req.name}=={min_version}'
70 |
71 | elif '@' not in min_versions[req.name]:
72 | existing_version = Version(min_versions[req.name].split('==')[1])
73 | new_version = next(
74 | (spec.version for spec in req.specifier if spec.operator in ('>=', '==')),
75 | existing_version,
76 | )
77 | if new_version > existing_version:
78 | min_versions[req.name] = (
79 | f'{req.name}=={new_version}' # Change when a valid newer version is found
80 | )
81 |
82 | return list(min_versions.values())
83 |
84 |
85 | @task
86 | def install_minimum(c):
87 | with open('pyproject.toml', 'rb') as pyproject_file:
88 | pyproject_data = tomli.load(pyproject_file)
89 |
90 | dependencies = pyproject_data.get('project', {}).get('dependencies', [])
91 | python_version = '.'.join(map(str, sys.version_info[:2]))
92 | minimum_versions = _get_minimum_versions(dependencies, python_version)
93 |
94 | if minimum_versions:
95 | install_deps = ' '.join(minimum_versions)
96 | c.run(f'python -m pip install {install_deps}')
97 |
98 |
99 | @task
100 | def minimum(c):
101 | install_minimum(c)
102 | check_dependencies(c)
103 | unit(c)
104 | integration(c)
105 |
106 |
107 | @task
108 | def lint(c):
109 | check_dependencies(c)
110 | c.run('ruff check .')
111 | c.run('ruff format --check --diff .')
112 |
113 |
114 | @task
115 | def fix_lint(c):
116 | check_dependencies(c)
117 | c.run('ruff check --fix .')
118 | c.run('ruff format .')
119 |
120 |
121 | def remove_readonly(func, path, _):
122 | """Clear the readonly bit and reattempt the removal"""
123 | os.chmod(path, stat.S_IWRITE)
124 | func(path)
125 |
126 |
127 | @task
128 | def rmdir(c, path):
129 | try:
130 | shutil.rmtree(path, onerror=remove_readonly)
131 | except PermissionError:
132 | pass
133 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """CTGAN tests."""
2 |
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 | """Integration testing subpackage."""
2 |
--------------------------------------------------------------------------------
/tests/integration/synthesizer/__init__.py:
--------------------------------------------------------------------------------
1 | """Subpackage for integration testing of synthesizers."""
2 |
--------------------------------------------------------------------------------
/tests/integration/synthesizer/test_ctgan.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """Integration tests for ctgan.
5 |
6 | These tests only ensure that the software does not crash and that
7 | the API works as expected in terms of input and output data formats,
8 | but correctness of the data values and the internal behavior of the
9 | model are not checked.
10 | """
11 |
12 | import tempfile as tf
13 |
14 | import numpy as np
15 | import pandas as pd
16 | import pytest
17 |
18 | from ctgan.errors import InvalidDataError
19 | from ctgan.synthesizers.ctgan import CTGAN
20 |
21 |
22 | def test_ctgan_no_categoricals():
23 | """Test the CTGAN with no categorical values."""
24 | data = pd.DataFrame({'continuous': np.random.random(1000)})
25 |
26 | ctgan = CTGAN(epochs=1)
27 | ctgan.fit(data, [])
28 |
29 | sampled = ctgan.sample(100)
30 |
31 | assert sampled.shape == (100, 1)
32 | assert isinstance(sampled, pd.DataFrame)
33 | assert set(sampled.columns) == {'continuous'}
34 | assert len(ctgan.loss_values) == 1
35 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
36 |
37 |
38 | def test_ctgan_dataframe():
39 | """Test the CTGAN when passed a dataframe."""
40 | data = pd.DataFrame({
41 | 'continuous': np.random.random(100),
42 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
43 | })
44 | discrete_columns = ['discrete']
45 |
46 | ctgan = CTGAN(epochs=1)
47 | ctgan.fit(data, discrete_columns)
48 |
49 | sampled = ctgan.sample(100)
50 |
51 | assert sampled.shape == (100, 2)
52 | assert isinstance(sampled, pd.DataFrame)
53 | assert set(sampled.columns) == {'continuous', 'discrete'}
54 | assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
55 | assert len(ctgan.loss_values) == 1
56 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
57 |
58 |
59 | def test_ctgan_numpy():
60 | """Test the CTGAN when passed a numpy array."""
61 | data = pd.DataFrame({
62 | 'continuous': np.random.random(100),
63 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
64 | })
65 | discrete_columns = [1]
66 |
67 | ctgan = CTGAN(epochs=1)
68 | ctgan.fit(data.to_numpy(), discrete_columns)
69 |
70 | sampled = ctgan.sample(100)
71 |
72 | assert sampled.shape == (100, 2)
73 | assert isinstance(sampled, np.ndarray)
74 | assert set(np.unique(sampled[:, 1])) == {'a', 'b', 'c'}
75 | assert len(ctgan.loss_values) == 1
76 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
77 |
78 |
79 | def test_log_frequency():
80 | """Test the CTGAN with no `log_frequency` set to False."""
81 | data = pd.DataFrame({
82 | 'continuous': np.random.random(1000),
83 | 'discrete': np.repeat(['a', 'b', 'c'], [950, 25, 25]),
84 | })
85 |
86 | discrete_columns = ['discrete']
87 |
88 | ctgan = CTGAN(epochs=100)
89 | ctgan.fit(data, discrete_columns)
90 |
91 | assert len(ctgan.loss_values) == 100
92 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
93 | pd.testing.assert_series_equal(ctgan.loss_values['Epoch'], pd.Series(range(100), name='Epoch'))
94 |
95 | sampled = ctgan.sample(10000)
96 | counts = sampled['discrete'].value_counts()
97 | assert counts['a'] < 6500
98 |
99 | ctgan = CTGAN(log_frequency=False, epochs=100)
100 | ctgan.fit(data, discrete_columns)
101 |
102 | assert len(ctgan.loss_values) == 100
103 | assert list(ctgan.loss_values.columns) == ['Epoch', 'Generator Loss', 'Discriminator Loss']
104 | pd.testing.assert_series_equal(ctgan.loss_values['Epoch'], pd.Series(range(100), name='Epoch'))
105 |
106 | sampled = ctgan.sample(10000)
107 | counts = sampled['discrete'].value_counts()
108 | assert counts['a'] > 9000
109 |
110 |
111 | def test_categorical_nan():
112 | """Test the CTGAN with no categorical values."""
113 | data = pd.DataFrame({
114 | 'continuous': np.random.random(30),
115 | # This must be a list (not a np.array) or NaN will be cast to a string.
116 | 'discrete': [np.nan, 'b', 'c'] * 10,
117 | })
118 | discrete_columns = ['discrete']
119 |
120 | ctgan = CTGAN(epochs=1)
121 | ctgan.fit(data, discrete_columns)
122 |
123 | sampled = ctgan.sample(100)
124 |
125 | assert sampled.shape == (100, 2)
126 | assert isinstance(sampled, pd.DataFrame)
127 | assert set(sampled.columns) == {'continuous', 'discrete'}
128 |
129 | # since np.nan != np.nan, we need to be careful here
130 | values = set(sampled['discrete'].unique())
131 | assert len(values) == 3
132 | assert any(pd.isna(x) for x in values)
133 | assert {'b', 'c'}.issubset(values)
134 |
135 |
136 | def test_continuous_nan():
137 | """Test the CTGAN with missing numerical values."""
138 | # Setup
139 | data = pd.DataFrame({
140 | 'continuous': [np.nan, 1.0, 2.0] * 10,
141 | 'discrete': ['a', 'b', 'c'] * 10,
142 | })
143 | discrete_columns = ['discrete']
144 | error_message = (
145 | 'CTGAN does not support null values in the continuous training data. '
146 | 'Please remove all null values from your continuous training data.'
147 | )
148 |
149 | # Run and Assert
150 | ctgan = CTGAN(epochs=1)
151 | with pytest.raises(InvalidDataError, match=error_message):
152 | ctgan.fit(data, discrete_columns)
153 |
154 |
155 | def test_synthesizer_sample():
156 | """Test the CTGAN samples the correct datatype."""
157 | data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)})
158 | discrete_columns = ['discrete']
159 |
160 | ctgan = CTGAN(epochs=1)
161 | ctgan.fit(data, discrete_columns)
162 |
163 | samples = ctgan.sample(1000, 'discrete', 'a')
164 | assert isinstance(samples, pd.DataFrame)
165 |
166 |
167 | def test_save_load():
168 | """Test the CTGAN load/save methods."""
169 | data = pd.DataFrame({
170 | 'continuous': np.random.random(100),
171 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
172 | })
173 | discrete_columns = ['discrete']
174 |
175 | ctgan = CTGAN(epochs=1)
176 | ctgan.fit(data, discrete_columns)
177 |
178 | with tf.TemporaryDirectory() as temporary_directory:
179 | ctgan.save(temporary_directory + 'test_tvae.pkl')
180 | ctgan = CTGAN.load(temporary_directory + 'test_tvae.pkl')
181 |
182 | sampled = ctgan.sample(1000)
183 | assert set(sampled.columns) == {'continuous', 'discrete'}
184 | assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
185 |
186 |
187 | def test_wrong_discrete_columns_dataframe():
188 | """Test the CTGAN correctly crashes when passed non-existing discrete columns."""
189 | data = pd.DataFrame({'discrete': ['a', 'b']})
190 | discrete_columns = ['b', 'c']
191 |
192 | ctgan = CTGAN(epochs=1)
193 | with pytest.raises(ValueError, match="Invalid columns found: {'.*', '.*'}"):
194 | ctgan.fit(data, discrete_columns)
195 |
196 |
197 | def test_wrong_discrete_columns_numpy():
198 | """Test the CTGAN correctly crashes when passed non-existing discrete columns."""
199 | data = pd.DataFrame({'discrete': ['a', 'b']})
200 | discrete_columns = [0, 1]
201 |
202 | ctgan = CTGAN(epochs=1)
203 | with pytest.raises(ValueError, match=r'Invalid columns found: \[1\]'):
204 | ctgan.fit(data.to_numpy(), discrete_columns)
205 |
206 |
207 | def test_wrong_sampling_conditions():
208 | """Test the CTGAN correctly crashes when passed incorrect sampling conditions."""
209 | data = pd.DataFrame({
210 | 'continuous': np.random.random(100),
211 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
212 | })
213 | discrete_columns = ['discrete']
214 |
215 | ctgan = CTGAN(epochs=1)
216 | ctgan.fit(data, discrete_columns)
217 |
218 | with pytest.raises(ValueError, match="The column_name `cardinal` doesn't exist in the data."):
219 | ctgan.sample(1, 'cardinal', "doesn't matter")
220 |
221 | with pytest.raises(ValueError): # noqa: RDT currently incorrectly raises a tuple instead of a string
222 | ctgan.sample(1, 'discrete', 'd')
223 |
224 |
225 | def test_fixed_random_seed():
226 | """Test the CTGAN with a fixed seed.
227 |
228 | Expect that when the random seed is reset with the same seed, the same sequence
229 | of data will be produced. Expect that the data generated with the seed is
230 | different than randomly sampled data.
231 | """
232 | # Setup
233 | data = pd.DataFrame({
234 | 'continuous': np.random.random(100),
235 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
236 | })
237 | discrete_columns = ['discrete']
238 |
239 | ctgan = CTGAN(epochs=1, cuda=False)
240 |
241 | # Run
242 | ctgan.fit(data, discrete_columns)
243 | sampled_random = ctgan.sample(10)
244 |
245 | ctgan.set_random_state(0)
246 | sampled_0_0 = ctgan.sample(10)
247 | sampled_0_1 = ctgan.sample(10)
248 |
249 | ctgan.set_random_state(0)
250 | sampled_1_0 = ctgan.sample(10)
251 | sampled_1_1 = ctgan.sample(10)
252 |
253 | # Assert
254 | assert not np.array_equal(sampled_random, sampled_0_0)
255 | assert not np.array_equal(sampled_random, sampled_0_1)
256 | np.testing.assert_array_equal(sampled_0_0, sampled_1_0)
257 | np.testing.assert_array_equal(sampled_0_1, sampled_1_1)
258 |
259 |
260 | def test_ctgan_save_and_load(tmpdir):
261 | """Test that the ``CTGAN`` model can be saved and loaded."""
262 | # Setup
263 | data = pd.DataFrame({
264 | 'continuous': np.random.random(100),
265 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
266 | })
267 | discrete_columns = [1]
268 |
269 | ctgan = CTGAN(epochs=1)
270 | ctgan.fit(data.to_numpy(), discrete_columns)
271 | ctgan.set_random_state(0)
272 |
273 | ctgan.sample(100)
274 | model_path = tmpdir / 'model.pkl'
275 |
276 | # Save
277 | ctgan.save(str(model_path))
278 |
279 | # Load
280 | loaded_instance = CTGAN.load(str(model_path))
281 | loaded_instance.sample(100)
282 |
--------------------------------------------------------------------------------
/tests/integration/synthesizer/test_tvae.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """Integration tests for tvae.
5 |
6 | These tests only ensure that the software does not crash and that
7 | the API works as expected in terms of input and output data formats,
8 | but correctness of the data values and the internal behavior of the
9 | model are not checked.
10 | """
11 |
12 | import numpy as np
13 | import pandas as pd
14 |
15 | from ctgan.synthesizers.tvae import TVAE
16 |
17 |
18 | def test_drop_last_false():
19 | """Test the TVAE predicts the correct values."""
20 | data = pd.DataFrame({'1': ['a', 'b', 'c'] * 150, '2': ['a', 'b', 'c'] * 150})
21 |
22 | tvae = TVAE(epochs=300)
23 | tvae.fit(data, ['1', '2'])
24 |
25 | sampled = tvae.sample(100)
26 | correct = 0
27 | for _, row in sampled.iterrows():
28 | if row['1'] == row['2']:
29 | correct += 1
30 |
31 | assert correct >= 95
32 |
33 |
34 | def test__loss_function():
35 | """Test the TVAE produces average values similar to the training data."""
36 | data = pd.DataFrame({
37 | '1': [float(i) for i in range(1000)],
38 | '2': [float(2 * i) for i in range(1000)],
39 | })
40 |
41 | tvae = TVAE(epochs=300)
42 | tvae.fit(data)
43 |
44 | num_samples = 1000
45 | sampled = tvae.sample(num_samples)
46 | error = 0
47 | for _, row in sampled.iterrows():
48 | error += abs(2 * row['1'] - row['2'])
49 |
50 | avg_error = error / num_samples
51 |
52 | assert avg_error < 400
53 |
54 |
55 | def test_fixed_random_seed():
56 | """Test the TVAE with a fixed seed.
57 |
58 | Expect that when the random seed is reset with the same seed, the same sequence
59 | of data will be produced. Expect that the data generated with the seed is
60 | different than randomly sampled data.
61 | """
62 | # Setup
63 | data = pd.DataFrame({
64 | 'continuous': np.random.random(100),
65 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
66 | })
67 | discrete_columns = ['discrete']
68 |
69 | tvae = TVAE(epochs=1)
70 |
71 | # Run
72 | tvae.fit(data, discrete_columns)
73 | sampled_random = tvae.sample(10)
74 |
75 | tvae.set_random_state(0)
76 | sampled_0_0 = tvae.sample(10)
77 | sampled_0_1 = tvae.sample(10)
78 |
79 | tvae.set_random_state(0)
80 | sampled_1_0 = tvae.sample(10)
81 | sampled_1_1 = tvae.sample(10)
82 |
83 | # Assert
84 | assert not np.array_equal(sampled_random, sampled_0_0)
85 | assert not np.array_equal(sampled_random, sampled_0_1)
86 | np.testing.assert_array_equal(sampled_0_0, sampled_1_0)
87 | np.testing.assert_array_equal(sampled_0_1, sampled_1_1)
88 |
89 |
90 | def test_tvae_save(tmpdir, capsys):
91 | """Test that the ``TVAE`` model can be saved and loaded."""
92 | # Setup
93 | data = pd.DataFrame({
94 | 'continuous': np.random.random(100),
95 | 'discrete': np.random.choice(['a', 'b', 'c'], 100),
96 | })
97 | discrete_columns = ['discrete']
98 |
99 | tvae = TVAE(epochs=10, verbose=True)
100 | tvae.fit(data, discrete_columns)
101 | captured_out = capsys.readouterr().err
102 | tvae.set_random_state(0)
103 |
104 | tvae.sample(100)
105 | model_path = tmpdir / 'model.pkl'
106 |
107 | # Save
108 | tvae.save(str(model_path))
109 |
110 | # Load
111 | loaded_instance = TVAE.load(str(model_path))
112 | sampled = loaded_instance.sample(100)
113 |
114 | # Assert
115 | assert sampled.shape == (100, 2)
116 | assert isinstance(sampled, pd.DataFrame)
117 | assert set(sampled.columns) == set(data.columns)
118 | assert set(sampled.dtypes) == set(data.dtypes)
119 | loss_values = tvae.loss_values
120 | assert len(loss_values) == 10
121 | assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'}
122 | assert all(loss_values['Batch'] == 0)
123 | last_loss_val = loss_values['Loss'].iloc[-1]
124 | assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out
125 |
--------------------------------------------------------------------------------
/tests/integration/test_data_transformer.py:
--------------------------------------------------------------------------------
1 | """Data transformer intergration testing module."""
2 |
3 | from unittest import TestCase
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from ctgan.data_transformer import DataTransformer
9 |
10 |
11 | class TestDataTransformer(TestCase):
12 | def test_constant(self):
13 | """Test transforming a dataframe containing constant values."""
14 | # Setup
15 | data = pd.DataFrame({'cnt': [123] * 1000})
16 | transformer = DataTransformer()
17 |
18 | # Run
19 | transformer.fit(data, [])
20 | new_data = transformer.transform(data)
21 | transformer.inverse_transform(new_data)
22 |
23 | # Assert transformed values are between -1 and 1
24 | assert (new_data[:, 0] > -np.ones(len(new_data))).all()
25 | assert (new_data[:, 0] < np.ones(len(new_data))).all()
26 |
27 | # Assert transformed values are a gaussian centered in 0 and with std ~ 0
28 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1
29 | assert 0 <= np.std(new_data[:, 0]) < 0.1
30 |
31 | # Assert there are at most `max_columns=10` one hot columns
32 | assert new_data.shape[0] == 1000
33 | assert new_data.shape[1] <= 11
34 | assert np.isin(new_data[:, 1:], [0, 1]).all()
35 |
36 | def test_df_continuous(self):
37 | """Test transforming a dataframe containing only continuous values."""
38 | # Setup
39 | data = pd.DataFrame({'col': np.random.normal(size=1000)})
40 | transformer = DataTransformer()
41 |
42 | # Run
43 | transformer.fit(data, [])
44 | new_data = transformer.transform(data)
45 | transformer.inverse_transform(new_data)
46 |
47 | # Assert transformed values are between -1 and 1
48 | assert (new_data[:, 0] > -np.ones(len(new_data))).all()
49 | assert (new_data[:, 0] < np.ones(len(new_data))).all()
50 |
51 | # Assert transformed values are a gaussian centered in 0 and with std = 1/4
52 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1
53 | assert 0.2 < np.std(new_data[:, 0]) < 0.3
54 |
55 | # Assert there are at most `max_columns=10` one hot columns
56 | assert new_data.shape[0] == 1000
57 | assert new_data.shape[1] <= 11
58 | assert np.isin(new_data[:, 1:], [0, 1]).all()
59 |
60 | def test_df_categorical_constant(self):
61 | """Test transforming a dataframe containing only constant categorical values."""
62 | # Setup
63 | data = pd.DataFrame({'cnt': [123] * 1000})
64 | transformer = DataTransformer()
65 |
66 | # Run
67 | transformer.fit(data, ['cnt'])
68 | new_data = transformer.transform(data)
69 | transformer.inverse_transform(new_data)
70 |
71 | # Assert there is only 1 one hot vector
72 | assert np.array_equal(new_data, np.ones((len(data), 1)))
73 |
74 | def test_df_categorical(self):
75 | """Test transforming a dataframe containing only categorical values."""
76 | # Setup
77 | data = pd.DataFrame({'cat': np.random.choice(['a', 'b', 'c'], size=1000)})
78 | transformer = DataTransformer()
79 |
80 | # Run
81 | transformer.fit(data, ['cat'])
82 | new_data = transformer.transform(data)
83 | transformer.inverse_transform(new_data)
84 |
85 | # Assert there are 3 one hot vectors
86 | assert new_data.shape[0] == 1000
87 | assert new_data.shape[1] == 3
88 | assert np.isin(new_data[:, 1:], [0, 1]).all()
89 |
90 | def test_df_mixed(self):
91 | """Test transforming a dataframe containing mixed data types."""
92 | # Setup
93 | data = pd.DataFrame({
94 | 'num': np.random.normal(size=1000),
95 | 'cat': np.random.choice(['a', 'b', 'c'], size=1000),
96 | })
97 | transformer = DataTransformer()
98 |
99 | # Run
100 | transformer.fit(data, ['cat'])
101 | new_data = transformer.transform(data)
102 | transformer.inverse_transform(new_data)
103 |
104 | # Assert transformed numerical values are between -1 and 1
105 | assert (new_data[:, 0] > -np.ones(len(new_data))).all()
106 | assert (new_data[:, 0] < np.ones(len(new_data))).all()
107 |
108 | # Assert transformed numerical values are a gaussian centered in 0 and with std = 1/4
109 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1
110 | assert 0.2 < np.std(new_data[:, 0]) < 0.3
111 |
112 | # Assert there are at most `max_columns=10` one hot columns for the numerical values
113 | # and 3 for the categorical ones
114 | assert new_data.shape[0] == 1000
115 | assert 5 <= new_data.shape[1] <= 17
116 | assert np.isin(new_data[:, 1:], [0, 1]).all()
117 |
118 | def test_numpy(self):
119 | """Test transforming a numpy array."""
120 | # Setup
121 | data = pd.DataFrame({
122 | 'num': np.random.normal(size=1000),
123 | 'cat': np.random.choice(['a', 'b', 'c'], size=1000),
124 | })
125 | data = np.array(data)
126 | transformer = DataTransformer()
127 |
128 | # Run
129 | transformer.fit(data, [1])
130 | new_data = transformer.transform(data)
131 | transformer.inverse_transform(new_data)
132 |
133 | # Assert transformed numerical values are between -1 and 1
134 | assert (new_data[:, 0] > -np.ones(len(new_data))).all()
135 | assert (new_data[:, 0] < np.ones(len(new_data))).all()
136 |
137 | # Assert transformed numerical values are a gaussian centered in 0 and with std = 1/4
138 | assert -0.1 < np.mean(new_data[:, 0]) < 0.1
139 | assert 0.2 < np.std(new_data[:, 0]) < 0.3
140 |
141 | # Assert there are at most `max_columns=10` one hot columns for the numerical values
142 | # and 3 for the categorical ones
143 | assert new_data.shape[0] == 1000
144 | assert 5 <= new_data.shape[1] <= 17
145 | assert np.isin(new_data[:, 1:], [0, 1]).all()
146 |
--------------------------------------------------------------------------------
/tests/integration/test_load_demo.py:
--------------------------------------------------------------------------------
1 | from ctgan import CTGAN, load_demo
2 |
3 |
4 | def test_load_demo():
5 | """End-to-end test to load and synthesize data."""
6 | # Setup
7 | discrete_columns = [
8 | 'workclass',
9 | 'education',
10 | 'marital-status',
11 | 'occupation',
12 | 'relationship',
13 | 'race',
14 | 'sex',
15 | 'native-country',
16 | 'income',
17 | ]
18 | ctgan = CTGAN(epochs=1)
19 |
20 | # Run
21 | data = load_demo()
22 | ctgan.fit(data, discrete_columns)
23 | samples = ctgan.sample(1000, condition_column='native-country', condition_value='United-States')
24 |
25 | # Assert
26 | assert samples.shape == (1000, 15)
27 | assert all([col[0] != ' ' for col in samples.columns])
28 | assert not samples.isna().any().any()
29 |
--------------------------------------------------------------------------------
/tests/test_tasks.py:
--------------------------------------------------------------------------------
1 | """Tests for the ``tasks.py`` file."""
2 |
3 | from tasks import _get_minimum_versions
4 |
5 |
6 | def test_get_minimum_versions():
7 | """Test the ``_get_minimum_versions`` method.
8 |
9 | The method should return the minimum versions of the dependencies for the given python version.
10 | If a library is linked to an URL, the minimum version should be the URL.
11 | """
12 | # Setup
13 | dependencies = [
14 | "numpy>=1.20.0,<2;python_version<'3.10'",
15 | "numpy>=1.23.3,<2;python_version>='3.10'",
16 | "pandas>=1.2.0,<2;python_version<'3.10'",
17 | "pandas>=1.3.0,<2;python_version>='3.10'",
18 | 'humanfriendly>=8.2,<11',
19 | 'pandas @ git+https://github.com/pandas-dev/pandas.git@master',
20 | ]
21 |
22 | # Run
23 | minimum_versions_39 = _get_minimum_versions(dependencies, '3.9')
24 | minimum_versions_310 = _get_minimum_versions(dependencies, '3.10')
25 |
26 | # Assert
27 | expected_versions_39 = [
28 | 'numpy==1.20.0',
29 | 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
30 | 'humanfriendly==8.2',
31 | ]
32 | expected_versions_310 = [
33 | 'numpy==1.23.3',
34 | 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
35 | 'humanfriendly==8.2',
36 | ]
37 |
38 | assert minimum_versions_39 == expected_versions_39
39 | assert minimum_versions_310 == expected_versions_310
40 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
1 | """Unit testing module."""
2 |
--------------------------------------------------------------------------------
/tests/unit/synthesizer/__init__.py:
--------------------------------------------------------------------------------
1 | """CTGAN testing module."""
2 |
--------------------------------------------------------------------------------
/tests/unit/synthesizer/test_base.py:
--------------------------------------------------------------------------------
1 | """BaseSynthesizer unit testing module."""
2 |
3 | from unittest.mock import MagicMock, call, patch
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from ctgan.synthesizers.base import BaseSynthesizer, random_state
9 |
10 |
11 | @patch('ctgan.synthesizers.base.torch')
12 | @patch('ctgan.synthesizers.base.np.random')
13 | def test_valid_random_state(random_mock, torch_mock):
14 | """Test the ``random_state`` attribute with a valid random state.
15 |
16 | Expect that the decorated function uses the random_state attribute.
17 | """
18 | # Setup
19 | my_function = MagicMock()
20 | instance = MagicMock()
21 |
22 | random_state_mock = MagicMock()
23 | random_state_mock.get_state.return_value = 'desired numpy state'
24 | torch_generator_mock = MagicMock()
25 | torch_generator_mock.get_state.return_value = 'desired torch state'
26 | instance.random_states = (random_state_mock, torch_generator_mock)
27 |
28 | args = {'some', 'args'}
29 | kwargs = {'keyword': 'value'}
30 |
31 | random_mock.RandomState.return_value = random_state_mock
32 | random_mock.get_state.return_value = 'random state'
33 | torch_mock.Generator.return_value = torch_generator_mock
34 | torch_mock.get_rng_state.return_value = 'torch random state'
35 |
36 | # Run
37 | decorated_function = random_state(my_function)
38 | decorated_function(instance, *args, **kwargs)
39 |
40 | # Assert
41 | my_function.assert_called_once_with(instance, *args, **kwargs)
42 |
43 | instance.assert_not_called
44 | assert random_mock.get_state.call_count == 2
45 | assert torch_mock.get_rng_state.call_count == 2
46 | random_mock.RandomState.assert_has_calls([
47 | call().get_state(),
48 | call(),
49 | call().set_state('random state'),
50 | ])
51 | random_mock.set_state.assert_has_calls([call('desired numpy state'), call('random state')])
52 | torch_mock.set_rng_state.assert_has_calls([
53 | call('desired torch state'),
54 | call('torch random state'),
55 | ])
56 |
57 |
58 | @patch('ctgan.synthesizers.base.torch')
59 | @patch('ctgan.synthesizers.base.np.random')
60 | def test_no_random_seed(random_mock, torch_mock):
61 | """Test the ``random_state`` attribute with no random state.
62 |
63 | Expect that the decorated function calls the original function
64 | when there is no random state.
65 | """
66 | # Setup
67 | my_function = MagicMock()
68 | instance = MagicMock()
69 | instance.random_states = None
70 |
71 | args = {'some', 'args'}
72 | kwargs = {'keyword': 'value'}
73 |
74 | # Run
75 | decorated_function = random_state(my_function)
76 | decorated_function(instance, *args, **kwargs)
77 |
78 | # Assert
79 | my_function.assert_called_once_with(instance, *args, **kwargs)
80 |
81 | instance.assert_not_called
82 | random_mock.get_state.assert_not_called()
83 | random_mock.RandomState.assert_not_called()
84 | random_mock.set_state.assert_not_called()
85 | torch_mock.get_rng_state.assert_not_called()
86 | torch_mock.Generator.assert_not_called()
87 | torch_mock.set_rng_state.assert_not_called()
88 |
89 |
90 | class TestBaseSynthesizer:
91 | def test_set_random_state(self):
92 | """Test ``set_random_state`` works as expected."""
93 | # Setup
94 | instance = BaseSynthesizer()
95 |
96 | # Run
97 | instance.set_random_state(3)
98 |
99 | # Assert
100 | assert isinstance(instance.random_states, tuple)
101 | assert isinstance(instance.random_states[0], np.random.RandomState)
102 | assert isinstance(instance.random_states[1], torch.Generator)
103 |
104 | def test_set_random_state_with_none(self):
105 | """Test ``set_random_state`` with None."""
106 | # Setup
107 | instance = BaseSynthesizer()
108 |
109 | # Run and assert
110 | instance.set_random_state(3)
111 | assert instance.random_states is not None
112 |
113 | instance.set_random_state(None)
114 | assert instance.random_states is None
115 |
--------------------------------------------------------------------------------
/tests/unit/synthesizer/test_ctgan.py:
--------------------------------------------------------------------------------
1 | """CTGAN unit testing module."""
2 |
3 | from unittest import TestCase
4 | from unittest.mock import Mock
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import pytest
9 | import torch
10 |
11 | from ctgan.data_transformer import SpanInfo
12 | from ctgan.errors import InvalidDataError
13 | from ctgan.synthesizers.ctgan import CTGAN, Discriminator, Generator, Residual
14 |
15 |
16 | class TestDiscriminator(TestCase):
17 | def test___init__(self):
18 | """Test `__init__` for a generic case.
19 |
20 | Make sure 'self.seq' has same length as 3*`discriminator_dim` + 1.
21 |
22 | Setup:
23 | - Create Discriminator
24 |
25 | Input:
26 | - input_dim = positive integer
27 | - discriminator_dim = list of integers
28 | - pack = positive integer
29 |
30 | Output:
31 | - None
32 |
33 | Side Effects:
34 | - Set `self.seq`, `self.pack` and `self.packdim`
35 | """
36 | discriminator_dim = [1, 2, 3]
37 | discriminator = Discriminator(input_dim=50, discriminator_dim=discriminator_dim, pac=7)
38 |
39 | assert discriminator.pac == 7
40 | assert discriminator.pacdim == 350
41 | assert len(discriminator.seq) == 3 * len(discriminator_dim) + 1
42 |
43 | def test_forward(self):
44 | """Test `test_forward` for a generic case.
45 |
46 | Check that the output shapes are correct.
47 | We can also test that all parameters have a gradient attached to them
48 | by running `encoder.parameters()`. To do that, we just need to use `loss.backward()`
49 | for some loss, like `loss = torch.mean(output)`. Notice that the input_dim = input_size.
50 |
51 | Setup:
52 | - initialize with input_size, discriminator_dim, pac
53 | - Create random tensor as input
54 |
55 | Input:
56 | - input = random tensor of shape (N, input_size)
57 |
58 | Output:
59 | - tensor of shape (N/pac, 1)
60 | """
61 | discriminator = Discriminator(input_dim=50, discriminator_dim=[100, 200, 300], pac=7)
62 | output = discriminator(torch.randn(70, 50))
63 | assert output.shape == (10, 1)
64 |
65 | # Check to make sure no gradients attached
66 | for parameter in discriminator.parameters():
67 | assert parameter.grad is None
68 |
69 | # Backpropagate
70 | output.mean().backward()
71 |
72 | # Check to make sure all parameters have gradients
73 | for parameter in discriminator.parameters():
74 | assert parameter.grad is not None
75 |
76 |
77 | class TestResidual(TestCase):
78 | def test_forward(self):
79 | """Test `test_forward` for a generic case.
80 |
81 | Check that the output shapes are correct.
82 | We can also test that all parameters have a gradient attached to them
83 | by running `encoder.parameters()`. To do that, we just need to use `loss.backward()`
84 | for some loss, like `loss = torch.mean(output)`.
85 |
86 | Setup:
87 | - initialize with input_size, output_size
88 | - Create random tensor as input
89 |
90 | Input:
91 | - input = random tensor of shape (N, input_size)
92 |
93 | Output:
94 | - tensor of shape (N, input_size + output_size)
95 | """
96 | residual = Residual(10, 2)
97 | output = residual(torch.randn(100, 10))
98 | assert output.shape == (100, 12)
99 |
100 | # Check to make sure no gradients attached
101 | for parameter in residual.parameters():
102 | assert parameter.grad is None
103 |
104 | # Backpropagate
105 | output.mean().backward()
106 |
107 | # Check to make sure all parameters have gradients
108 | for parameter in residual.parameters():
109 | assert parameter.grad is not None
110 |
111 |
112 | class TestGenerator(TestCase):
113 | def test___init__(self):
114 | """Test `__init__` for a generic case.
115 |
116 | Make sure `self.seq` has same length as `generator_dim` + 1.
117 |
118 | Setup:
119 | - Create Generator
120 |
121 | Input:
122 | - embedding_dim = positive integer
123 | - generator_dim = list of integers
124 | - data_dim = positive integer
125 |
126 | Output:
127 | - None
128 |
129 | Side Effects:
130 | - Set `self.seq`
131 | """
132 | generator_dim = [1, 2, 3]
133 | generator = Generator(embedding_dim=50, generator_dim=generator_dim, data_dim=7)
134 |
135 | assert len(generator.seq) == len(generator_dim) + 1
136 |
137 | def test_forward(self):
138 | """Test `test_forward` for a generic case.
139 |
140 | Check that the output shapes are correct.
141 | We can also test that all parameters have a gradient attached to them
142 | by running `encoder.parameters()`. To do that, we just need to use `loss.backward()`
143 | for some loss, like `loss = torch.mean(output)`.
144 |
145 | Setup:
146 | - initialize with embedding_dim, generator_dim, data_dim
147 | - Create random tensor as input
148 |
149 | Input:
150 | - input = random tensor of shape (N, input_size)
151 |
152 | Output:
153 | - tensor of shape (N, data_dim)
154 | """
155 | generator = Generator(embedding_dim=60, generator_dim=[100, 200, 300], data_dim=500)
156 | output = generator(torch.randn(70, 60))
157 | assert output.shape == (70, 500)
158 |
159 | # Check to make sure no gradients attached
160 | for parameter in generator.parameters():
161 | assert parameter.grad is None
162 |
163 | # Backpropagate
164 | output.mean().backward()
165 |
166 | # Check to make sure all parameters have gradients
167 | for parameter in generator.parameters():
168 | assert parameter.grad is not None
169 |
170 |
171 | def _assert_is_between(data, lower, upper):
172 | """Assert all values of the tensor 'data' are within range."""
173 | assert all((data >= lower).numpy().tolist())
174 | assert all((data <= upper).numpy().tolist())
175 |
176 |
177 | class TestCTGAN(TestCase):
178 | def test__apply_activate_(self):
179 | """Test `_apply_activate` for tables with both continuous and categoricals.
180 |
181 | Check every continuous column has all values between -1 and 1
182 | (since they are normalized), and check every categorical column adds up to 1.
183 |
184 | Setup:
185 | - Mock `self._transformer.output_info_list`
186 |
187 | Input:
188 | - data = tensor of shape (N, data_dims)
189 |
190 | Output:
191 | - tensor = tensor of shape (N, data_dims)
192 | """
193 | model = CTGAN()
194 | model._transformer = Mock()
195 | model._transformer.output_info_list = [
196 | [SpanInfo(3, 'softmax')],
197 | [SpanInfo(1, 'tanh'), SpanInfo(2, 'softmax')],
198 | ]
199 |
200 | data = torch.randn(100, 6)
201 | result = model._apply_activate(data)
202 |
203 | assert result.shape == (100, 6)
204 | _assert_is_between(result[:, 0:3], 0.0, 1.0)
205 | _assert_is_between(result[:3], -1.0, 1.0)
206 | _assert_is_between(result[:, 4:6], 0.0, 1.0)
207 |
208 | def test__cond_loss(self):
209 | """Test `_cond_loss`.
210 |
211 | Test that the loss is purely a function of the target categorical.
212 |
213 | Setup:
214 | - mock transformer.output_info_list
215 | - create two categoricals, one continuous
216 | - compute the conditional loss, conditioned on the 1st categorical
217 | - compare the loss to the cross-entropy of the 1st categorical, manually computed
218 |
219 | Input:
220 | data - the synthetic data generated by the model
221 | c - a tensor with the same shape as the data but with only a specific one-hot vector
222 | corresponding to the target column filled in
223 | m - binary mask used to select the categorical column to condition on
224 |
225 | Output:
226 | loss scalar; this should only be affected by the target column
227 |
228 | Note:
229 | - even though the implementation of this is probably right, I'm not sure if the idea
230 | behind it is correct
231 | """
232 | model = CTGAN()
233 | model._transformer = Mock()
234 | model._transformer.output_info_list = [
235 | [SpanInfo(1, 'tanh'), SpanInfo(2, 'softmax')],
236 | [SpanInfo(3, 'softmax')], # this is the categorical column we are conditioning on
237 | [SpanInfo(2, 'softmax')], # this is the categorical column we are bry jrbec on
238 | ]
239 |
240 | data = torch.tensor([
241 | # first 3 dims ignored, next 3 dims are the prediction, last 2 dims are ignored
242 | [0.0, -1.0, 0.0, 0.05, 0.05, 0.9, 0.1, 0.4],
243 | ])
244 |
245 | c = torch.tensor([
246 | # first 3 dims are a one-hot for the categorical,
247 | # next 2 are for a different categorical that we are not conditioning on
248 | # (continuous values are not stored in this tensor)
249 | [0.0, 0.0, 1.0, 0.0, 0.0],
250 | ])
251 |
252 | # this indicates that we are conditioning on the first categorical
253 | m = torch.tensor([[1, 0]])
254 |
255 | result = model._cond_loss(data, c, m)
256 | expected = torch.nn.functional.cross_entropy(
257 | torch.tensor([
258 | [0.05, 0.05, 0.9], # 3 categories, one hot
259 | ]),
260 | torch.tensor([2]),
261 | )
262 |
263 | assert (result - expected).abs() < 1e-3
264 |
265 | def test__validate_discrete_columns(self):
266 | """Test `_validate_discrete_columns` if the discrete column doesn't exist.
267 |
268 | Check the appropriate error is raised if `discrete_columns` is invalid, both
269 | for numpy arrays and dataframes.
270 |
271 | Setup:
272 | - Create dataframe with a discrete column
273 | - Define `discrete_columns` as something not in the dataframe
274 |
275 | Input:
276 | - train_data = 2-dimensional numpy array or a pandas.DataFrame
277 | - discrete_columns = list of strings or integers
278 |
279 | Output:
280 | None
281 |
282 | Side Effects:
283 | - Raises error if the discrete column is invalid.
284 |
285 | Note:
286 | - could create another function for numpy array
287 | """
288 | data = pd.DataFrame({'discrete': ['a', 'b']})
289 | discrete_columns = ['doesnt exist']
290 |
291 | ctgan = CTGAN(epochs=1)
292 | with pytest.raises(ValueError, match=r'Invalid columns found: {\'doesnt exist\'}'):
293 | ctgan.fit(data, discrete_columns)
294 |
295 | def test__validate_null_data(self):
296 | """Test `_validate_null_data` with pandas and numpy data.
297 |
298 | Check the appropriate error is raised if null values are present in
299 | continuous columns, both for numpy arrays and dataframes.
300 | """
301 | # Setup
302 | discrete_df = pd.DataFrame({'discrete': ['a', 'b']})
303 | discrete_array = np.array([['a'], ['b']])
304 | continuous_no_nulls_df = pd.DataFrame({'continuous': [0, 1]})
305 | continuous_no_nulls_array = np.array([[0], [1]])
306 | continuous_with_null_df = pd.DataFrame({'continuous': [1, np.nan]})
307 | continuous_with_null_array = np.array([[1], [np.nan]])
308 | ctgan = CTGAN(epochs=1)
309 | error_message = (
310 | 'CTGAN does not support null values in the continuous training data. '
311 | 'Please remove all null values from your continuous training data.'
312 | )
313 |
314 | # Test discrete DataFrame fits without error
315 | ctgan.fit(discrete_df, ['discrete'])
316 |
317 | # Test discrete array fits without error
318 | ctgan.fit(discrete_array, [0])
319 |
320 | # Test continuous DataFrame without nulls fits without error
321 | ctgan.fit(continuous_no_nulls_df)
322 |
323 | # Test continuous array without nulls fits without error
324 | ctgan.fit(continuous_no_nulls_array)
325 |
326 | # Test nulls in continuous columns DataFrame errors on fit
327 | with pytest.raises(InvalidDataError, match=error_message):
328 | ctgan.fit(continuous_with_null_df)
329 |
330 | # Test nulls in continuous columns array errors on fit
331 | with pytest.raises(InvalidDataError, match=error_message):
332 | ctgan.fit(continuous_with_null_array)
333 |
--------------------------------------------------------------------------------
/tests/unit/synthesizer/test_tvae.py:
--------------------------------------------------------------------------------
1 | """TVAE unit testing module."""
2 |
3 | from unittest.mock import MagicMock, Mock, call, patch
4 |
5 | import pandas as pd
6 |
7 | from ctgan.synthesizers import TVAE
8 |
9 |
10 | class TestTVAE:
11 | @patch('ctgan.synthesizers.tvae._loss_function')
12 | @patch('ctgan.synthesizers.tvae.tqdm')
13 | def test_fit_verbose(self, tqdm_mock, loss_func_mock):
14 | """Test verbose parameter prints progress bar."""
15 | # Setup
16 | epochs = 1
17 |
18 | def mock_iter():
19 | for i in range(epochs):
20 | yield i
21 |
22 | def mock_add(a, b):
23 | mock_loss = Mock()
24 | mock_loss.detach().cpu().item.return_value = 1.23456789
25 | return mock_loss
26 |
27 | loss_mock = MagicMock()
28 | loss_mock.__add__ = mock_add
29 | loss_func_mock.return_value = (loss_mock, loss_mock)
30 |
31 | iterator_mock = MagicMock()
32 | iterator_mock.__iter__.side_effect = mock_iter
33 | tqdm_mock.return_value = iterator_mock
34 | synth = TVAE(epochs=epochs, verbose=True)
35 | train_data = pd.DataFrame({'col1': [0, 1, 2, 3, 4], 'col2': [10, 11, 12, 13, 14]})
36 |
37 | # Run
38 | synth.fit(train_data)
39 |
40 | # Assert
41 | tqdm_mock.assert_called_once_with(range(epochs), disable=False)
42 | assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000')
43 | assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235')
44 | assert iterator_mock.set_description.call_count == 2
45 |
--------------------------------------------------------------------------------
/tests/unit/test_data_transformer.py:
--------------------------------------------------------------------------------
1 | """Data transformer unit testing module."""
2 |
3 | from unittest import TestCase
4 | from unittest.mock import Mock, patch
5 |
6 | import numpy as np
7 | import pandas as pd
8 |
9 | from ctgan.data_transformer import ColumnTransformInfo, DataTransformer, SpanInfo
10 |
11 |
12 | class TestDataTransformer(TestCase):
13 | @patch('ctgan.data_transformer.ClusterBasedNormalizer')
14 | def test___fit_continuous(self, MockCBN):
15 | """Test ``_fit_continuous`` on a simple continuous column.
16 |
17 | A ``ClusterBasedNormalizer`` will be created and fit with some ``data``.
18 |
19 | Setup:
20 | - Mock the ``ClusterBasedNormalizer`` with ``valid_component_indicator`` as
21 | ``[True, False, True]``.
22 | - Initialize a ``DataTransformer``.
23 |
24 | Input:
25 | - A dataframe with only one column containing random float values.
26 |
27 | Output:
28 | - A ``ColumnTransformInfo`` object where:
29 | - ``column_name`` matches the column of the data.
30 | - ``transform`` is the ``ClusterBasedNormalizer`` instance.
31 | - ``output_dimensions`` is 3 (matches size of ``valid_component_indicator``).
32 | - ``output_info`` assigns the correct activation functions.
33 |
34 | Side Effects:
35 | - ``fit`` should be called with the data.
36 | """
37 | # Setup
38 | cbn_instance = MockCBN.return_value
39 | cbn_instance.valid_component_indicator = [True, False, True]
40 | transformer = DataTransformer()
41 | data = pd.DataFrame(np.random.normal((100, 1)), columns=['column'])
42 |
43 | # Run
44 | info = transformer._fit_continuous(data)
45 |
46 | # Assert
47 | assert info.column_name == 'column'
48 | assert info.transform == cbn_instance
49 | assert info.output_dimensions == 3
50 | assert info.output_info[0].dim == 1
51 | assert info.output_info[0].activation_fn == 'tanh'
52 | assert info.output_info[1].dim == 2
53 | assert info.output_info[1].activation_fn == 'softmax'
54 |
55 | @patch('ctgan.data_transformer.ClusterBasedNormalizer')
56 | def test__fit_continuous_max_clusters(self, MockCBN):
57 | """Test ``_fit_continuous`` with data that has less than 10 rows.
58 |
59 | Expect that a ``ClusterBasedNormalizer`` is created with the max number of clusters
60 | set to the length of the data.
61 |
62 | Input:
63 | - Data with less than 10 rows.
64 |
65 | Side Effects:
66 | - A ``ClusterBasedNormalizer`` is created with the max number of clusters set to the
67 | length of the data.
68 | """
69 | # Setup
70 | data = pd.DataFrame(np.random.normal((7, 1)), columns=['column'])
71 | transformer = DataTransformer()
72 |
73 | # Run
74 | transformer._fit_continuous(data)
75 |
76 | # Assert
77 | MockCBN.assert_called_once_with(
78 | missing_value_generation='from_column', max_clusters=len(data), weight_threshold=0.005
79 | )
80 |
81 | @patch('ctgan.data_transformer.OneHotEncoder')
82 | def test___fit_discrete(self, MockOHE):
83 | """Test ``_fit_discrete_`` on a simple discrete column.
84 |
85 | A ``OneHotEncoder`` will be created and fit with the ``data``.
86 |
87 | Setup:
88 | - Mock the ``OneHotEncoder``.
89 | - Create ``DataTransformer``.
90 |
91 | Input:
92 | - A dataframe with only one column containing ``['a', 'b']`` values.
93 |
94 | Output:
95 | - A ``ColumnTransformInfo`` object where:
96 | - ``column_name`` matches the column of the data.
97 | - ``transform`` is the ``OneHotEncoder`` instance.
98 | - ``output_dimensions`` is 2.
99 | - ``output_info`` assigns the correct activation function.
100 |
101 | Side Effects:
102 | - ``fit`` should be called with the data.
103 | """
104 | # Setup
105 | ohe_instance = MockOHE.return_value
106 | ohe_instance.dummies = ['a', 'b']
107 | transformer = DataTransformer()
108 | data = pd.DataFrame(np.array(['a', 'b'] * 100), columns=['column'])
109 |
110 | # Run
111 | info = transformer._fit_discrete(data)
112 |
113 | # Assert
114 | assert info.column_name == 'column'
115 | assert info.transform == ohe_instance
116 | assert info.output_dimensions == 2
117 | assert info.output_info[0].dim == 2
118 | assert info.output_info[0].activation_fn == 'softmax'
119 |
120 | def test_fit(self):
121 | """Test ``fit`` on a np.ndarray with one continuous and one discrete columns.
122 |
123 | The ``fit`` method should:
124 | - Set ``self.dataframe`` to ``False``.
125 | - Set ``self._column_raw_dtypes`` to the appropirate dtypes.
126 | - Use the appropriate ``_fit`` type for each column.
127 | - Update ``self.output_info_list``, ``self.output_dimensions`` and
128 | ``self._column_transform_info_list`` appropriately.
129 |
130 | Setup:
131 | - Create ``DataTransformer``.
132 | - Mock ``_fit_discrete``.
133 | - Mock ``_fit_continuous``.
134 |
135 | Input:
136 | - A table with one continuous and one discrete columns.
137 | - A list with the name of the discrete column.
138 |
139 | Side Effects:
140 | - ``_fit_discrete`` and ``_fit_continuous`` should each be called once.
141 | - Assigns ``self._column_raw_dtypes`` the appropriate dtypes.
142 | - Assigns ``self.output_info_list`` the appropriate ``output_info``.
143 | - Assigns ``self.output_dimensions`` the appropriate ``output_dimensions``.
144 | - Assigns ``self._column_transform_info_list`` the appropriate
145 | ``column_transform_info``.
146 | """
147 | # Setup
148 | transformer = DataTransformer()
149 | transformer._fit_continuous = Mock()
150 | transformer._fit_continuous.return_value = ColumnTransformInfo(
151 | column_name='x',
152 | column_type='continuous',
153 | transform=None,
154 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
155 | output_dimensions=1 + 3,
156 | )
157 |
158 | transformer._fit_discrete = Mock()
159 | transformer._fit_discrete.return_value = ColumnTransformInfo(
160 | column_name='y',
161 | column_type='discrete',
162 | transform=None,
163 | output_info=[SpanInfo(2, 'softmax')],
164 | output_dimensions=2,
165 | )
166 |
167 | data = pd.DataFrame({
168 | 'x': np.random.random(size=100),
169 | 'y': np.random.choice(['yes', 'no'], size=100),
170 | })
171 |
172 | # Run
173 | transformer.fit(data, discrete_columns=['y'])
174 |
175 | # Assert
176 | transformer._fit_discrete.assert_called_once()
177 | transformer._fit_continuous.assert_called_once()
178 | assert transformer.output_dimensions == 6
179 |
180 | @patch('ctgan.data_transformer.ClusterBasedNormalizer')
181 | def test__transform_continuous(self, MockCBN):
182 | """Test ``_transform_continuous``.
183 |
184 | Setup:
185 | - Mock the ``ClusterBasedNormalizer`` with the transform method returning
186 | some dataframe.
187 | - Create ``DataTransformer``.
188 |
189 | Input:
190 | - ``ColumnTransformInfo`` object.
191 | - A dataframe containing a continuous column.
192 |
193 | Output:
194 | - A np.array where the first column contains the normalized part
195 | of the mocked transform, and the other columns are a one hot encoding
196 | representation of the component part of the mocked transform.
197 | """
198 | # Setup
199 | cbn_instance = MockCBN.return_value
200 | cbn_instance.transform.return_value = pd.DataFrame({
201 | 'x.normalized': [0.1, 0.2, 0.3],
202 | 'x.component': [0.0, 1.0, 1.0],
203 | })
204 |
205 | transformer = DataTransformer()
206 | data = pd.DataFrame({'x': np.array([0.1, 0.3, 0.5])})
207 | column_transform_info = ColumnTransformInfo(
208 | column_name='x',
209 | column_type='continuous',
210 | transform=cbn_instance,
211 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
212 | output_dimensions=1 + 3,
213 | )
214 |
215 | # Run
216 | result = transformer._transform_continuous(column_transform_info, data)
217 |
218 | # Assert
219 | expected = np.array([
220 | [0.1, 1, 0, 0],
221 | [0.2, 0, 1, 0],
222 | [0.3, 0, 1, 0],
223 | ])
224 | np.testing.assert_array_equal(result, expected)
225 |
226 | def test_transform(self):
227 | """Test ``transform`` on a dataframe with one continuous and one discrete columns.
228 |
229 | It should use the appropriate ``_transform`` type for each column and should return
230 | them concanenated appropriately.
231 |
232 | Setup:
233 | - Initialize a ``DataTransformer`` with a ``column_transform_info`` detailing
234 | a continuous and a discrete columns.
235 | - Mock the ``_transform_discrete`` and ``_transform_continuous`` methods.
236 |
237 | Input:
238 | - A table with one continuous and one discrete columns.
239 |
240 | Output:
241 | - np.array containing the transformed columns.
242 | """
243 | # Setup
244 | data = pd.DataFrame({'x': np.array([0.1, 0.3, 0.5]), 'y': np.array(['yes', 'yes', 'no'])})
245 |
246 | transformer = DataTransformer()
247 | transformer._column_transform_info_list = [
248 | ColumnTransformInfo(
249 | column_name='x',
250 | column_type='continuous',
251 | transform=None,
252 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
253 | output_dimensions=1 + 3,
254 | ),
255 | ColumnTransformInfo(
256 | column_name='y',
257 | column_type='discrete',
258 | transform=None,
259 | output_info=[SpanInfo(2, 'softmax')],
260 | output_dimensions=2,
261 | ),
262 | ]
263 |
264 | transformer._transform_continuous = Mock()
265 | selected_normalized_value = np.array([[0.1], [0.3], [0.5]])
266 | selected_component_onehot = np.array([
267 | [1, 0, 0],
268 | [0, 1, 0],
269 | [0, 1, 0],
270 | ])
271 | return_value = np.concatenate(
272 | (selected_normalized_value, selected_component_onehot), axis=1
273 | )
274 | transformer._transform_continuous.return_value = return_value
275 |
276 | transformer._transform_discrete = Mock()
277 | transformer._transform_discrete.return_value = np.array([
278 | [0, 1],
279 | [0, 1],
280 | [1, 0],
281 | ])
282 |
283 | # Run
284 | result = transformer.transform(data)
285 |
286 | # Assert
287 | expected = np.array([
288 | [0.1, 1, 0, 0, 0, 1],
289 | [0.3, 0, 1, 0, 0, 1],
290 | [0.5, 0, 1, 0, 1, 0],
291 | ])
292 | assert result.shape == (3, 6)
293 | assert (result[:, 0] == expected[:, 0]).all(), 'continuous-cdf'
294 | assert (result[:, 1:4] == expected[:, 1:4]).all(), 'continuous-softmax'
295 | assert (result[:, 4:6] == expected[:, 4:6]).all(), 'discrete'
296 |
297 | def test_parallel_sync_transform_same_output(self):
298 | """Test ``_parallel_transform`` and ``_synchronous_transform`` on a dataframe.
299 |
300 | The output of ``_parallel_transform`` should be the same as the output of
301 | ``_synchronous_transform``.
302 |
303 | Setup:
304 | - Initialize a ``DataTransformer`` with a ``column_transform_info`` detailing
305 | a continuous and a discrete columns.
306 | - Mock the ``_transform_discrete`` and ``_transform_continuous`` methods.
307 |
308 | Input:
309 | - A table with one continuous and one discrete columns.
310 |
311 | Output:
312 | - A list containing the transformed columns.
313 | """
314 | # Setup
315 | data = pd.DataFrame({'x': np.array([0.1, 0.3, 0.5]), 'y': np.array(['yes', 'yes', 'no'])})
316 |
317 | transformer = DataTransformer()
318 | transformer._column_transform_info_list = [
319 | ColumnTransformInfo(
320 | column_name='x',
321 | column_type='continuous',
322 | transform=None,
323 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
324 | output_dimensions=1 + 3,
325 | ),
326 | ColumnTransformInfo(
327 | column_name='y',
328 | column_type='discrete',
329 | transform=None,
330 | output_info=[SpanInfo(2, 'softmax')],
331 | output_dimensions=2,
332 | ),
333 | ]
334 |
335 | transformer._transform_continuous = Mock()
336 | selected_normalized_value = np.array([[0.1], [0.3], [0.5]])
337 | selected_component_onehot = np.array([
338 | [1, 0, 0],
339 | [0, 1, 0],
340 | [0, 1, 0],
341 | ])
342 | return_value = np.concatenate(
343 | (selected_normalized_value, selected_component_onehot), axis=1
344 | )
345 | transformer._transform_continuous.return_value = return_value
346 |
347 | transformer._transform_discrete = Mock()
348 | transformer._transform_discrete.return_value = np.array([
349 | [0, 1],
350 | [0, 1],
351 | [1, 0],
352 | ])
353 |
354 | # Run
355 | parallel_result = transformer._parallel_transform(
356 | data, transformer._column_transform_info_list
357 | )
358 | sync_result = transformer._synchronous_transform(
359 | data, transformer._column_transform_info_list
360 | )
361 | parallel_result_np = np.concatenate(parallel_result, axis=1).astype(float)
362 | sync_result_np = np.concatenate(sync_result, axis=1).astype(float)
363 |
364 | # Assert
365 | assert len(parallel_result) == len(sync_result)
366 | np.testing.assert_array_equal(parallel_result_np, sync_result_np)
367 |
368 | @patch('ctgan.data_transformer.ClusterBasedNormalizer')
369 | def test__inverse_transform_continuous(self, MockCBN):
370 | """Test ``_inverse_transform_continuous``.
371 |
372 | Setup:
373 | - Create ``DataTransformer``.
374 | - Mock the ``ClusterBasedNormalizer`` where:
375 | - ``get_output_sdtypes`` returns the appropriate dictionary.
376 | - ``reverse_transform`` returns some dataframe.
377 |
378 | Input:
379 | - A ``ColumnTransformInfo`` object.
380 | - A np.ndarray where:
381 | - The first column contains the normalized value
382 | - The remaining columns correspond to the one-hot
383 | - sigmas = np.ndarray of floats
384 | - st = index of the sigmas ndarray
385 |
386 | Output:
387 | - Dataframe where the first column are floats and the second is a lable encoding.
388 |
389 | Side Effects:
390 | - The ``reverse_transform`` method should be called with a dataframe
391 | where the first column are floats and the second is a lable encoding.
392 | """
393 | # Setup
394 | cbn_instance = MockCBN.return_value
395 | cbn_instance.get_output_sdtypes.return_value = {
396 | 'x.normalized': 'numerical',
397 | 'x.component': 'numerical',
398 | }
399 |
400 | cbn_instance.reverse_transform.return_value = pd.DataFrame({
401 | 'x.normalized': [0.1, 0.2, 0.3],
402 | 'x.component': [0.0, 1.0, 1.0],
403 | })
404 |
405 | transformer = DataTransformer()
406 | column_data = np.array([
407 | [0.1, 1, 0, 0],
408 | [0.3, 0, 1, 0],
409 | [0.5, 0, 1, 0],
410 | ])
411 |
412 | column_transform_info = ColumnTransformInfo(
413 | column_name='x',
414 | column_type='continuous',
415 | transform=cbn_instance,
416 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
417 | output_dimensions=1 + 3,
418 | )
419 |
420 | # Run
421 | result = transformer._inverse_transform_continuous(
422 | column_transform_info, column_data, None, None
423 | )
424 |
425 | # Assert
426 | expected = pd.DataFrame({'x.normalized': [0.1, 0.2, 0.3], 'x.component': [0.0, 1.0, 1.0]})
427 |
428 | np.testing.assert_array_equal(result, expected)
429 |
430 | expected_data = pd.DataFrame({'x.normalized': [0.1, 0.3, 0.5], 'x.component': [0, 1, 1]})
431 |
432 | pd.testing.assert_frame_equal(cbn_instance.reverse_transform.call_args[0][0], expected_data)
433 |
434 | def test_convert_column_name_value_to_id(self):
435 | """Test ``convert_column_name_value_to_id`` on a simple ``_column_transform_info_list``.
436 |
437 | Tests that the appropriate indexes are returned when a table of three columns,
438 | discrete, continuous, discrete, is passed as '_column_transform_info_list'.
439 |
440 | Setup:
441 | - Mock ``_column_transform_info_list``.
442 |
443 | Input:
444 | - column_name = the name of a discrete column
445 | - value = the categorical value
446 |
447 | Output:
448 | - dictionary containing:
449 | - ``discrete_column_id`` = the index of the target column,
450 | when considering only discrete columns
451 | - ``column_id`` = the index of the target column
452 | (e.g. 3 = the third column in the data)
453 | - ``value_id`` = the index of the indicator value in the one-hot encoding
454 | """
455 | # Setup
456 | ohe = Mock()
457 | ohe.transform.return_value = pd.DataFrame([
458 | [0, 1] # one hot encoding, second dimension
459 | ])
460 | transformer = DataTransformer()
461 | transformer._column_transform_info_list = [
462 | ColumnTransformInfo(
463 | column_name='x',
464 | column_type='continuous',
465 | transform=None,
466 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
467 | output_dimensions=1 + 3,
468 | ),
469 | ColumnTransformInfo(
470 | column_name='y',
471 | column_type='discrete',
472 | transform=ohe,
473 | output_info=[SpanInfo(2, 'softmax')],
474 | output_dimensions=2,
475 | ),
476 | ]
477 |
478 | # Run
479 | result = transformer.convert_column_name_value_to_id('y', 'yes')
480 |
481 | # Assert
482 | assert result['column_id'] == 1 # this is the 2nd column
483 | assert result['discrete_column_id'] == 0 # this is the 1st discrete column
484 | assert result['value_id'] == 1 # this is the 2nd dimension in the one hot encoding
485 |
486 | def test_convert_column_name_value_to_id_multiple(self):
487 | """Test ``convert_column_name_value_to_id``."""
488 | # Setup
489 | ohe = Mock()
490 | ohe.transform.return_value = pd.DataFrame([
491 | [0, 1, 0] # one hot encoding, second dimension
492 | ])
493 | transformer = DataTransformer()
494 | transformer._column_transform_info_list = [
495 | ColumnTransformInfo(
496 | column_name='x',
497 | column_type='continuous',
498 | transform=None,
499 | output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
500 | output_dimensions=1 + 3,
501 | ),
502 | ColumnTransformInfo(
503 | column_name='y',
504 | column_type='discrete',
505 | transform=ohe,
506 | output_info=[SpanInfo(2, 'softmax')],
507 | output_dimensions=2,
508 | ),
509 | ColumnTransformInfo(
510 | column_name='z',
511 | column_type='discrete',
512 | transform=ohe,
513 | output_info=[SpanInfo(2, 'softmax')],
514 | output_dimensions=2,
515 | ),
516 | ]
517 |
518 | # Run
519 | result = transformer.convert_column_name_value_to_id('z', 'yes')
520 |
521 | # Assert
522 | assert result['column_id'] == 2 # this is the 3rd column
523 | assert result['discrete_column_id'] == 1 # this is the 2nd discrete column
524 | assert result['value_id'] == 1 # this is the 1st dimension in the one hot encoding
525 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py39-lint, py3{8,9,10,11,12,13}-{unit,integration,readme}
3 |
4 | [testenv]
5 | skipsdist = false
6 | skip_install = false
7 | deps =
8 | invoke
9 | readme: rundoc
10 | extras =
11 | lint: dev
12 | unit: test
13 | integration: test
14 | commands =
15 | lint: invoke lint
16 | unit: invoke unit
17 | integration: invoke integration
18 | readme: invoke readme
19 | invoke rmdir --path {envdir}
20 |
--------------------------------------------------------------------------------