├── .coveragerc ├── .editorconfig ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md ├── PULL_REQUEST_TEMPLATE │ └── new_transformer.md ├── auto_assign.yml └── workflows │ ├── dependency_checker.yml │ ├── integration.yml │ ├── lint.yml │ ├── minimum.yml │ ├── prepare_release.yml │ ├── readme.yml │ ├── release.yml │ └── unit.yml ├── .gitignore ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── DEVELOPMENT.rst ├── HISTORY.md ├── INSTALL.md ├── LICENSE ├── Makefile ├── README.md ├── RELEASE.md ├── codecov.yml ├── latest_requirements.txt ├── pyproject.toml ├── rdt ├── __init__.py ├── _utils.py ├── errors.py ├── hyper_transformer.py ├── performance │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── boolean.py │ │ ├── categorical.py │ │ ├── datetime.py │ │ ├── id.py │ │ ├── numerical.py │ │ ├── pii.py │ │ ├── text.py │ │ └── utils.py │ ├── performance.py │ └── profiling.py └── transformers │ ├── __init__.py │ ├── _validators.py │ ├── base.py │ ├── boolean.py │ ├── categorical.py │ ├── datetime.py │ ├── id.py │ ├── null.py │ ├── numerical.py │ ├── pii │ ├── __init__.py │ ├── anonymization.py │ ├── anonymizer.py │ └── utils.py │ ├── text.py │ └── utils.py ├── scripts ├── check_for_prereleases.py └── release_notes_generator.py ├── static_code_analysis.txt ├── tasks.py └── tests ├── __init__.py ├── code_style.py ├── contributing.py ├── datasets └── tests │ ├── test_boolean.py │ ├── test_categorical.py │ ├── test_datetime.py │ ├── test_numerical.py │ └── test_utils.py ├── integration ├── __init__.py ├── test_hyper_transformer.py ├── test_transformers.py └── transformers │ ├── __init__.py │ ├── pii │ ├── __init__.py │ ├── test_anonymization.py │ └── test_anonymizer.py │ ├── test_base.py │ ├── test_boolean.py │ ├── test_categorical.py │ ├── test_datetime.py │ ├── test_id.py │ └── test_numerical.py ├── performance ├── README.md ├── __init__.py ├── test_performance.py └── tests │ ├── __init__.py │ └── test_profiling.py ├── test_scripts.py ├── test_tasks.py └── unit ├── __init__.py ├── test___init__.py ├── test__utils.py ├── test_hyper_transformer.py └── transformers ├── __init__.py ├── pii ├── __init__.py ├── test_anonymization.py ├── test_anonymizer.py └── test_utils.py ├── test___init__.py ├── test__validators.py ├── test_base.py ├── test_boolean.py ├── test_categorical.py ├── test_datetime.py ├── test_id.py ├── test_null.py ├── test_numerical.py ├── test_text.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = rdt/performance/* -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.py] 14 | max_line_length = 99 15 | 16 | [*.bat] 17 | indent_style = tab 18 | end_of_line = crlf 19 | 20 | [LICENSE] 21 | insert_final_newline = false 22 | 23 | [Makefile] 24 | indent_style = tab 25 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Global rule: 2 | * @sdv-dev/core-contributors 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report an error that you found when using RDT 4 | title: '' 5 | labels: bug, new 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Environment Details 11 | 12 | Please indicate the following details about the environment in which you found the bug: 13 | 14 | * RDT version: 15 | * Python version: 16 | * Operating System: 17 | 18 | ### Error Description 19 | 20 | 22 | 23 | ### Steps to reproduce 24 | 25 | 29 | 30 | ``` 31 | Paste the command(s) you ran and the output. 32 | If there was a crash, please include the traceback here. 33 | ``` 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Request a new feature that you would like to see implemented in RDT 4 | title: '' 5 | labels: new feature, new 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Problem Description 11 | 12 | 14 | 15 | ### Expected behavior 16 | 17 | 20 | 21 | ### Additional context 22 | 23 | 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Doubts about RDT usage 4 | title: '' 5 | labels: question, new 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Environment details 11 | 12 | If you are already running RDT, please indicate the following details about the environment in 13 | which you are running it: 14 | 15 | * RDT version: 16 | * Python version: 17 | * Operating System: 18 | 19 | ### Problem description 20 | 21 | 24 | 25 | ### What I already tried 26 | 27 | 29 | 30 | ``` 31 | Paste the command(s) you ran and the output. 32 | If there was a crash, please include the traceback here. 33 | ``` 34 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/new_transformer.md: -------------------------------------------------------------------------------- 1 | ### New Transformer Description 2 | 3 | __sdtype__: (replace this with the sdtype that your new transformer handles) 4 | 5 | (replace this text with your description) 6 | 7 | ### Checklist 8 | - [ ] Does this PR resolve an existing issue? If not, please open a new issue. Add the issue number here: Resolves #xxx 9 | - [ ] Is this PR introducing one (and only one) new Transformer? 10 | - [ ] Have you implemented Unit tests for all the methods in your Transformer, and successfully run them? 11 | - [ ] If the sdtype that this Transformer addresses is new, have you added Dataset Generators and Real World Datasets for this new sdtype? 12 | - [ ] Have you successfully run the Integration tests on your Transformer? 13 | - [ ] Have you run the Performance tests, and optimized your transformer according to the [Common Performance Pitfalls](https://github.com/sdv-dev/RDT/blob/main/CONTRIBUTING.rst#common-performance-pitfalls)? 14 | 15 | ---- 16 | Please follow the [Contributing Guide](https://github.com/sdv-dev/RDT/blob/main/CONTRIBUTING.rst#contributing) to add a new transformer. 17 | -------------------------------------------------------------------------------- /.github/auto_assign.yml: -------------------------------------------------------------------------------- 1 | # Set to true to add assignees to pull requests 2 | addAssignees: true -------------------------------------------------------------------------------- /.github/workflows/dependency_checker.yml: -------------------------------------------------------------------------------- 1 | name: Dependency Checker 2 | on: 3 | workflow_dispatch: 4 | schedule: 5 | - cron: '0 0 * * 1' 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Set up latest Python 12 | uses: actions/setup-python@v5 13 | with: 14 | python-version-file: 'pyproject.toml' 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install .[dev] 18 | make check-deps OUTPUT_FILEPATH=latest_requirements.txt 19 | make fix-lint 20 | - name: Create pull request 21 | id: cpr 22 | uses: peter-evans/create-pull-request@v4 23 | with: 24 | token: ${{ secrets.GH_ACCESS_TOKEN }} 25 | commit-message: Update latest dependencies 26 | author: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 27 | committer: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 28 | title: Automated Latest Dependency Updates 29 | body: "This is an auto-generated PR with **latest** dependency updates." 30 | branch: latest-dependency-update 31 | branch-suffix: short-commit-hash 32 | base: main 33 | -------------------------------------------------------------------------------- /.github/workflows/integration.yml: -------------------------------------------------------------------------------- 1 | name: Integration Tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | integration: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 18 | os: [ubuntu-latest, windows-latest] 19 | include: 20 | - os: macos-latest 21 | python-version: '3.8' 22 | - os: macos-latest 23 | python-version: '3.13' 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | python -m pip install invoke .[test] 34 | - name: Run integration tests 35 | run: invoke integration 36 | 37 | - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.13 38 | name: Upload integration codecov report 39 | uses: codecov/codecov-action@v4 40 | with: 41 | flags: integration 42 | file: ${{ github.workspace }}/integration_cov.xml 43 | fail_ci_if_error: true 44 | token: ${{ secrets.CODECOV_TOKEN }} 45 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Style Checks 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up latest Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version-file: 'pyproject.toml' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install invoke .[dev] 25 | - name: Run lint checks 26 | run: invoke lint 27 | -------------------------------------------------------------------------------- /.github/workflows/minimum.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests Minimum Versions 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | minimum: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 18 | os: [ubuntu-latest, windows-latest] 19 | include: 20 | - os: macos-latest 21 | python-version: '3.8' 22 | - os: macos-latest 23 | python-version: '3.13' 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | python -m pip install invoke .[test] 34 | - name: Run tests with minimum versions 35 | run: invoke minimum 36 | -------------------------------------------------------------------------------- /.github/workflows/prepare_release.yml: -------------------------------------------------------------------------------- 1 | name: Release Prep 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | branch: 7 | description: 'Branch to merge release notes and code analysis into.' 8 | required: true 9 | default: 'main' 10 | version: 11 | description: 12 | 'Version to use for the release. Must be in format: X.Y.Z.' 13 | date: 14 | description: 15 | 'Date of the release. Must be in format YYYY-MM-DD.' 16 | 17 | jobs: 18 | preparerelease: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up latest Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version-file: 'pyproject.toml' 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install requests==2.31.0 31 | python -m pip install bandit==1.7.7 32 | python -m pip install packaging 33 | python -m pip install .[test] 34 | 35 | - name: Check for prerelease dependencies 36 | run: python scripts/check_for_prereleases.py 37 | 38 | - name: Generate release notes 39 | env: 40 | GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} 41 | run: > 42 | python scripts/release_notes_generator.py 43 | -v ${{ inputs.version }} 44 | -d ${{ inputs.date }} 45 | 46 | - name: Save static code analysis 47 | run: bandit -r . -x ./tests,./scripts,./build -f txt -o static_code_analysis.txt --exit-zero 48 | 49 | - name: Create pull request 50 | id: cpr 51 | uses: peter-evans/create-pull-request@v4 52 | with: 53 | token: ${{ secrets.GH_ACCESS_TOKEN }} 54 | commit-message: Prepare release for v${{ inputs.version }} 55 | author: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 56 | committer: "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 57 | title: v${{ inputs.version }} Release Preparation 58 | body: "This is an auto-generated PR to prepare the release." 59 | branch: prepared-release 60 | branch-suffix: short-commit-hash 61 | base: ${{ inputs.branch }} 62 | -------------------------------------------------------------------------------- /.github/workflows/readme.yml: -------------------------------------------------------------------------------- 1 | name: Test README 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | readme: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 18 | os: [ubuntu-latest, macos-latest] # skip windows bc rundoc fails 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install invoke rundoc . 29 | python -m pip install tomli 30 | python -m pip install packaging 31 | - name: Run the README.md 32 | run: invoke readme 33 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | release: 4 | types: [published] 5 | branches: 6 | - main 7 | - stable 8 | 9 | workflow_dispatch: 10 | inputs: 11 | candidate: 12 | description: 'Release candidate.' 13 | required: true 14 | type: boolean 15 | default: true 16 | test_pypi: 17 | description: 'Test PyPI.' 18 | type: boolean 19 | default: false 20 | jobs: 21 | release: 22 | runs-on: ubuntu-latest 23 | permissions: 24 | id-token: write 25 | steps: 26 | - uses: actions/checkout@v4 27 | with: 28 | ref: ${{ inputs.candidate && 'main' || 'stable' }} 29 | 30 | - name: Set up latest Python 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version-file: 'pyproject.toml' 34 | 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | python -m pip install .[dev] 39 | 40 | - name: Create wheel 41 | run: | 42 | make dist 43 | 44 | - name: Publish a Python distribution to PyPI 45 | uses: pypa/gh-action-pypi-publish@release/v1 46 | with: 47 | repository-url: ${{ inputs.test_pypi && 'https://test.pypi.org/legacy/' || 'https://upload.pypi.org/legacy/' }} 48 | 49 | - name: Bump version to next candidate 50 | if: ${{ inputs.candidate && !inputs.test_pypi }} 51 | run: | 52 | git config user.name "github-actions[bot]" 53 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com" 54 | bump-my-version bump candidate --no-tag --no-commit 55 | 56 | - name: Create pull request 57 | if: ${{ inputs.candidate && !inputs.test_pypi }} 58 | id: cpr 59 | uses: peter-evans/create-pull-request@v4 60 | with: 61 | token: ${{ secrets.GH_ACCESS_TOKEN }} 62 | commit-message: bumpversion-candidate 63 | committer: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> 64 | author: ${{ github.actor }} <${{ github.actor_id }}+${{ github.actor }}@users.noreply.github.com> 65 | signoff: false 66 | delete-branch: true 67 | title: Automated Bump Version Candidate 68 | body: "This is an auto-generated PR that bumps the version to the next candidate." 69 | branch: bumpversion-candidate-update 70 | branch-suffix: short-commit-hash 71 | add-paths: | 72 | rdt/__init__.py 73 | pyproject.toml 74 | draft: false 75 | base: main 76 | 77 | - name: Enable Pull Request Automerge 78 | if: ${{ steps.cpr.outputs.pull-request-operation == 'created' }} 79 | run: gh pr merge "${{ steps.cpr.outputs.pull-request-number }}" --squash --admin 80 | env: 81 | GH_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} 82 | -------------------------------------------------------------------------------- /.github/workflows/unit.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | unit: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | matrix: 17 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] 18 | os: [ubuntu-latest, windows-latest] 19 | include: 20 | - os: macos-latest 21 | python-version: '3.8' 22 | - os: macos-latest 23 | python-version: '3.13' 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | python -m pip install invoke .[test] 34 | - name: Check dependencies 35 | run: invoke check-dependencies 36 | - name: Run Unit tests 37 | run: invoke unit 38 | 39 | - if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.13 40 | name: Upload unit codecov report 41 | uses: codecov/codecov-action@v4 42 | with: 43 | flags: unit 44 | file: ${{ github.workspace }}/unit_cov.xml 45 | fail_ci_if_error: true 46 | token: ${{ secrets.CODECOV_TOKEN }} 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .github/.tmp/ 2 | tests/readme_test/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | *_cov.xml 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | docs/api/ 72 | docs/rdt.rst 73 | docs/rdt.*.rst 74 | docs/modules.rst 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | .venv 96 | venv/ 97 | ENV/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | # other 113 | .DS_Store 114 | 115 | # Vim 116 | *.swp 117 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | See: https://github.com/sdv-dev/RDT/graphs/contributors 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | We love our community’s interest in the SDV ecosystem. The SDV software 6 | (and its related libraries) is owned and maintained by DataCebo, Inc. 7 | It is available under the `Business Source License`_ for you to browse. 8 | 9 | We support a large set of users with enterprise-specific intricacies and 10 | reliability needs. This has required us to be deliberate about setting 11 | the roadmap for SDV libraries. As a result, we are unable to prioritize 12 | reviewing and accepting external pull requests. As a policy, we will 13 | not be able to accept external contributions. 14 | 15 | **Would you like a bug or feature request to be addressed?** If you haven't 16 | already, we would greatly appreciate it if you could `file an issue`_ 17 | instead with the overall description of your problem. We can determine 18 | whether it’s aligned with our framework. Once discussed, our team 19 | typically resolves smaller issues within a few release cycles. 20 | We appreciate your understanding. 21 | 22 | 23 | .. _Business Source License: https://github.com/sdv-dev/RDT/blob/main/LICENSE 24 | .. _file an issue: https://github.com/sdv-dev/RDT/issues 25 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installing RDT 2 | 3 | ## Requirements 4 | 5 | **RDT** has been developed and tested on 6 | [Python 3.8, 3.9, 3.10, 3.11, 3.12, and 3.13](https://www.python.org/downloads/) 7 | 8 | Also, although it is not strictly required, the usage of a [virtualenv]( 9 | https://virtualenv.pypa.io/en/latest/) is highly recommended in order to avoid 10 | interfering with other software installed in the system where **RDT** is run. 11 | 12 | ## Install with pip 13 | 14 | The easiest and recommended way to install **RDT** is using [pip]( 15 | https://pip.pypa.io/en/stable/): 16 | 17 | ```bash 18 | pip install rdt 19 | ``` 20 | 21 | This will pull and install the latest stable release from [PyPi](https://pypi.org/). 22 | 23 | ## Install with conda 24 | 25 | **RDT** can also be installed using [conda](https://docs.conda.io/en/latest/): 26 | 27 | ```bash 28 | conda install -c sdv-dev -c conda-forge rdt 29 | ``` 30 | 31 | This will pull and install the latest stable release from [Anaconda](https://anaconda.org/). 32 | 33 | ## Install from source 34 | 35 | If you want to install **RDT** from source you need to first clone the repository 36 | and then execute the `make install` command inside the `stable` branch. Note that this 37 | command works only on Unix based systems like GNU/Linux and macOS: 38 | 39 | ```bash 40 | git clone https://github.com/sdv-dev/RDT 41 | cd RDT 42 | git checkout stable 43 | make install 44 | ``` 45 | 46 | ## Install for development 47 | 48 | If you intend to modify the source code or contribute to the project you will need to 49 | install it from the source using the `make install-develop` command. In this case, we 50 | recommend you to branch from `main` first: 51 | 52 | ```bash 53 | git clone git@github.com:sdv-dev/RDT 54 | cd RDT 55 | git checkout main 56 | git checkout -b 57 | make install-develp 58 | ``` 59 | 60 | For more details about how to contribute to the project please visit the [Contributing Guide]( 61 | CONTRIBUTING.rst). 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Business Source License 1.1 2 | 3 | Parameters 4 | 5 | Licensor: DataCebo, Inc. 6 | 7 | Licensed Work: RDT 8 | The Licensed Work is (c) DataCebo, Inc. 9 | 10 | Additional Use Grant: You may make use of the Licensed Work, and derivatives of the Licensed 11 | Work, provided that you do not use the Licensed Work, or derivatives of 12 | the Licensed Work, for a Data Transform Service. 13 | 14 | A "Data Transform Service" is a commercial offering 15 | that allows third parties (other than your employees and 16 | contractors) to access the functionality of the Licensed 17 | Work so that such third parties directly benefit from the 18 | reversible data transformation for data formatting, statistical 19 | processing, anonymization or contextual extraction features of 20 | the Licensed Work. 21 | 22 | Change Date: Change date is four years from release date. 23 | Please see https://github.com/sdv-dev/RDT/releases 24 | for exact dates. 25 | 26 | Change License: MIT License 27 | 28 | 29 | Notice 30 | 31 | The Business Source License (this document, or the "License") is not an Open 32 | Source license. However, the Licensed Work will eventually be made available 33 | under an Open Source License, as stated in this License. 34 | 35 | License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved. 36 | "Business Source License" is a trademark of MariaDB Corporation Ab. 37 | 38 | ----------------------------------------------------------------------------- 39 | 40 | Business Source License 1.1 41 | 42 | Terms 43 | 44 | The Licensor hereby grants you the right to copy, modify, create derivative 45 | works, redistribute, and make non-production use of the Licensed Work. The 46 | Licensor may make an Additional Use Grant, above, permitting limited 47 | production use. 48 | 49 | Effective on the Change Date, or the fourth anniversary of the first publicly 50 | available distribution of a specific version of the Licensed Work under this 51 | License, whichever comes first, the Licensor hereby grants you rights under 52 | the terms of the Change License, and the rights granted in the paragraph 53 | above terminate. 54 | 55 | If your use of the Licensed Work does not comply with the requirements 56 | currently in effect as described in this License, you must purchase a 57 | commercial license from the Licensor, its affiliated entities, or authorized 58 | resellers, or you must refrain from using the Licensed Work. 59 | 60 | All copies of the original and modified Licensed Work, and derivative works 61 | of the Licensed Work, are subject to this License. This License applies 62 | separately for each version of the Licensed Work and the Change Date may vary 63 | for each version of the Licensed Work released by Licensor. 64 | 65 | You must conspicuously display this License on each original or modified copy 66 | of the Licensed Work. If you receive the Licensed Work in original or 67 | modified form from a third party, the terms and conditions set forth in this 68 | License apply to your use of that work. 69 | 70 | Any use of the Licensed Work in violation of this License will automatically 71 | terminate your rights under this License for the current and all other 72 | versions of the Licensed Work. 73 | 74 | This License does not grant you any right in any trademark or logo of 75 | Licensor or its affiliates (provided that you may use a trademark or logo of 76 | Licensor as expressly required by this License). 77 | 78 | TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON 79 | AN "AS IS" BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, 80 | EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND 82 | TITLE. 83 | 84 | MariaDB hereby grants you permission to use this License’s text to license 85 | your works, and to refer to it using the trademark "Business Source License", 86 | as long as you comply with the Covenants of Licensor below. 87 | 88 | Covenants of Licensor 89 | 90 | In consideration of the right to use this License’s text and the "Business 91 | Source License" name and trademark, Licensor covenants to MariaDB, and to all 92 | other recipients of the licensed work to be provided by Licensor: 93 | 94 | 1. To specify as the Change License the GPL Version 2.0 or any later version, 95 | or a license that is compatible with GPL Version 2.0 or a later version, 96 | where "compatible" means that software provided under the Change License can 97 | be included in a program with software provided under GPL Version 2.0 or a 98 | later version. Licensor may specify additional Change Licenses without 99 | limitation. 100 | 101 | 2. To either: (a) specify an additional grant of rights to use that does not 102 | impose any additional restriction on the right granted in this License, as 103 | the Additional Use Grant; or (b) insert the text "None". 104 | 105 | 3. To specify a Change Date. 106 | 107 | 4. Not to modify this License in any other way. 108 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | 3 | define BROWSER_PYSCRIPT 4 | import os, webbrowser, sys 5 | 6 | try: 7 | from urllib import pathname2url 8 | except: 9 | from urllib.request import pathname2url 10 | 11 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 12 | endef 13 | export BROWSER_PYSCRIPT 14 | 15 | define PRINT_HELP_PYSCRIPT 16 | import re, sys 17 | 18 | for line in sys.stdin: 19 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 20 | if match: 21 | target, help = match.groups() 22 | print("%-20s %s" % (target, help)) 23 | endef 24 | export PRINT_HELP_PYSCRIPT 25 | 26 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 27 | 28 | .PHONY: help 29 | help: 30 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 31 | 32 | 33 | # CLEAN TARGETS 34 | 35 | .PHONY: clean-build 36 | clean-build: ## remove build artifacts 37 | rm -fr build/ 38 | rm -fr dist/ 39 | rm -fr .eggs/ 40 | find . -name '*.egg-info' -exec rm -fr {} + 41 | find . -name '*.egg' -exec rm -f {} + 42 | 43 | .PHONY: clean-pyc 44 | clean-pyc: ## remove Python file artifacts 45 | find . -name '*.pyc' -exec rm -f {} + 46 | find . -name '*.pyo' -exec rm -f {} + 47 | find . -name '*~' -exec rm -f {} + 48 | find . -name '__pycache__' -exec rm -fr {} + 49 | 50 | .PHONY: clean-coverage 51 | clean-coverage: ## remove coverage artifacts 52 | rm -f .coverage 53 | rm -f .coverage.* 54 | rm -fr htmlcov/ 55 | 56 | .PHONY: clean-test 57 | clean-test: ## remove test artifacts 58 | rm -fr .tox/ 59 | rm -fr .pytest_cache 60 | 61 | .PHONY: clean 62 | clean: clean-build clean-pyc clean-test clean-coverage ## remove all build, test, coverage and Python artifacts 63 | 64 | 65 | # INSTALL TARGETS 66 | 67 | .PHONY: install 68 | install: clean-build clean-pyc ## install the package to the active Python's site-packages 69 | pip install . 70 | 71 | .PHONY: install-test 72 | install-test: clean-build clean-pyc ## install the package and test dependencies 73 | pip install .[test] 74 | 75 | .PHONY: install-develop 76 | install-develop: clean-build clean-pyc ## install the package in editable mode and dependencies for development 77 | pip install -e .[dev] 78 | 79 | .PHONY: install-readme 80 | install-readme: clean-build clean-pyc ## install the package in editable mode and readme dependencies for developement 81 | pip install -e .[readme] 82 | 83 | # LINT TARGETS 84 | 85 | .PHONY: lint 86 | lint: 87 | invoke lint 88 | 89 | .PHONY: fix-lint 90 | fix-lint: 91 | invoke fix-lint 92 | 93 | 94 | # TEST TARGETS 95 | 96 | .PHONY: test-unit 97 | test-unit: ## run tests quickly with the default Python 98 | invoke unit 99 | 100 | .PHONY: test-integration 101 | test-integration: ## run integration tests 102 | invoke integration 103 | 104 | .PHONY: test-readme 105 | test-readme: ## run the readme snippets 106 | invoke readme 107 | 108 | .PHONY: test-performance 109 | test-performance: ## run performance tests 110 | invoke performance 111 | 112 | .PHONY: test 113 | test: test-unit test-integration test-readme ## test everything that needs test dependencies 114 | 115 | .PHONY: test-repo 116 | test-repo: lint test-unit test-integration test-readme test-performance ## test everything 117 | 118 | .PHONY: coverage 119 | coverage: ## check code coverage quickly with the default Python 120 | coverage run --source rdt -m pytest tests/unit 121 | coverage report -m 122 | coverage html 123 | $(BROWSER) htmlcov/index.html 124 | 125 | 126 | # RELEASE TARGETS 127 | 128 | .PHONY: dist 129 | dist: clean ## builds source and wheel package 130 | python -m build --wheel --sdist 131 | ls -l dist 132 | 133 | .PHONY: publish-confirm 134 | publish-confirm: 135 | @echo "WARNING: This will irreversibly upload a new version to PyPI!" 136 | @echo -n "Please type 'confirm' to proceed: " \ 137 | && read answer \ 138 | && [ "$${answer}" = "confirm" ] 139 | 140 | .PHONY: publish-test 141 | publish-test: dist publish-confirm ## package and upload a release on TestPyPI 142 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 143 | 144 | .PHONY: publish 145 | publish: dist publish-confirm ## package and upload a release 146 | twine upload dist/* 147 | 148 | .PHONY: git-merge-main-stable 149 | git-merge-main-stable: ## Merge main into stable 150 | git checkout stable || git checkout -b stable 151 | git merge --no-ff main -m"make release-tag: Merge branch 'main' into stable" 152 | 153 | .PHONY: git-merge-stable-main 154 | git-merge-stable-main: ## Merge stable into main 155 | git checkout main 156 | git merge stable 157 | 158 | .PHONY: git-push 159 | git-push: ## Simply push the repository to github 160 | git push 161 | 162 | .PHONY: git-push-tags-stable 163 | git-push-tags-stable: ## Push tags and stable to github 164 | git push --tags origin stable 165 | 166 | .PHONY: bumpversion-release 167 | bumpversion-release: ## Bump the version to the next release 168 | bump-my-version bump release --no-tag 169 | 170 | .PHONY: bumpversion-patch 171 | bumpversion-patch: ## Bump the version to the next patch 172 | bump-my-version bump --no-tag patch 173 | 174 | .PHONY: bumpversion-candidate 175 | bumpversion-candidate: ## Bump the version to the next candidate 176 | bump-my-version bump candidate --no-tag 177 | 178 | .PHONY: bumpversion-minor 179 | bumpversion-minor: ## Bump the version the next minor skipping the release 180 | bump-my-version bump --no-tag minor 181 | 182 | .PHONY: bumpversion-major 183 | bumpversion-major: ## Bump the version the next major skipping the release 184 | bump-my-version bump --no-tag major 185 | 186 | .PHONY: bumpversion-revert 187 | bumpversion-revert: ## Undo a previous bumpversion-release 188 | git tag --delete $(shell git tag --points-at HEAD) 189 | git checkout main 190 | git branch -D stable 191 | 192 | CLEAN_DIR := $(shell git status --short | grep -v ??) 193 | CURRENT_BRANCH := $(shell git rev-parse --abbrev-ref HEAD 2>/dev/null) 194 | CURRENT_VERSION := $(shell grep "^current_version" pyproject.toml | grep -o "dev[0-9]*") 195 | CHANGELOG_LINES := $(shell git diff HEAD..origin/stable HISTORY.md 2>&1 | wc -l) 196 | 197 | .PHONY: check-clean 198 | check-clean: ## Check if the directory has uncommitted changes 199 | ifneq ($(CLEAN_DIR),) 200 | $(error There are uncommitted changes) 201 | endif 202 | 203 | .PHONY: check-main 204 | check-main: ## Check if we are in main branch 205 | ifneq ($(CURRENT_BRANCH),main) 206 | $(error Please make the release from main branch\n) 207 | endif 208 | 209 | .PHONY: check-candidate 210 | check-candidate: ## Check if a release candidate has been made 211 | ifeq ($(CURRENT_VERSION),dev0) 212 | $(error Please make a release candidate and test it before atempting a release) 213 | endif 214 | 215 | .PHONY: check-history 216 | check-history: ## Check if HISTORY.md has been modified 217 | ifeq ($(CHANGELOG_LINES),0) 218 | $(error Please insert the release notes in HISTORY.md before releasing) 219 | endif 220 | 221 | .PHONY: check-deps 222 | check-deps: # Dependency targets 223 | $(eval allow_list='numpy=|pandas=|scikit-learn=|scipy=|Faker=|copulas=') 224 | pip freeze | grep -v "RDT.git" | grep -E $(allow_list) | sort > $(OUTPUT_FILEPATH) 225 | 226 | .PHONY: check-release 227 | check-release: check-clean check-candidate check-main check-history ## Check if the release can be made 228 | @echo "A new release can be made" 229 | 230 | .PHONY: release 231 | release: check-release git-merge-main-stable bumpversion-release git-push-tags-stable \ 232 | git-merge-stable-main bumpversion-patch git-push 233 | 234 | .PHONY: release-test 235 | release-test: check-release git-merge-main-stable bumpversion-release bumpversion-revert 236 | 237 | .PHONY: release-candidate 238 | release-candidate: check-main publish bumpversion-candidate git-push 239 | 240 | .PHONY: release-candidate-test 241 | release-candidate-test: check-clean check-main publish-test 242 | 243 | .PHONY: release-minor 244 | release-minor: check-release bumpversion-minor release 245 | 246 | .PHONY: release-major 247 | release-major: check-release bumpversion-major release 248 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

4 | This repository is part of The Synthetic Data Vault Project, a project from DataCebo. 5 |

6 | 7 | [![Development Status](https://img.shields.io/badge/Development%20Status-5%20--%20Production/Stable-green)](https://pypi.org/search/?q=&o=&c=Development+Status+%3A%3A+5+-+Production%2FStable) 8 | [![PyPi Shield](https://img.shields.io/pypi/v/RDT.svg)](https://pypi.python.org/pypi/RDT) 9 | [![Unit Tests](https://github.com/sdv-dev/RDT/actions/workflows/unit.yml/badge.svg)](https://github.com/sdv-dev/RDT/actions/workflows/unit.yml) 10 | [![Downloads](https://pepy.tech/badge/rdt)](https://pepy.tech/project/rdt) 11 | [![Coverage Status](https://codecov.io/gh/sdv-dev/RDT/branch/main/graph/badge.svg)](https://codecov.io/gh/sdv-dev/RDT) 12 | [![Slack](https://img.shields.io/badge/Community-Slack-blue?style=plastic&logo=slack)](https://bit.ly/sdv-slack-invite) 13 | 14 |
15 |
16 |

17 | 18 | 19 | 20 |

21 |
22 | 23 |
24 | 25 | # Overview 26 | 27 | RDT (Reversible Data Transforms) is a Python library that transforms raw data into fully numerical 28 | data, ready for data science. The transforms are reversible, allowing you to convert from numerical 29 | data back into your original format. 30 | 31 | 32 | 33 | 34 | # Install 35 | 36 | Install **RDT** using ``pip`` or ``conda``. We recommend using a virtual environment to avoid 37 | conflicts with other software on your device. 38 | 39 | ```bash 40 | pip install rdt 41 | ``` 42 | 43 | ```bash 44 | conda install -c conda-forge rdt 45 | ``` 46 | 47 | For more information about using reversible data transformations, visit the [RDT Documentation](https://docs.sdv.dev/rdt). 48 | 49 | 50 | # Quickstart 51 | 52 | In this short series of tutorials we will guide you through a series of steps that will 53 | help you getting started using **RDT** to transform columns, tables and datasets. 54 | 55 | ## Load the demo data 56 | 57 | After you have installed RDT, you can get started using the demo dataset. 58 | 59 | ```python3 60 | from rdt import get_demo 61 | 62 | customers = get_demo() 63 | ``` 64 | 65 | This dataset contains some randomly generated values that describe the customers of an online 66 | marketplace. 67 | 68 | ``` 69 | last_login email_optin credit_card age dollars_spent 70 | 0 2021-06-26 False VISA 29 99.99 71 | 1 2021-02-10 False VISA 18 NaN 72 | 2 NaT False AMEX 21 2.50 73 | 3 2020-09-26 True NaN 45 25.00 74 | 4 2020-12-22 NaN DISCOVER 32 19.99 75 | ``` 76 | 77 | Let's transform this data so that each column is converted to full, numerical data ready for data 78 | science. 79 | 80 | ## Creating the HyperTransformer & config 81 | 82 | The ``HyperTransformer`` is capable of transforming multi-column datasets. 83 | 84 | ```python3 85 | from rdt import HyperTransformer 86 | 87 | ht = HyperTransformer() 88 | ``` 89 | 90 | The `HyperTransformer` needs to know about the columns in your dataset and which transformers to 91 | apply to each. These are described by a config. We can ask the `HyperTransformer` to automatically 92 | detect it based on the data we plan to use. 93 | 94 | ```python3 95 | ht.detect_initial_config(data=customers) 96 | ``` 97 | 98 | This will create and set the config. 99 | 100 | ``` 101 | Config: 102 | { 103 | "sdtypes": { 104 | "last_login": "datetime", 105 | "email_optin": "boolean", 106 | "credit_card": "categorical", 107 | "age": "numerical", 108 | "dollars_spent": "numerical" 109 | }, 110 | "transformers": { 111 | "last_login": "UnixTimestampEncoder()", 112 | "email_optin": "BinaryEncoder()", 113 | "credit_card": "FrequencyEncoder()", 114 | "age": "FloatFormatter()", 115 | "dollars_spent": "FloatFormatter()" 116 | } 117 | } 118 | ``` 119 | 120 | The `sdtypes` dictionary describes the semantic data types of each of your columns and the 121 | `transformers` dictionary describes which transformer to use for each column. You can customize the 122 | transformers and their settings. (See the [Transformers Glossary](https://docs.sdv.dev/rdt/transformers-glossary/browse-transformers) for more information). 123 | 124 | ## Fitting & using the HyperTransformer 125 | 126 | The `HyperTransformer` references the config while learning the data during the `fit` stage. 127 | 128 | ```python3 129 | ht.fit(customers) 130 | ``` 131 | 132 | Once the transformer is fit, it's ready to use. Use the transform method to transform all columns 133 | of your dataset at once. 134 | 135 | ```python3 136 | transformed_data = ht.transform(customers) 137 | ``` 138 | 139 | ``` 140 | last_login.value email_optin.value credit_card.value age.value dollars_spent.value 141 | 0 1.624666e+18 0.0 0.2 29 99.99 142 | 1 1.612915e+18 0.0 0.2 18 36.87 143 | 2 1.611814e+18 0.0 0.5 21 2.50 144 | 3 1.601078e+18 1.0 0.7 45 25.00 145 | 4 1.608595e+18 0.0 0.9 32 19.99 146 | ``` 147 | 148 | The ``HyperTransformer`` applied the assigned transformer to each individual column. Each column 149 | now contains fully numerical data that you can use for your project! 150 | 151 | When you're done with your project, you can also transform the data back to the original format 152 | using the `reverse_transform` method. 153 | 154 | ```python3 155 | original_format_data = ht.reverse_transform(transformed_data) 156 | ``` 157 | 158 | ``` 159 | last_login email_optin credit_card age dollars_spent 160 | 0 NaT False VISA 29 99.99 161 | 1 2021-02-10 False VISA 18 NaN 162 | 2 NaT False AMEX 21 NaN 163 | 3 2020-09-26 True NaN 45 25.00 164 | 4 2020-12-22 False DISCOVER 32 19.99 165 | ``` 166 | 167 | # What's Next? 168 | 169 | To learn more about reversible data transformations, visit the [RDT Documentation](https://docs.sdv.dev/rdt). 170 | 171 | 172 | --- 173 | 174 | 175 |
176 | 177 |
178 |
179 |
180 | 181 | [The Synthetic Data Vault Project](https://sdv.dev) was first created at MIT's [Data to AI Lab]( 182 | https://dai.lids.mit.edu/) in 2016. After 4 years of research and traction with enterprise, we 183 | created [DataCebo](https://datacebo.com) in 2020 with the goal of growing the project. 184 | Today, DataCebo is the proud developer of SDV, the largest ecosystem for 185 | synthetic data generation & evaluation. It is home to multiple libraries that support synthetic 186 | data, including: 187 | 188 | * 🔄 Data discovery & transformation. Reverse the transforms to reproduce realistic data. 189 | * 🧠 Multiple machine learning models -- ranging from Copulas to Deep Learning -- to create tabular, 190 | multi table and time series data. 191 | * 📊 Measuring quality and privacy of synthetic data, and comparing different synthetic data 192 | generation models. 193 | 194 | [Get started using the SDV package](https://sdv.dev/SDV/getting_started/install.html) -- a fully 195 | integrated solution and your one-stop shop for synthetic data. Or, use the standalone libraries 196 | for specific needs. 197 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release workflow 2 | 3 | The process of releasing a new version involves several steps: 4 | 5 | 1. [Install RDT from source](#install-rdt-from-source) 6 | 7 | 2. [Linting and tests](#linting-and-tests) 8 | 9 | 3. [Make a release candidate](#make-a-release-candidate) 10 | 11 | 4. [Integration with SDV](#integration-with-sdv) 12 | 13 | 5. [Milestone](#milestone) 14 | 15 | 6. [Update HISTORY](#update-history) 16 | 17 | 7. [Check the release](#check-the-release) 18 | 19 | 8. [Update stable branch and bump version](#update-stable-branch-and-bump-version) 20 | 21 | 9. [Create the Release on GitHub](#create-the-release-on-github) 22 | 23 | 10. [Close milestone and create new milestone](#close-milestone-and-create-new-milestone) 24 | 25 | ## Install RDT from source 26 | 27 | Clone the project and install the development requirements before start the release process. Alternatively, with your virtualenv activated. 28 | 29 | ```bash 30 | git clone https://github.com/sdv-dev/RDT.git 31 | cd RDT 32 | git checkout main 33 | make install-develop 34 | ``` 35 | 36 | ## Linting and tests 37 | 38 | Execute the tests and linting. The tests must end with no errors: 39 | 40 | ```bash 41 | make test && make lint 42 | ``` 43 | 44 | And you will see something like this: 45 | 46 | ``` 47 | Coverage XML written to file ./integration_cov.xml 48 | ============ 242 passed, 166 warnings, 134 subtests passed in 9.88s ============ 49 | .... 50 | invoke lint 51 | No broken requirements found. 52 | All checks passed! 53 | 80 files already formatted 54 | ``` 55 | 56 | The execution has finished with no errors, 0 test skipped and 166 warnings. 57 | 58 | ## Make a release candidate 59 | 60 | 1. On the RDT GitHub page, navigate to the [Actions][actions] tab. 61 | 2. Select the `Release` action. 62 | 3. Run it on the main branch. Make sure `Release candidate` is checked and `Test PyPI` is not. 63 | 4. Check on [PyPI][rdt-pypi] to assure the release candidate was successfully uploaded. 64 | - You should see X.Y.ZdevN PRE-RELEASE 65 | 66 | [actions]: https://github.com/sdv-dev/RDT/actions 67 | [rdt-pypi]: https://pypi.org/project/RDT/#history 68 | 69 | ## Integration with SDV 70 | 71 | ### Create a branch on SDV to test the candidate 72 | 73 | Before doing the actual release, we need to test that the candidate works with SDV. To do this, we can create a branch on SDV that points to the release candidate we just created using the following steps: 74 | 75 | 1. Create a new branch on the SDV repository. 76 | 77 | ```bash 78 | git checkout -b test-rdt-X.Y.Z 79 | ``` 80 | 81 | 2. Update the pyproject.toml to set the minimum version of RDT to be the same as the version of the release. For example, 82 | 83 | ```toml 84 | 'rdt>=X.Y.Z.dev0' 85 | ``` 86 | 87 | 3. Push this branch. This should trigger all the tests to run. 88 | 89 | ```bash 90 | git push --set-upstream origin test-rdt-X.Y.Z 91 | ``` 92 | 93 | 4. Check the [Actions][sdv-actions] tab on SDV to make sure all the tests pass. 94 | 95 | [sdv-actions]: https://github.com/sdv-dev/SDV/actions 96 | 97 | ## Milestone 98 | 99 | It's important to check that the GitHub and milestone issues are up to date with the release. 100 | 101 | You neet to check that: 102 | 103 | - The milestone for the current release exists. 104 | - All the issues closed since the latest release are associated to the milestone. If they are not, associate them 105 | - All the issues associated to the milestone are closed. If there are open issues but the milestone needs to 106 | be released anyway, move them to the next milestone. 107 | - All the issues in the milestone are assigned to at least one person. 108 | - All the pull requests closed since the latest release are associated to an issue. If necessary, create issues 109 | and assign them to the milestone. Also assign the person who opened the issue to them. 110 | 111 | ## Update HISTORY 112 | Run the [Release Prep](https://github.com/sdv-dev/RDT/actions/workflows/prepare_release.yml) workflow. This workflow will create a pull request with updates to HISTORY.md 113 | 114 | Make sure HISTORY.md is updated with the issues of the milestone: 115 | 116 | ``` 117 | # History 118 | 119 | ## X.Y.Z (YYYY-MM-DD) 120 | 121 | ### New Features 122 | 123 | * - [Issue #](https://github.com/sdv-dev/RDT/issues/) by @resolver 124 | 125 | ### General Improvements 126 | 127 | * - [Issue #](https://github.com/sdv-dev/RDT/issues/) by @resolver 128 | 129 | ### Bug Fixed 130 | 131 | * - [Issue #](https://github.com/sdv-dev/RDT/issues/) by @resolver 132 | ``` 133 | 134 | The issue list per milestone can be found [here][milestones]. 135 | 136 | [milestones]: https://github.com/sdv-dev/RDT/milestones 137 | 138 | Put the pull request up for review and get 2 approvals to merge into `main`. 139 | 140 | ## Check the release 141 | Once HISTORY.md has been updated on `main`, check if the release can be made: 142 | 143 | 144 | ```bash 145 | make check-release 146 | ``` 147 | 148 | ## Update stable branch and bump version 149 | The `stable` branch needs to be updated with the changes from `main` and the version needs to be bumped. 150 | Depending on the type of release, run one of the following: 151 | 152 | * `make release`: This will release the version that has already been bumped (patch, minor, or major). By default, this is typically a patch release. Use this when the changes are bugfixes or enhancements that do not modify the existing user API. Changes that modify the user API to add new features but that do not modify the usage of the previous features can also be released as a patch. 153 | * `make release-minor`: This will bump and release the next minor version. Use this if the changes modify the existing user API in any way, even if it is backwards compatible. Minor backwards incompatible changes can also be released as minor versions while the library is still in beta state. After the major version v1.0.0 has been released, minor version can only be used to add backwards compatible API changes. 154 | * `make release-major`: This will bump and release the next major version. Use this if the changes modify the user API in a backwards incompatible way after the major version v1.0.0 has been released. 155 | 156 | Running one of these will **push commits directly** to `main`. 157 | At the end, you should see the 3 commits on `main` (from oldest to newest): 158 | - `make release-tag: Merge branch 'main' into stable` 159 | - `Bump version: X.Y.Z.devN → X.Y.Z` 160 | - `Bump version: X.Y.Z -> X.Y.A.dev0` 161 | 162 | ## Create the Release on GitHub 163 | 164 | After the update to HISTORY.md is merged into `main` and the version is bumped, it is time to [create the release GitHub](https://github.com/sdv-dev/RDT/releases/new). 165 | - Create a new tag with the version number with a v prefix (e.g. v0.3.1) 166 | - The target should be the `stable` branch 167 | - Release title is the same as the tag (e.g. v0.3.1) 168 | - This is not a pre-release (`Set as a pre-release` should be unchecked) 169 | 170 | Click `Publish release`, which will kickoff the release workflow and automatically upload the package to [public PyPI](https://pypi.org/project/rdt/). 171 | 172 | ## Close milestone and create new milestone 173 | 174 | Finaly, **close the milestone** and, if it does not exist, **create the next milestone**. 175 | 176 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | range: "90...100" 4 | status: 5 | project: 6 | default: false 7 | patch: 8 | default: false -------------------------------------------------------------------------------- /latest_requirements.txt: -------------------------------------------------------------------------------- 1 | Faker==37.4.0 2 | copulas==0.12.3 3 | numpy==2.3.1 4 | pandas==2.3.0 5 | scikit-learn==1.7.0 6 | scipy==1.16.0 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = 'rdt' 3 | description = 'Reversible Data Transforms' 4 | authors = [{ name = 'DataCebo, Inc.', email = 'info@sdv.dev' }] 5 | classifiers = [ 6 | 'Development Status :: 5 - Production/Stable', 7 | 'Intended Audience :: Developers', 8 | 'License :: Free for non-commercial use', 9 | 'Natural Language :: English', 10 | 'Programming Language :: Python :: 3', 11 | 'Programming Language :: Python :: 3.8', 12 | 'Programming Language :: Python :: 3.9', 13 | 'Programming Language :: Python :: 3.10', 14 | 'Programming Language :: Python :: 3.11', 15 | 'Programming Language :: Python :: 3.12', 16 | 'Programming Language :: Python :: 3.13', 17 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 18 | ] 19 | keywords = ['machine learning', 'synthetic data generation', 'benchmark', 'generative models'] 20 | dynamic = ['version'] 21 | license = { text = 'BSL-1.1' } 22 | requires-python = '>=3.8,<3.14' 23 | readme = 'README.md' 24 | dependencies = [ 25 | "numpy>=1.21.0;python_version<'3.10'", 26 | "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'", 27 | "numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'", 28 | "numpy>=2.1.0;python_version>='3.13'", 29 | "pandas>=1.4.0;python_version<'3.11'", 30 | "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'", 31 | "pandas>=2.1.1;python_version>='3.12' and python_version<'3.13'", 32 | "pandas>=2.2.3;python_version>='3.13'", 33 | "scipy>=1.7.3;python_version<'3.10'", 34 | "scipy>=1.9.2;python_version>='3.10' and python_version<'3.12'", 35 | "scipy>=1.12.0;python_version>='3.12' and python_version<'3.13'", 36 | "scipy>=1.14.1;python_version>='3.13'", 37 | "scikit-learn>=1.0.2;python_version<'3.10'", 38 | "scikit-learn>=1.1.0;python_version>='3.10' and python_version<'3.11'", 39 | "scikit-learn>=1.1.3;python_version>='3.11' and python_version<'3.12'", 40 | "scikit-learn>=1.3.1;python_version>='3.12' and python_version<'3.13'", 41 | "scikit-learn>=1.5.2;python_version>='3.13'", 42 | 'Faker>=17', 43 | 'python-dateutil>=2.9', 44 | ] 45 | 46 | [project.urls] 47 | "Source Code"= "https://github.com/sdv-dev/RDT/" 48 | "Issue Tracker" = "https://github.com/sdv-dev/RDT/issues" 49 | "Changes" = "https://github.com/sdv-dev/RDT/blob/main/HISTORY.md" 50 | "Twitter" = "https://twitter.com/sdv_dev" 51 | "Chat" = "https://bit.ly/sdv-slack-invite" 52 | 53 | [project.entry-points] 54 | rdt = { main = 'rdt.cli.__main__:main' } 55 | 56 | [project.optional-dependencies] 57 | copulas = ['copulas>=0.12.1',] 58 | pyarrow = ['pyarrow>=17.0.0',] 59 | test = [ 60 | 'rdt[pyarrow]', 61 | 'rdt[copulas]', 62 | 63 | 'pytest>=3.4.2', 64 | 'pytest-cov>=2.6.0', 65 | 'jupyter>=1.0.0,<2', 66 | 'pytest-subtests>=0.5,<1.0', 67 | 'pytest-runner >= 2.11.1', 68 | 'tomli>=2.0.0,<3', 69 | ] 70 | dev = [ 71 | 'rdt[test]', 72 | 73 | # general 74 | 'build>=1.0.0,<2', 75 | 'bump-my-version>=0.18.3', 76 | 'pip>=9.0.1', 77 | 'watchdog>=1.0.1,<5', 78 | 79 | # style check 80 | 'ruff>=0.3.2,<1', 81 | 82 | # distribute on PyPI 83 | 'twine>=1.10.0', 84 | 'wheel>=0.30.0', 85 | 86 | # Advanced testing 87 | 'coverage>=4.5.12,<8', 88 | 'tabulate>=0.8.9,<1', 89 | 90 | # Invoking test commands 91 | 'invoke', 92 | ] 93 | readme = ['rundoc>=0.4.3,<0.5',] 94 | 95 | [tool.setuptools] 96 | include-package-data = true 97 | license-files = ['LICENSE'] 98 | 99 | [tool.setuptools.packages.find] 100 | include = ['rdt', 'rdt.*'] 101 | namespaces = false 102 | 103 | [tool.setuptools.package-data] 104 | '*' = [ 105 | 'AUTHORS.rst', 106 | 'CONTRIBUTING.rst', 107 | 'HISTORY.md', 108 | 'README.md', 109 | 'RELEASE.md', 110 | '*.md', 111 | '*.rst', 112 | 'conf.py', 113 | 'Makefile', 114 | 'make.bat', 115 | '*.jpg', 116 | '*.png', 117 | '*.gif' 118 | ] 119 | 120 | [tool.setuptools.exclude-package-data] 121 | '*' = [ 122 | '* __pycache__', 123 | '*.py[co]', 124 | 'static_code_analysis.txt', 125 | ] 126 | 127 | [tool.setuptools.dynamic] 128 | version = {attr = 'rdt.__version__'} 129 | 130 | [tool.isort] 131 | line_length = 99 132 | lines_between_types = 0 133 | multi_line_output = 4 134 | use_parentheses = true 135 | 136 | [tool.pydocstyle] 137 | convention = 'google' 138 | add-ignore = ['D107', 'D407', 'D417'] 139 | 140 | [tool.pytest.ini_options] 141 | collect_ignore = ['pyproject.toml'] 142 | 143 | [tool.coverage.report] 144 | exclude_lines = ['NotImplementedError()'] 145 | 146 | [tool.bumpversion] 147 | current_version = "1.17.2.dev0" 148 | parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?' 149 | serialize = [ 150 | '{major}.{minor}.{patch}.{release}{candidate}', 151 | '{major}.{minor}.{patch}' 152 | ] 153 | search = '{current_version}' 154 | replace = '{new_version}' 155 | regex = false 156 | ignore_missing_version = false 157 | tag = true 158 | sign_tags = false 159 | tag_name = 'v{new_version}' 160 | tag_message = 'Bump version: {current_version} → {new_version}' 161 | allow_dirty = false 162 | commit = true 163 | message = 'Bump version: {current_version} → {new_version}' 164 | commit_args = '' 165 | 166 | [tool.bumpversion.parts.release] 167 | first_value = 'dev' 168 | optional_value = 'release' 169 | values = [ 170 | 'dev', 171 | 'release' 172 | ] 173 | 174 | [[tool.bumpversion.files]] 175 | filename = "rdt/__init__.py" 176 | search = "__version__ = '{current_version}'" 177 | replace = "__version__ = '{new_version}'" 178 | 179 | [build-system] 180 | requires = ['setuptools', 'wheel'] 181 | build-backend = 'setuptools.build_meta' 182 | 183 | [tool.ruff] 184 | preview = true 185 | line-length = 100 186 | indent-width = 4 187 | src = ["rdt"] 188 | exclude = [ 189 | 'docs', 190 | '.tox', 191 | '.git', 192 | '__pycache__', 193 | '*.ipynb', 194 | '.ipynb_checkpoints', 195 | 'tasks.py', 196 | 'tests/contributing.py' 197 | ] 198 | 199 | [tool.ruff.lint] 200 | select = [ 201 | # Pyflakes 202 | "F", 203 | # Pycodestyle 204 | "E", 205 | "W", 206 | # pydocstyle 207 | "D", 208 | # isort 209 | "I001", 210 | # print statements 211 | "T201", 212 | # pandas-vet 213 | "PD", 214 | # numpy 2.0 215 | "NPY201" 216 | ] 217 | ignore = [ 218 | # pydocstyle 219 | "D107", # Missing docstring in __init__ 220 | "D417", # Missing argument descriptions in the docstring, this is a bug from pydocstyle: https://github.com/PyCQA/pydocstyle/issues/449 221 | "PD901", 222 | "PD101", 223 | ] 224 | 225 | [tool.ruff.format] 226 | quote-style = "single" 227 | indent-style = "space" 228 | preview = true 229 | docstring-code-format = true 230 | docstring-code-line-length = "dynamic" 231 | 232 | [tool.ruff.lint.isort] 233 | known-first-party = ["rdt"] 234 | lines-between-types = 0 235 | 236 | [tool.ruff.lint.per-file-ignores] 237 | "__init__.py" = ["F401", "E402", "F403", "F405", "E501", "I001"] 238 | "errors.py" = ["D105"] 239 | "tests/**.py" = ["D"] 240 | "tests/contributing.py" = ["T201"] 241 | 242 | [tool.ruff.lint.pydocstyle] 243 | convention = "google" 244 | 245 | [tool.ruff.lint.pycodestyle] 246 | max-doc-length = 100 247 | max-line-length = 100 248 | -------------------------------------------------------------------------------- /rdt/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for RDT.""" 4 | 5 | __author__ = 'DataCebo, Inc.' 6 | __email__ = 'info@sdv.dev' 7 | __version__ = '1.17.2.dev0' 8 | 9 | 10 | import sys 11 | import warnings 12 | from importlib.metadata import entry_points 13 | from operator import attrgetter 14 | from types import ModuleType 15 | 16 | import numpy as np 17 | import pandas as pd 18 | 19 | from rdt import transformers 20 | from rdt.hyper_transformer import HyperTransformer 21 | 22 | __all__ = ['HyperTransformer', 'transformers'] 23 | 24 | RANDOM_SEED = 42 25 | 26 | 27 | def get_demo(num_rows=5): 28 | """Generate demo data with multiple sdtypes. 29 | 30 | The first five rows are hard coded. The rest are randomly generated 31 | using ``np.random.seed(42)``. 32 | 33 | Args: 34 | num_rows (int): 35 | Number of data rows to generate. Defaults to 5. 36 | 37 | Returns: 38 | pd.DataFrame 39 | """ 40 | # Hard code first five rows 41 | login_dates = pd.Series( 42 | ['2021-06-26', '2021-02-10', 'NAT', '2020-09-26', '2020-12-22'], 43 | dtype='datetime64[ns]', 44 | ) 45 | email_optin = pd.Series([False, False, False, True, np.nan], dtype='object') 46 | credit_card = ['VISA', 'VISA', 'AMEX', np.nan, 'DISCOVER'] 47 | age = [29, 18, 21, 45, 32] 48 | dollars_spent = [99.99, np.nan, 2.50, 25.00, 19.99] 49 | 50 | data = pd.DataFrame({ 51 | 'last_login': login_dates, 52 | 'email_optin': email_optin, 53 | 'credit_card': credit_card, 54 | 'age': age, 55 | 'dollars_spent': dollars_spent, 56 | }) 57 | 58 | if num_rows <= 5: 59 | return data.iloc[:num_rows] 60 | 61 | # Randomly generate the remaining rows 62 | random_state = np.random.get_state() 63 | np.random.set_state(np.random.RandomState(RANDOM_SEED).get_state()) 64 | try: 65 | num_rows -= 5 66 | 67 | login_dates = np.array( 68 | [ 69 | np.datetime64('2000-01-01') + np.timedelta64(np.random.randint(0, 10000), 'D') 70 | for _ in range(num_rows) 71 | ], 72 | dtype='datetime64[ns]', 73 | ) 74 | login_dates[np.random.random(size=num_rows) > 0.8] = np.datetime64('NaT') 75 | 76 | email_optin = pd.Series([True, False, np.nan], dtype='object').sample( 77 | num_rows, replace=True 78 | ) 79 | credit_card = np.random.choice(['VISA', 'AMEX', np.nan, 'DISCOVER'], size=num_rows) 80 | age = np.random.randint(18, 100, size=num_rows) 81 | 82 | dollars_spent = np.around(np.random.uniform(0, 100, size=num_rows), decimals=2) 83 | dollars_spent[np.random.random(size=num_rows) > 0.8] = np.nan 84 | 85 | finally: 86 | np.random.set_state(random_state) 87 | 88 | return pd.concat( 89 | [ 90 | data, 91 | pd.DataFrame({ 92 | 'last_login': login_dates, 93 | 'email_optin': email_optin, 94 | 'credit_card': credit_card, 95 | 'age': age, 96 | 'dollars_spent': dollars_spent, 97 | }), 98 | ], 99 | ignore_index=True, 100 | ) 101 | 102 | 103 | def _get_addon_target(addon_path_name): 104 | """Find the target object for the add-on. 105 | 106 | Args: 107 | addon_path_name (str): 108 | The add-on's name. The add-on's name should be the full path of valid Python 109 | identifiers (i.e. importable.module:object.attr). 110 | 111 | Returns: 112 | tuple: 113 | * object: 114 | The base module or object the add-on should be added to. 115 | * str: 116 | The name the add-on should be added to under the module or object. 117 | """ 118 | module_path, _, object_path = addon_path_name.partition(':') 119 | module_path = module_path.split('.') 120 | 121 | if module_path[0] != __name__: 122 | msg = f"expected base module to be '{__name__}', found '{module_path[0]}'" 123 | raise AttributeError(msg) 124 | 125 | target_base = sys.modules[__name__] 126 | for submodule in module_path[1:-1]: 127 | target_base = getattr(target_base, submodule) 128 | 129 | addon_name = module_path[-1] 130 | if object_path: 131 | if len(module_path) > 1 and not hasattr(target_base, module_path[-1]): 132 | msg = f"cannot add '{object_path}' to unknown submodule '{'.'.join(module_path)}'" 133 | raise AttributeError(msg) 134 | 135 | if len(module_path) > 1: 136 | target_base = getattr(target_base, module_path[-1]) 137 | 138 | split_object = object_path.split('.') 139 | addon_name = split_object[-1] 140 | 141 | if len(split_object) > 1: 142 | target_base = attrgetter('.'.join(split_object[:-1]))(target_base) 143 | 144 | return target_base, addon_name 145 | 146 | 147 | def _find_addons(): 148 | """Find and load all RDT add-ons. 149 | 150 | If the add-on is a module, we add it both to the target module and to 151 | ``system.modules`` so that they can be imported from the top of a file as follows: 152 | 153 | from top_module.addon_module import x 154 | """ 155 | group = 'rdt_modules' 156 | try: 157 | eps = entry_points(group=group) # pylint: disable=E1123 158 | except TypeError: 159 | # Load-time selection requires Python >= 3.10 or importlib_metadata >= 3.6 160 | eps = entry_points().get(group, []) # pylint: disable=E1101 161 | 162 | for entry_point in eps: 163 | try: 164 | addon = entry_point.load() 165 | except Exception: # pylint: disable=broad-exception-caught 166 | msg = f'Failed to load "{entry_point.name}" from "{entry_point.value}".' 167 | warnings.warn(msg) 168 | continue 169 | 170 | try: 171 | addon_target, addon_name = _get_addon_target(entry_point.name) 172 | except AttributeError as error: 173 | msg = f"Failed to set '{entry_point.name}': {error}." 174 | warnings.warn(msg) 175 | continue 176 | 177 | if isinstance(addon, ModuleType): 178 | addon_module_name = f'{addon_target.__name__}.{addon_name}' 179 | if addon_module_name not in sys.modules: 180 | sys.modules[addon_module_name] = addon 181 | 182 | setattr(addon_target, addon_name, addon) 183 | 184 | 185 | _find_addons() 186 | -------------------------------------------------------------------------------- /rdt/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from rdt.errors import InvalidConfigError 4 | 5 | 6 | def _validate_unique_transformer_instances(column_name_to_transformer): 7 | """Validate that the transformer instance for each field is unique. 8 | 9 | Args: 10 | column_name_to_transformer (dict): 11 | A dictionary mapping column names to their current transformer. 12 | 13 | Raises: 14 | - ``InvalidConfigError`` if transformers in ``column_name_to_transformer`` are repeated. 15 | """ 16 | seen_transformers = defaultdict(set) 17 | for column_name, transformer in column_name_to_transformer.items(): 18 | if transformer is not None: 19 | seen_transformers[transformer].add(column_name) 20 | 21 | duplicated_transformers = { 22 | transformer: columns 23 | for transformer, columns in seen_transformers.items() 24 | if len(columns) > 1 25 | } 26 | if duplicated_transformers: 27 | duplicated_column_messages = [] 28 | for duplicated_columns in duplicated_transformers.values(): 29 | columns = ', '.join( 30 | sorted([ 31 | str(columns) if not isinstance(columns, str) else f"'{columns}'" 32 | for columns in duplicated_columns 33 | ]) 34 | ) 35 | duplicated_column_messages.append(f'columns ({columns})') 36 | 37 | if len(duplicated_column_messages) > 1: 38 | plurality = 'instances are' 39 | else: 40 | plurality = 'instance is' 41 | 42 | error_message = ( 43 | f'The same transformer {plurality} being assigned to ' 44 | f'{", ".join(duplicated_column_messages)}. Please create different transformer objects ' 45 | 'for each assignment.' 46 | ) 47 | raise InvalidConfigError(error_message) 48 | -------------------------------------------------------------------------------- /rdt/errors.py: -------------------------------------------------------------------------------- 1 | """RDT Exceptions.""" 2 | 3 | 4 | class ConfigNotSetError(Exception): 5 | """Error to use when no config has been set or detected.""" 6 | 7 | 8 | class InvalidConfigError(Exception): 9 | """Error to raise when something is incorrect about the config.""" 10 | 11 | 12 | class InvalidDataError(Exception): 13 | """Error to raise when the data is ill-formed in some way.""" 14 | 15 | 16 | class NotFittedError(Exception): 17 | """Error to raise when ``transform`` or ``reverse_transform`` are used before fitting.""" 18 | 19 | 20 | class TransformerInputError(Exception): 21 | """Error to raise when ``HyperTransformer`` receives an incorrect input.""" 22 | 23 | 24 | class TransformerProcessingError(Exception): 25 | """Error to raise when transformer fails to complete some process (ie. anonymization).""" 26 | -------------------------------------------------------------------------------- /rdt/performance/__init__.py: -------------------------------------------------------------------------------- 1 | """Functions to evaluate and test the performance of the RDT Transformers.""" 2 | 3 | from rdt.performance import profiling 4 | from rdt.performance.performance import evaluate_transformer_performance 5 | 6 | __all__ = [ 7 | 'evaluate_transformer_performance', 8 | 'profiling', 9 | ] 10 | -------------------------------------------------------------------------------- /rdt/performance/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Dataset Generators to test the RDT Transformers.""" 2 | 3 | from collections import defaultdict 4 | 5 | from rdt.performance.datasets import ( 6 | boolean, 7 | categorical, 8 | datetime, 9 | id, 10 | numerical, 11 | pii, 12 | text, 13 | ) 14 | from rdt.performance.datasets.base import BaseDatasetGenerator 15 | 16 | __all__ = [ 17 | 'boolean', 18 | 'categorical', 19 | 'datetime', 20 | 'id', 21 | 'numerical', 22 | 'pii', 23 | 'text', 24 | 'BaseDatasetGenerator', 25 | ] 26 | 27 | 28 | def get_dataset_generators_by_type(): 29 | """Build a ``dict`` mapping sdtypes to dataset generators. 30 | 31 | Returns: 32 | dict: 33 | Mapping of sdtype to a list of dataset generators that produce 34 | data of that sdtype. 35 | """ 36 | dataset_generators = defaultdict(list) 37 | for dataset_generator in BaseDatasetGenerator.get_subclasses(): 38 | dataset_generators[dataset_generator.SDTYPE].append(dataset_generator) 39 | 40 | return dataset_generators 41 | -------------------------------------------------------------------------------- /rdt/performance/datasets/base.py: -------------------------------------------------------------------------------- 1 | """Base class for all the Dataset Generators.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | 6 | class BaseDatasetGenerator(ABC): 7 | """Parent class for all the Dataset Generators.""" 8 | 9 | SDTYPE = None 10 | 11 | @staticmethod 12 | @abstractmethod 13 | def generate(num_rows): 14 | """Return array of data. This method serves as a template for dataset generators. 15 | 16 | Args: 17 | num_rows (int): 18 | Number of rows to generate. 19 | 20 | Returns: 21 | numpy.ndarray of size ``num_rows`` 22 | """ 23 | raise NotImplementedError() 24 | 25 | @classmethod 26 | def get_subclasses(cls): 27 | """Recursively find subclasses of this Baseline. 28 | 29 | Returns: 30 | list: 31 | List of all subclasses of this class. 32 | """ 33 | subclasses = [] 34 | for subclass in cls.__subclasses__(): 35 | if ABC not in subclass.__bases__: 36 | subclasses.append(subclass) 37 | 38 | subclasses += subclass.get_subclasses() 39 | 40 | return subclasses 41 | 42 | @staticmethod 43 | @abstractmethod 44 | def get_performance_thresholds(): 45 | """Return the expected thresholds.""" 46 | raise NotImplementedError() 47 | -------------------------------------------------------------------------------- /rdt/performance/datasets/boolean.py: -------------------------------------------------------------------------------- 1 | """Dataset Generators for boolean transformers.""" 2 | 3 | from abc import ABC 4 | 5 | import numpy as np 6 | 7 | from rdt.performance.datasets.base import BaseDatasetGenerator 8 | 9 | MAX_PERCENT_NULL = 50 # cap the percentage of null values at 50% 10 | MIN_PERCENT = 20 # the minimum percentage of true or false is 20% 11 | 12 | 13 | class BooleanGenerator(BaseDatasetGenerator, ABC): 14 | """Base class for generators that generate boolean data.""" 15 | 16 | SDTYPE = 'boolean' 17 | 18 | 19 | class RandomBooleanGenerator(BooleanGenerator): 20 | """Generator that creates dataset of random booleans.""" 21 | 22 | @staticmethod 23 | def generate(num_rows): 24 | """Generate an array of random booleans. 25 | 26 | Args: 27 | num_rows (int): 28 | Number of rows of booleans to generate. 29 | 30 | Returns: 31 | numpy.ndarray of size ``num_rows`` containing random booleans. 32 | """ 33 | return np.random.choice(a=[True, False], size=num_rows) 34 | 35 | @staticmethod 36 | def get_performance_thresholds(): 37 | """Return the expected thresholds.""" 38 | return { 39 | 'fit': {'time': 2e-5, 'memory': 400.0}, 40 | 'transform': {'time': 1e-5, 'memory': 400.0}, 41 | 'reverse_transform': { 42 | 'time': 5e-5, 43 | 'memory': 500.0, 44 | }, 45 | } 46 | 47 | 48 | class RandomBooleanNaNsGenerator(BooleanGenerator): 49 | """Generator that creates an array of random booleans with nulls.""" 50 | 51 | @staticmethod 52 | def generate(num_rows): 53 | """Generate a ``num_rows`` number of rows.""" 54 | percent_null = np.random.randint(MIN_PERCENT, MAX_PERCENT_NULL) 55 | percent_true = (100 - percent_null) / 2 56 | percent_false = 100 - percent_true - percent_null 57 | 58 | return np.random.choice( 59 | a=[True, False, None], 60 | size=num_rows, 61 | p=[percent_true / 100, percent_false / 100, percent_null / 100], 62 | ) 63 | 64 | @staticmethod 65 | def get_performance_thresholds(): 66 | """Return the expected thresholds.""" 67 | return { 68 | 'fit': {'time': 2e-5, 'memory': 400.0}, 69 | 'transform': {'time': 1e-5, 'memory': 1000.0}, 70 | 'reverse_transform': { 71 | 'time': 5e-5, 72 | 'memory': 1000.0, 73 | }, 74 | } 75 | 76 | 77 | class RandomSkewedBooleanGenerator(BooleanGenerator): 78 | """Generator that creates dataset of random booleans.""" 79 | 80 | @staticmethod 81 | def generate(num_rows): 82 | """Generate a ``num_rows`` number of rows.""" 83 | percent_true = np.random.randint(MIN_PERCENT, 100 - MIN_PERCENT) 84 | 85 | return np.random.choice( 86 | a=[True, False], 87 | size=num_rows, 88 | p=[percent_true / 100, (100 - percent_true) / 100], 89 | ) 90 | 91 | @staticmethod 92 | def get_performance_thresholds(): 93 | """Return the expected thresholds.""" 94 | return { 95 | 'fit': {'time': 1e-5, 'memory': 400.0}, 96 | 'transform': {'time': 1e-5, 'memory': 400.0}, 97 | 'reverse_transform': { 98 | 'time': 5e-5, 99 | 'memory': 500.0, 100 | }, 101 | } 102 | 103 | 104 | class RandomSkewedBooleanNaNsGenerator(BooleanGenerator): 105 | """Generator that creates an array of random booleans with nulls.""" 106 | 107 | @staticmethod 108 | def generate(num_rows): 109 | """Generate a ``num_rows`` number of rows.""" 110 | percent_null = np.random.randint(MIN_PERCENT, MAX_PERCENT_NULL) 111 | percent_true = np.random.randint(MIN_PERCENT, 100 - percent_null - MIN_PERCENT) 112 | percent_false = 100 - percent_null - percent_true 113 | 114 | return np.random.choice( 115 | a=[True, False, None], 116 | size=num_rows, 117 | p=[percent_true / 100, percent_false / 100, percent_null / 100], 118 | ) 119 | 120 | @staticmethod 121 | def get_performance_thresholds(): 122 | """Return the expected thresholds.""" 123 | return { 124 | 'fit': {'time': 1e-5, 'memory': 400.0}, 125 | 'transform': {'time': 1e-5, 'memory': 1000.0}, 126 | 'reverse_transform': { 127 | 'time': 5e-5, 128 | 'memory': 1000.0, 129 | }, 130 | } 131 | 132 | 133 | class ConstantBooleanGenerator(BooleanGenerator): 134 | """Generator that creates a constant array with either True or False.""" 135 | 136 | @staticmethod 137 | def generate(num_rows): 138 | """Generate a ``num_rows`` number of rows.""" 139 | constant = np.random.choice([True, False]) 140 | return np.full(num_rows, constant) 141 | 142 | @staticmethod 143 | def get_performance_thresholds(): 144 | """Return the expected thresholds.""" 145 | return { 146 | 'fit': {'time': 1e-5, 'memory': 400.0}, 147 | 'transform': {'time': 1e-5, 'memory': 400.0}, 148 | 'reverse_transform': { 149 | 'time': 5e-5, 150 | 'memory': 500.0, 151 | }, 152 | } 153 | 154 | 155 | class ConstantBooleanNaNsGenerator(BooleanGenerator): 156 | """Generator that creates a constant array with either True or False with some nulls.""" 157 | 158 | @staticmethod 159 | def generate(num_rows): 160 | """Generate a ``num_rows`` number of rows.""" 161 | constant = np.random.choice([True, False]) 162 | percent_null = np.random.randint(MIN_PERCENT, MAX_PERCENT_NULL) 163 | 164 | return np.random.choice( 165 | a=[constant, None], 166 | size=num_rows, 167 | p=[(100 - percent_null) / 100, percent_null / 100], 168 | ) 169 | 170 | @staticmethod 171 | def get_performance_thresholds(): 172 | """Return the expected thresholds.""" 173 | return { 174 | 'fit': {'time': 1e-5, 'memory': 400.0}, 175 | 'transform': {'time': 1e-5, 'memory': 1000.0}, 176 | 'reverse_transform': { 177 | 'time': 5e-5, 178 | 'memory': 1000.0, 179 | }, 180 | } 181 | -------------------------------------------------------------------------------- /rdt/performance/datasets/datetime.py: -------------------------------------------------------------------------------- 1 | """Dataset Generators for datetime transformers.""" 2 | 3 | import datetime 4 | from abc import ABC 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from rdt.performance.datasets.base import BaseDatasetGenerator 10 | from rdt.performance.datasets.utils import add_nans 11 | 12 | 13 | class DatetimeGenerator(BaseDatasetGenerator, ABC): 14 | """Base class for generators that generate datatime data.""" 15 | 16 | SDTYPE = 'datetime' 17 | 18 | 19 | class RandomGapDatetimeGenerator(DatetimeGenerator): 20 | """Generator that creates dates with random gaps between them.""" 21 | 22 | @staticmethod 23 | def generate(num_rows): 24 | """Generate a ``num_rows`` number of rows.""" 25 | today = datetime.datetime.today() 26 | delta = datetime.timedelta(days=1) 27 | dates = [(np.random.random() * delta + today) for i in range(num_rows)] 28 | return np.array(dates, dtype='datetime64') 29 | 30 | @staticmethod 31 | def get_performance_thresholds(): 32 | """Return the expected thresholds.""" 33 | return { 34 | 'fit': {'time': 5e-05, 'memory': 500.0}, 35 | 'transform': {'time': 5e-05, 'memory': 350.0}, 36 | 'reverse_transform': { 37 | 'time': 5e-05, 38 | 'memory': 1000.0, 39 | }, 40 | } 41 | 42 | 43 | class RandomGapSecondsDatetimeGenerator(DatetimeGenerator): 44 | """Generator that creates dates with random gaps of seconds between them.""" 45 | 46 | @staticmethod 47 | def generate(num_rows): 48 | """Generate a ``num_rows`` number of rows.""" 49 | today = datetime.datetime.today() 50 | delta = datetime.timedelta(seconds=1) 51 | dates = [(np.random.random() * delta + today) for i in range(num_rows)] 52 | return np.array(dates, dtype='datetime64') 53 | 54 | @staticmethod 55 | def get_performance_thresholds(): 56 | """Return the expected thresholds.""" 57 | return { 58 | 'fit': {'time': 5e-05, 'memory': 500.0}, 59 | 'transform': {'time': 5e-05, 'memory': 350.0}, 60 | 'reverse_transform': { 61 | 'time': 5e-05, 62 | 'memory': 1000.0, 63 | }, 64 | } 65 | 66 | 67 | class RandomGapDatetimeNaNsGenerator(DatetimeGenerator): 68 | """Generator that creates dates with random gaps and NaNs.""" 69 | 70 | @staticmethod 71 | def generate(num_rows): 72 | """Generate a ``num_rows`` number of rows.""" 73 | dates = RandomGapDatetimeGenerator.generate(num_rows) 74 | return add_nans(dates.astype('O')) 75 | 76 | @staticmethod 77 | def get_performance_thresholds(): 78 | """Return the expected thresholds.""" 79 | return { 80 | 'fit': {'time': 5e-05, 'memory': 500.0}, 81 | 'transform': {'time': 5e-05, 'memory': 1000.0}, 82 | 'reverse_transform': { 83 | 'time': 5e-05, 84 | 'memory': 1000.0, 85 | }, 86 | } 87 | 88 | 89 | class EqualGapHoursDatetimeGenerator(DatetimeGenerator): 90 | """Generator that creates dates with hour gaps between them.""" 91 | 92 | @staticmethod 93 | def generate(num_rows): 94 | """Generate a ``num_rows`` number of rows.""" 95 | today = datetime.datetime.today() 96 | delta = datetime.timedelta 97 | dates = [delta(hours=i) + today for i in range(num_rows)] 98 | return np.array(dates, dtype='datetime64') 99 | 100 | @staticmethod 101 | def get_performance_thresholds(): 102 | """Return the expected thresholds.""" 103 | return { 104 | 'fit': {'time': 5e-05, 'memory': 500.0}, 105 | 'transform': {'time': 5e-05, 'memory': 350.0}, 106 | 'reverse_transform': { 107 | 'time': 5e-05, 108 | 'memory': 1000.0, 109 | }, 110 | } 111 | 112 | 113 | class EqualGapDaysDatetimeGenerator(DatetimeGenerator): 114 | """Generator that creates dates with 1 day gaps between them.""" 115 | 116 | @staticmethod 117 | def generate(num_rows): 118 | """Generate a ``num_rows`` number of rows.""" 119 | today = datetime.datetime.today() 120 | delta = datetime.timedelta 121 | 122 | today = min(datetime.datetime.today(), pd.Timestamp.max - delta(num_rows)) 123 | dates = [delta(i) + today for i in range(num_rows)] 124 | 125 | return np.array(dates, dtype='datetime64') 126 | 127 | @staticmethod 128 | def get_performance_thresholds(): 129 | """Return the expected thresholds.""" 130 | return { 131 | 'fit': {'time': 5e-05, 'memory': 500.0}, 132 | 'transform': {'time': 5e-05, 'memory': 350.0}, 133 | 'reverse_transform': { 134 | 'time': 5e-05, 135 | 'memory': 1000.0, 136 | }, 137 | } 138 | 139 | 140 | class EqualGapWeeksDatetimeGenerator(DatetimeGenerator): 141 | """Generator that creates dates with 1 week gaps between them.""" 142 | 143 | @staticmethod 144 | def generate(num_rows): 145 | """Generate a ``num_rows`` number of rows.""" 146 | today = datetime.datetime.today() 147 | delta = datetime.timedelta 148 | 149 | today = datetime.datetime.today() 150 | dates = [min(delta(weeks=i) + today, pd.Timestamp.max) for i in range(num_rows)] 151 | 152 | return np.array(dates, dtype='datetime64') 153 | 154 | @staticmethod 155 | def get_performance_thresholds(): 156 | """Return the expected thresholds.""" 157 | return { 158 | 'fit': {'time': 5e-05, 'memory': 500.0}, 159 | 'transform': {'time': 5e-05, 'memory': 350.0}, 160 | 'reverse_transform': { 161 | 'time': 5e-05, 162 | 'memory': 1000.0, 163 | }, 164 | } 165 | -------------------------------------------------------------------------------- /rdt/performance/datasets/id.py: -------------------------------------------------------------------------------- 1 | """Dataset Generators for ID transformers.""" 2 | 3 | from abc import ABC 4 | 5 | import numpy as np 6 | 7 | from rdt.performance.datasets.base import BaseDatasetGenerator 8 | from rdt.performance.datasets.utils import add_nans 9 | 10 | 11 | class RegexGeneratorGenerator(BaseDatasetGenerator, ABC): 12 | """Base class for generators that generate ID data.""" 13 | 14 | SDTYPE = 'id' 15 | 16 | 17 | class RandomStringGenerator(RegexGeneratorGenerator): 18 | """Generator that creates an array of random strings.""" 19 | 20 | @staticmethod 21 | def generate(num_rows): 22 | """Generate a ``num_rows`` number of rows.""" 23 | categories = ['Alice', 'Bob', 'Charlie', 'Dave', 'Eve'] 24 | return np.random.choice(a=categories, size=num_rows) 25 | 26 | @staticmethod 27 | def get_performance_thresholds(): 28 | """Return the expected thresholds.""" 29 | return { 30 | 'fit': {'time': 1e-05, 'memory': 500.0}, 31 | 'transform': {'time': 1e-05, 'memory': 500.0}, 32 | 'reverse_transform': { 33 | 'time': 2e-05, 34 | 'memory': 1000.0, 35 | }, 36 | } 37 | 38 | 39 | class RandomStringNaNsGenerator(RegexGeneratorGenerator): 40 | """Generator that creates an array of random strings with nans.""" 41 | 42 | @staticmethod 43 | def generate(num_rows): 44 | """Generate a ``num_rows`` number of rows.""" 45 | return add_nans(RandomStringGenerator.generate(num_rows).astype('O')) 46 | 47 | @staticmethod 48 | def get_performance_thresholds(): 49 | """Return the expected thresholds.""" 50 | return { 51 | 'fit': {'time': 1e-05, 'memory': 400.0}, 52 | 'transform': {'time': 1e-05, 'memory': 1000.0}, 53 | 'reverse_transform': { 54 | 'time': 2e-05, 55 | 'memory': 1000.0, 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /rdt/performance/datasets/numerical.py: -------------------------------------------------------------------------------- 1 | """Dataset Generators for numerical transformers.""" 2 | 3 | from abc import ABC 4 | 5 | import numpy as np 6 | 7 | from rdt.performance.datasets.base import BaseDatasetGenerator 8 | from rdt.performance.datasets.utils import add_nans 9 | 10 | 11 | class NumericalGenerator(BaseDatasetGenerator, ABC): 12 | """Base class for generators that create numerical data.""" 13 | 14 | SDTYPE = 'numerical' 15 | 16 | 17 | class RandomIntegerGenerator(NumericalGenerator): 18 | """Generator that creates an array of random integers.""" 19 | 20 | @staticmethod 21 | def generate(num_rows): 22 | """Generate a ``num_rows`` number of rows.""" 23 | ii32 = np.iinfo(np.int32) 24 | return np.random.randint(ii32.min, ii32.max, num_rows) 25 | 26 | @staticmethod 27 | def get_performance_thresholds(): 28 | """Return the expected thresholds.""" 29 | return { 30 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 31 | 'transform': {'time': 5e-05, 'memory': 400.0}, 32 | 'reverse_transform': { 33 | 'time': 5e-05, 34 | 'memory': 400.0, 35 | }, 36 | } 37 | 38 | 39 | class RandomIntegerNaNsGenerator(NumericalGenerator): 40 | """Generator that creates an array of random integers with nans.""" 41 | 42 | @staticmethod 43 | def generate(num_rows): 44 | """Generate a ``num_rows`` number of rows.""" 45 | return add_nans(RandomIntegerGenerator.generate(num_rows).astype(float)) 46 | 47 | @staticmethod 48 | def get_performance_thresholds(): 49 | """Return the expected thresholds.""" 50 | return { 51 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 52 | 'transform': {'time': 4e-05, 'memory': 400.0}, 53 | 'reverse_transform': { 54 | 'time': 2e-05, 55 | 'memory': 350.0, 56 | }, 57 | } 58 | 59 | 60 | class ConstantIntegerGenerator(NumericalGenerator): 61 | """Generator that creates a constant array with a random integer.""" 62 | 63 | @staticmethod 64 | def generate(num_rows): 65 | """Generate a ``num_rows`` number of rows.""" 66 | ii32 = np.iinfo(np.int32) 67 | constant = np.random.randint(ii32.min, ii32.max) 68 | return np.full(num_rows, constant) 69 | 70 | @staticmethod 71 | def get_performance_thresholds(): 72 | """Return the expected thresholds.""" 73 | return { 74 | 'fit': {'time': 1e-03, 'memory': 400.0}, 75 | 'transform': {'time': 1e-05, 'memory': 400.0}, 76 | 'reverse_transform': { 77 | 'time': 5e-05, 78 | 'memory': 400.0, 79 | }, 80 | } 81 | 82 | 83 | class ConstantIntegerNaNsGenerator(NumericalGenerator): 84 | """Generator that creates a constant array with a random integer with some nans.""" 85 | 86 | @staticmethod 87 | def generate(num_rows): 88 | """Generate a ``num_rows`` number of rows.""" 89 | return add_nans(ConstantIntegerGenerator.generate(num_rows).astype(float)) 90 | 91 | @staticmethod 92 | def get_performance_thresholds(): 93 | """Return the expected thresholds.""" 94 | return { 95 | 'fit': {'time': 1e-03, 'memory': 600.0}, 96 | 'transform': {'time': 3e-05, 'memory': 400.0}, 97 | 'reverse_transform': { 98 | 'time': 2e-05, 99 | 'memory': 350.0, 100 | }, 101 | } 102 | 103 | 104 | class AlmostConstantIntegerGenerator(NumericalGenerator): 105 | """Generator that creates an array with 2 only values, one of them repeated.""" 106 | 107 | @staticmethod 108 | def generate(num_rows): 109 | """Generate a ``num_rows`` number of rows.""" 110 | ii32 = np.iinfo(np.int32) 111 | values = np.random.randint(ii32.min, ii32.max, size=2) 112 | additional_values = np.full(num_rows - 2, values[1]) 113 | array = np.concatenate([values, additional_values]) 114 | np.random.shuffle(array) 115 | return array 116 | 117 | @staticmethod 118 | def get_performance_thresholds(): 119 | """Return the expected thresholds.""" 120 | return { 121 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 122 | 'transform': {'time': 1e-05, 'memory': 2000.0}, 123 | 'reverse_transform': { 124 | 'time': 5e-05, 125 | 'memory': 2000.0, 126 | }, 127 | } 128 | 129 | 130 | class AlmostConstantIntegerNaNsGenerator(NumericalGenerator): 131 | """Generator that creates an array with 2 only values, one of them repeated, and NaNs.""" 132 | 133 | @staticmethod 134 | def generate(num_rows): 135 | """Generate a ``num_rows`` number of rows.""" 136 | ii32 = np.iinfo(np.int32) 137 | values = np.random.randint(ii32.min, ii32.max, size=2) 138 | additional_values = np.full(num_rows - 2, values[1]).astype(float) 139 | array = np.concatenate([values, add_nans(additional_values)]) 140 | np.random.shuffle(array) 141 | return array 142 | 143 | @staticmethod 144 | def get_performance_thresholds(): 145 | """Return the expected thresholds.""" 146 | return { 147 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 148 | 'transform': {'time': 3e-05, 'memory': 1000.0}, 149 | 'reverse_transform': { 150 | 'time': 2e-05, 151 | 'memory': 1000.0, 152 | }, 153 | } 154 | 155 | 156 | class NormalGenerator(NumericalGenerator): 157 | """Generator that creates an array of normally distributed float values.""" 158 | 159 | @staticmethod 160 | def generate(num_rows): 161 | """Generate a ``num_rows`` number of rows.""" 162 | return np.random.normal(size=num_rows) 163 | 164 | @staticmethod 165 | def get_performance_thresholds(): 166 | """Return the expected thresholds.""" 167 | return { 168 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 169 | 'transform': {'time': 1e-05, 'memory': 400.0}, 170 | 'reverse_transform': { 171 | 'time': 1e-05, 172 | 'memory': 400.0, 173 | }, 174 | } 175 | 176 | 177 | class NormalNaNsGenerator(NumericalGenerator): 178 | """Generator that creates an array of normally distributed float values, with NaNs.""" 179 | 180 | @staticmethod 181 | def generate(num_rows): 182 | """Generate a ``num_rows`` number of rows.""" 183 | return add_nans(NormalGenerator.generate(num_rows)) 184 | 185 | @staticmethod 186 | def get_performance_thresholds(): 187 | """Return the expected thresholds.""" 188 | return { 189 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 190 | 'transform': {'time': 4e-05, 'memory': 400.0}, 191 | 'reverse_transform': { 192 | 'time': 5e-05, 193 | 'memory': 350.0, 194 | }, 195 | } 196 | 197 | 198 | class BigNormalGenerator(NumericalGenerator): 199 | """Generator that creates an array of big normally distributed float values.""" 200 | 201 | @staticmethod 202 | def generate(num_rows): 203 | """Generate a ``num_rows`` number of rows.""" 204 | return np.random.normal(scale=1e10, size=num_rows) 205 | 206 | @staticmethod 207 | def get_performance_thresholds(): 208 | """Return the expected thresholds.""" 209 | return { 210 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 211 | 'transform': {'time': 5e-05, 'memory': 400.0}, 212 | 'reverse_transform': { 213 | 'time': 5e-05, 214 | 'memory': 400.0, 215 | }, 216 | } 217 | 218 | 219 | class BigNormalNaNsGenerator(NumericalGenerator): 220 | """Generator that creates an array of normally distributed float values, with NaNs.""" 221 | 222 | @staticmethod 223 | def generate(num_rows): 224 | """Generate a ``num_rows`` number of rows.""" 225 | return add_nans(BigNormalGenerator.generate(num_rows)) 226 | 227 | @staticmethod 228 | def get_performance_thresholds(): 229 | """Return the expected thresholds.""" 230 | return { 231 | 'fit': {'time': 1e-03, 'memory': 2500.0}, 232 | 'transform': {'time': 3e-05, 'memory': 400.0}, 233 | 'reverse_transform': { 234 | 'time': 2e-05, 235 | 'memory': 350.0, 236 | }, 237 | } 238 | -------------------------------------------------------------------------------- /rdt/performance/datasets/pii.py: -------------------------------------------------------------------------------- 1 | """Dataset Generators for Personal Identifiable Information transformers.""" 2 | 3 | from abc import ABC 4 | 5 | import numpy as np 6 | 7 | from rdt.performance.datasets.base import BaseDatasetGenerator 8 | from rdt.performance.datasets.utils import add_nans 9 | 10 | 11 | class PIIGenerator(BaseDatasetGenerator, ABC): 12 | """Base class for generators that generate PII data.""" 13 | 14 | SDTYPE = 'pii' 15 | 16 | 17 | class RandomStringGenerator(PIIGenerator): 18 | """Generator that creates an array of random strings.""" 19 | 20 | @staticmethod 21 | def generate(num_rows): 22 | """Generate a ``num_rows`` number of rows.""" 23 | categories = ['Alice', 'Bob', 'Charlie', 'Dave', 'Eve'] 24 | return np.random.choice(a=categories, size=num_rows) 25 | 26 | @staticmethod 27 | def get_performance_thresholds(): 28 | """Return the expected thresholds.""" 29 | return { 30 | 'fit': {'time': 1e-05, 'memory': 500.0}, 31 | 'transform': {'time': 1e-05, 'memory': 500.0}, 32 | 'reverse_transform': { 33 | 'time': 3e-05, 34 | 'memory': 1000.0, 35 | }, 36 | } 37 | 38 | 39 | class RandomStringNaNsGenerator(PIIGenerator): 40 | """Generator that creates an array of random strings with nans.""" 41 | 42 | @staticmethod 43 | def generate(num_rows): 44 | """Generate a ``num_rows`` number of rows.""" 45 | return add_nans(RandomStringGenerator.generate(num_rows).astype('O')) 46 | 47 | @staticmethod 48 | def get_performance_thresholds(): 49 | """Return the expected thresholds.""" 50 | return { 51 | 'fit': {'time': 1e-05, 'memory': 400.0}, 52 | 'transform': {'time': 1e-05, 'memory': 1000.0}, 53 | 'reverse_transform': { 54 | 'time': 3e-05, 55 | 'memory': 1000.0, 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /rdt/performance/datasets/text.py: -------------------------------------------------------------------------------- 1 | """Dataset Generators for 'text' transformers.""" 2 | 3 | from abc import ABC 4 | 5 | import numpy as np 6 | 7 | from rdt.performance.datasets.base import BaseDatasetGenerator 8 | from rdt.performance.datasets.utils import add_nans 9 | 10 | 11 | class RegexGeneratorGenerator(BaseDatasetGenerator, ABC): 12 | """Base class for generators that generate ID data.""" 13 | 14 | SDTYPE = 'text' 15 | 16 | 17 | class RandomStringGenerator(RegexGeneratorGenerator): 18 | """Generator that creates an array of random strings.""" 19 | 20 | @staticmethod 21 | def generate(num_rows): 22 | """Generate a ``num_rows`` number of rows.""" 23 | categories = ['Alice', 'Bob', 'Charlie', 'Dave', 'Eve'] 24 | return np.random.choice(a=categories, size=num_rows) 25 | 26 | @staticmethod 27 | def get_performance_thresholds(): 28 | """Return the expected thresholds.""" 29 | return { 30 | 'fit': {'time': 1e-05, 'memory': 500.0}, 31 | 'transform': {'time': 1e-05, 'memory': 500.0}, 32 | 'reverse_transform': { 33 | 'time': 2e-05, 34 | 'memory': 1000.0, 35 | }, 36 | } 37 | 38 | 39 | class RandomStringNaNsGenerator(RegexGeneratorGenerator): 40 | """Generator that creates an array of random strings with nans.""" 41 | 42 | @staticmethod 43 | def generate(num_rows): 44 | """Generate a ``num_rows`` number of rows.""" 45 | return add_nans(RandomStringGenerator.generate(num_rows).astype('O')) 46 | 47 | @staticmethod 48 | def get_performance_thresholds(): 49 | """Return the expected thresholds.""" 50 | return { 51 | 'fit': {'time': 1e-05, 'memory': 400.0}, 52 | 'transform': {'time': 1e-05, 'memory': 1000.0}, 53 | 'reverse_transform': { 54 | 'time': 2e-05, 55 | 'memory': 1000.0, 56 | }, 57 | } 58 | -------------------------------------------------------------------------------- /rdt/performance/datasets/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for the dataset generators.""" 2 | 3 | import numpy as np 4 | 5 | 6 | def add_nans(array): 7 | """Add a random amount of NaN values to the given array. 8 | 9 | Args: 10 | array (np.array): 11 | 1 dimensional numpy array. 12 | 13 | Returns: 14 | np.array: 15 | The same array with some values replaced by NaNs. 16 | """ 17 | if array.dtype.kind == 'i': 18 | array = array.astype(float) 19 | 20 | length = len(array) 21 | num_nulls = np.random.randint(1, length) 22 | nulls = np.random.choice(range(length), num_nulls) 23 | array[nulls] = np.nan 24 | return array 25 | -------------------------------------------------------------------------------- /rdt/performance/performance.py: -------------------------------------------------------------------------------- 1 | """Functions for evaluating transformer performance.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from rdt.performance.profiling import profile_transformer 7 | 8 | DATASET_SIZES = [1000, 10000, 100000] 9 | 10 | # Additional arguments for transformers 11 | TRANSFORMER_ARGS = { 12 | 'BinaryEncoder': { 13 | 'missing_value_replacement': -1, 14 | 'missing_value_generation': 'from_column', 15 | }, 16 | 'UnixTimestampEncoder': {'missing_value_generation': 'from_column'}, 17 | 'OptimizedTimestampEncoder': {'missing_value_generation': 'from_column'}, 18 | 'FloatFormatter': {'missing_value_generation': 'from_column'}, 19 | 'GaussianNormalizer': {'missing_value_generation': 'from_column'}, 20 | 'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'}, 21 | } 22 | 23 | 24 | def _get_dataset_sizes(sdtype): 25 | """Get a list of (fit_size, transform_size) for each dataset generator. 26 | 27 | Based on the sdtype of the dataset generator, return the list of 28 | sizes to run performance tests on. Each element in this list is a tuple 29 | of (fit_size, transform_size). 30 | 31 | Args: 32 | sdtype (str): 33 | The type of data that the generator returns. 34 | 35 | Returns: 36 | sizes (list[tuple]): 37 | A list of (fit_size, transform_size) configs to run tests on. 38 | """ 39 | sizes = [(s, s) for s in DATASET_SIZES] 40 | 41 | if sdtype == 'categorical': 42 | sizes = [(s, max(s, 1000)) for s in DATASET_SIZES if s <= 10000] 43 | 44 | return sizes 45 | 46 | 47 | def evaluate_transformer_performance(transformer, dataset_generator, verbose=False): 48 | """Evaluate the given transformer's performance against the given dataset generator. 49 | 50 | Args: 51 | transformer (rdt.transformers.BaseTransformer): 52 | The transformer to evaluate. 53 | dataset_generator (rdt.tests.datasets.BaseDatasetGenerator): 54 | The dataset generator to performance test against. 55 | verbose (bool): 56 | Whether or not to add extra columns about the dataset and transformer, 57 | and return data for all dataset sizes. If false, it will only return 58 | the max performance values of all the dataset sizes used. 59 | 60 | Returns: 61 | pandas.DataFrame: 62 | The performance test results. 63 | """ 64 | transformer_args = TRANSFORMER_ARGS.get(transformer.get_name(), {}) 65 | transformer_instance = transformer(**transformer_args) 66 | 67 | sizes = _get_dataset_sizes(dataset_generator.SDTYPE) 68 | 69 | out = [] 70 | for fit_size, transform_size in sizes: 71 | performance = profile_transformer( 72 | transformer=transformer_instance, 73 | dataset_generator=dataset_generator, 74 | fit_size=fit_size, 75 | transform_size=transform_size, 76 | ) 77 | size = np.array([fit_size, transform_size, transform_size] * 2) 78 | performance = performance / size 79 | if verbose: 80 | performance = performance.rename(lambda x: x + ' (s)' if 'Time' in x else x + ' (B)') 81 | performance['Number of fit rows'] = fit_size 82 | performance['Number of transform rows'] = transform_size 83 | performance['Dataset'] = dataset_generator.__name__ 84 | performance['Transformer'] = f'{transformer.__module__}.{transformer.get_name()}' 85 | 86 | out.append(performance) 87 | 88 | summary = pd.DataFrame(out) 89 | if verbose: 90 | return summary 91 | 92 | return summary.max(axis=0) 93 | -------------------------------------------------------------------------------- /rdt/performance/profiling.py: -------------------------------------------------------------------------------- 1 | """Functions to profile performance of RDT Transformers.""" 2 | 3 | # pylint: disable=W0212 4 | 5 | import multiprocessing as mp 6 | import timeit 7 | import tracemalloc 8 | from copy import deepcopy 9 | 10 | import pandas as pd 11 | 12 | 13 | def _profile_time(transformer, method_name, dataset, column=None, iterations=10, copy=False): 14 | total_time = 0 15 | for _ in range(iterations): 16 | if copy: 17 | transformer_copy = deepcopy(transformer) 18 | method = getattr(transformer_copy, method_name) 19 | 20 | else: 21 | method = getattr(transformer, method_name) 22 | 23 | start_time = timeit.default_timer() 24 | if column: 25 | method(dataset, column) 26 | else: 27 | method(dataset) 28 | total_time += timeit.default_timer() - start_time 29 | 30 | return total_time / iterations 31 | 32 | 33 | def _set_memory_for_method(method, dataset, column, peak_memory): 34 | tracemalloc.start() 35 | if column: 36 | method(dataset, column) 37 | else: 38 | method(dataset) 39 | 40 | peak_memory.value = tracemalloc.get_traced_memory()[1] 41 | tracemalloc.stop() 42 | tracemalloc.clear_traces() 43 | 44 | 45 | def _profile_memory(method, dataset, column=None): 46 | ctx = mp.get_context('spawn') 47 | peak_memory = ctx.Value('i', 0) 48 | profiling_process = ctx.Process( 49 | target=_set_memory_for_method, 50 | args=(method, dataset, column, peak_memory), 51 | ) 52 | profiling_process.start() 53 | profiling_process.join() 54 | return peak_memory.value 55 | 56 | 57 | def profile_transformer(transformer, dataset_generator, transform_size, fit_size=None): 58 | """Profile a Transformer on a dataset. 59 | 60 | This function will get the total time and peak memory 61 | for the ``fit``, ``transform`` and ``reverse_transform`` 62 | methods of the provided transformer against the provided 63 | dataset. 64 | 65 | Args: 66 | transformer (Transformer): 67 | Transformer instance. 68 | dataset_generator (DatasetGenerator): 69 | DatasetGenerator instance. 70 | transform_size (int): 71 | Number of rows to generate for ``transform`` and ``reverse_transform``. 72 | fit_size (int or None): 73 | Number of rows to generate for ``fit``. If None, use ``transform_size``. 74 | 75 | Returns: 76 | pandas.Series: 77 | Series containing the time and memory taken by ``fit``, ``transform``, 78 | and ``reverse_transform`` for the transformer. 79 | """ 80 | fit_size = fit_size or transform_size 81 | fit_dataset = pd.DataFrame({'test': dataset_generator.generate(fit_size)}) 82 | replace = transform_size > fit_size 83 | transform_dataset = fit_dataset.sample(transform_size, replace=replace) 84 | 85 | fit_time = _profile_time(transformer, 'fit', fit_dataset, column='test', copy=True) 86 | fit_memory = _profile_memory(transformer.fit, fit_dataset, column='test') 87 | transformer.fit(fit_dataset, 'test') 88 | 89 | transform_time = _profile_time(transformer, 'transform', transform_dataset) 90 | transform_memory = _profile_memory(transformer.transform, transform_dataset) 91 | 92 | reverse_dataset = transformer.transform(transform_dataset) 93 | reverse_time = _profile_time(transformer, 'reverse_transform', reverse_dataset) 94 | reverse_memory = _profile_memory(transformer.reverse_transform, reverse_dataset) 95 | 96 | return pd.Series({ 97 | 'Fit Time': fit_time, 98 | 'Fit Memory': fit_memory, 99 | 'Transform Time': transform_time, 100 | 'Transform Memory': transform_memory, 101 | 'Reverse Transform Time': reverse_time, 102 | 'Reverse Transform Memory': reverse_memory, 103 | }) 104 | -------------------------------------------------------------------------------- /rdt/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | """Transformers module.""" 2 | 3 | import importlib 4 | import inspect 5 | import warnings 6 | from collections import defaultdict 7 | from copy import deepcopy 8 | from functools import lru_cache 9 | 10 | from rdt.transformers.base import BaseMultiColumnTransformer, BaseTransformer 11 | from rdt.transformers.boolean import BinaryEncoder 12 | from rdt.transformers.categorical import ( 13 | CustomLabelEncoder, 14 | FrequencyEncoder, 15 | LabelEncoder, 16 | OneHotEncoder, 17 | OrderedLabelEncoder, 18 | OrderedUniformEncoder, 19 | UniformEncoder, 20 | ) 21 | from rdt.transformers.datetime import ( 22 | OptimizedTimestampEncoder, 23 | UnixTimestampEncoder, 24 | ) 25 | from rdt.transformers.id import IDGenerator, IndexGenerator, RegexGenerator 26 | from rdt.transformers.null import NullTransformer 27 | from rdt.transformers.numerical import ( 28 | ClusterBasedNormalizer, 29 | FloatFormatter, 30 | GaussianNormalizer, 31 | LogScaler, 32 | LogitScaler, 33 | ) 34 | from rdt.transformers.pii.anonymizer import ( 35 | AnonymizedFaker, 36 | PseudoAnonymizedFaker, 37 | ) 38 | from rdt.transformers.utils import WarnDict 39 | 40 | __all__ = [ 41 | 'BaseTransformer', 42 | 'BaseMultiColumnTransformer', 43 | 'BinaryEncoder', 44 | 'ClusterBasedNormalizer', 45 | 'CustomLabelEncoder', 46 | 'OrderedLabelEncoder', 47 | 'FloatFormatter', 48 | 'FrequencyEncoder', 49 | 'GaussianNormalizer', 50 | 'LabelEncoder', 51 | 'LogScaler', 52 | 'NullTransformer', 53 | 'OneHotEncoder', 54 | 'OptimizedTimestampEncoder', 55 | 'UnixTimestampEncoder', 56 | 'RegexGenerator', 57 | 'AnonymizedFaker', 58 | 'PseudoAnonymizedFaker', 59 | 'IDGenerator', 60 | 'IndexGenerator', 61 | 'get_transformer_name', 62 | 'get_transformer_class', 63 | 'get_transformers_by_type', 64 | 'get_default_transformers', 65 | 'get_default_transformer', 66 | 'UniformEncoder', 67 | 'OrderedUniformEncoder', 68 | ] 69 | 70 | 71 | def get_transformer_name(transformer): 72 | """Return the fully qualified path of the transformer. 73 | 74 | Args: 75 | transformer: 76 | A transformer class. 77 | 78 | Raises: 79 | ValueError: 80 | Crashes when the transformer is not passed as a class. 81 | 82 | Returns: 83 | string: 84 | The path of the transformer. 85 | """ 86 | if inspect.isclass(transformer): 87 | return transformer.__module__ + '.' + transformer.get_name() 88 | 89 | raise ValueError(f'The transformer {transformer} must be passed as a class.') 90 | 91 | 92 | TRANSFORMERS = { 93 | get_transformer_name(transformer): transformer 94 | for transformer in BaseTransformer.get_subclasses() 95 | } 96 | 97 | DEFAULT_TRANSFORMERS = WarnDict( 98 | boolean=UniformEncoder(), 99 | categorical=UniformEncoder(), 100 | datetime=UnixTimestampEncoder(), 101 | id=RegexGenerator(), 102 | numerical=FloatFormatter(), 103 | pii=AnonymizedFaker(), 104 | text=RegexGenerator(), 105 | ) 106 | 107 | 108 | @lru_cache() 109 | def get_class_by_transformer_name(): 110 | """Return a transformer class from a transformer name. 111 | 112 | Args: 113 | transformer_name (str): 114 | Transformer name ('LabelEncoder', 'FloatFormatter', etc). 115 | 116 | Returns: 117 | BaseTransformer: 118 | BaseTransformer subclass class object. 119 | """ 120 | return {class_.get_name(): class_ for class_ in BaseTransformer.get_subclasses()} 121 | 122 | 123 | def get_transformer_class(transformer): 124 | """Return a ``transformer`` class from a ``str``. 125 | 126 | Args: 127 | transformer (str): 128 | Python path. 129 | 130 | Returns: 131 | BaseTransformer: 132 | BaseTransformer subclass class object. 133 | """ 134 | if transformer in TRANSFORMERS: 135 | return TRANSFORMERS[transformer] 136 | 137 | package, name = transformer.rsplit('.', 1) 138 | return getattr(importlib.import_module(package), name) 139 | 140 | 141 | @lru_cache() 142 | def get_transformers_by_type(): 143 | """Build a ``dict`` mapping sdtypes to valid existing transformers for that sdtype. 144 | 145 | Returns: 146 | dict: 147 | Mapping of sdtypes to a list of existing transformers that take that 148 | sdtype as an input. 149 | """ 150 | sdtype_transformers = defaultdict(list) 151 | transformer_classes = BaseTransformer.get_subclasses() 152 | for transformer in transformer_classes: 153 | for sdtype in transformer.get_supported_sdtypes(): 154 | sdtype_transformers[sdtype].append(transformer) 155 | 156 | return sdtype_transformers 157 | 158 | 159 | @lru_cache() 160 | def get_default_transformers(): 161 | """Build a ``dict`` mapping sdtypes to a default transformer for that sdtype. 162 | 163 | Returns: 164 | dict: 165 | Mapping of sdtypes to a transformer. 166 | """ 167 | transformers_by_type = get_transformers_by_type() 168 | defaults = deepcopy(DEFAULT_TRANSFORMERS) 169 | for sdtype, transformers in transformers_by_type.items(): 170 | if sdtype not in defaults: 171 | defaults[sdtype] = transformers[0]() 172 | 173 | return defaults 174 | 175 | 176 | @lru_cache() 177 | def get_default_transformer(sdtype): 178 | """Get default transformer for a sdtype. 179 | 180 | Returns: 181 | Transformer: 182 | Default transformer for sdtype. 183 | """ 184 | default_transformers = get_default_transformers() 185 | return default_transformers[sdtype] 186 | -------------------------------------------------------------------------------- /rdt/transformers/_validators.py: -------------------------------------------------------------------------------- 1 | """Validations for multi-column transformers.""" 2 | 3 | import importlib 4 | 5 | from rdt.errors import TransformerInputError 6 | 7 | 8 | class BaseValidator: 9 | """Base validation class. 10 | 11 | The validation classes ensure that the input data is compatible with the transformers 12 | and that they can be imported. 13 | """ 14 | 15 | SUPPORTED_SDTYPES = [] 16 | VALIDATION_TYPE = None 17 | 18 | @classmethod 19 | def _validate_supported_sdtypes(cls, columns_to_sdtypes): 20 | message = '' 21 | for column, sdtype in columns_to_sdtypes.items(): 22 | if sdtype not in cls.SUPPORTED_SDTYPES: 23 | message += f"Column '{column}' has an unsupported sdtype '{sdtype}'.\n" 24 | 25 | if message: 26 | message += ( 27 | f'Please provide a column that is compatible with {cls.VALIDATION_TYPE} data.' 28 | ) 29 | raise TransformerInputError(message) 30 | 31 | @classmethod 32 | def validate_sdtypes(cls, columns_to_sdtypes): 33 | """Validate the columns to sdtypes mapping. 34 | 35 | This method aims to call all other sdtype validation method in the class. 36 | 37 | Args: 38 | columns_to_sdtypes (dict): 39 | Mapping of column names to sdtypes. 40 | """ 41 | raise NotImplementedError 42 | 43 | @classmethod 44 | def validate_imports(cls): 45 | """Check that the transformers can be imported.""" 46 | raise NotImplementedError 47 | 48 | @classmethod 49 | def validate(cls, columns_to_sdtypes): 50 | """Validate the input data. 51 | 52 | Args: 53 | columns_to_sdtypes (dict): 54 | Mapping of column names to sdtypes. 55 | """ 56 | cls.validate_sdtypes(columns_to_sdtypes) 57 | cls.validate_imports() 58 | 59 | 60 | class AddressValidator(BaseValidator): 61 | """Validation class for Address data.""" 62 | 63 | SUPPORTED_SDTYPES = [ 64 | 'country_code', 65 | 'administrative_unit', 66 | 'city', 67 | 'postcode', 68 | 'street_address', 69 | 'secondary_address', 70 | 'state', 71 | 'state_abbr', 72 | ] 73 | VALIDATION_TYPE = 'Address' 74 | 75 | @classmethod 76 | def _validate_number_columns(cls, columns_to_sdtypes): 77 | if len(columns_to_sdtypes) > 7: 78 | raise TransformerInputError( 79 | f'{cls.VALIDATION_TYPE} transformers takes up to 7 columns to transform. ' 80 | 'Please provide address data with valid fields.' 81 | ) 82 | 83 | @staticmethod 84 | def _validate_uniqueness_sdtype(columns_to_sdtypes): 85 | sdtypes_to_columns = {} 86 | for column, sdtype in columns_to_sdtypes.items(): 87 | if sdtype not in sdtypes_to_columns: 88 | sdtypes_to_columns[sdtype] = [] 89 | 90 | sdtypes_to_columns[sdtype].append(column) 91 | 92 | duplicate_fields = { 93 | value: keys for value, keys in sdtypes_to_columns.items() if len(keys) > 1 94 | } 95 | 96 | if duplicate_fields: 97 | message = '' 98 | for sdtype, columns in duplicate_fields.items(): 99 | to_print = "', '".join(columns) 100 | message += f"Columns '{to_print}' have the same sdtype '{sdtype}'.\n" 101 | 102 | message += 'Your address data cannot have duplicate fields.' 103 | raise TransformerInputError(message) 104 | 105 | @classmethod 106 | def _validate_administrative_unit(cls, columns_to_sdtypes): 107 | num_column_administrative_unit = sum( 108 | 1 for itm in columns_to_sdtypes.values() if itm in ['administrative_unit', 'state'] 109 | ) 110 | if num_column_administrative_unit > 1: 111 | raise TransformerInputError( 112 | f"The {cls.__name__} can have up to 1 column with sdtype 'state'" 113 | f" or 'administrative_unit'. Please provide address data with valid fields." 114 | ) 115 | 116 | @classmethod 117 | def validate_sdtypes(cls, columns_to_sdtypes): 118 | """Validate the columns to sdtypes mapping.""" 119 | cls._validate_supported_sdtypes(columns_to_sdtypes) 120 | cls._validate_number_columns(columns_to_sdtypes) 121 | cls._validate_uniqueness_sdtype(columns_to_sdtypes) 122 | cls._validate_administrative_unit(columns_to_sdtypes) 123 | 124 | @classmethod 125 | def validate_imports(cls): 126 | """Check that the address transformers can be imported.""" 127 | error_message = ( 128 | 'You must have SDV Enterprise with the address add-on to use the address features.' 129 | ) 130 | 131 | try: 132 | address_module = importlib.import_module('rdt.transformers.address') 133 | except ModuleNotFoundError: 134 | raise ImportError(error_message) from None 135 | 136 | required_classes = ['RandomLocationGenerator', 'RegionalAnonymizer'] 137 | for class_name in required_classes: 138 | if not hasattr(address_module, class_name): 139 | raise ImportError(error_message) 140 | 141 | 142 | class GPSValidator(BaseValidator): 143 | """Validation class for GPS data.""" 144 | 145 | SUPPORTED_SDTYPES = ['latitude', 'longitude'] 146 | VALIDATION_TYPE = 'GPS' 147 | 148 | @staticmethod 149 | def _validate_uniqueness_sdtype(columns_to_sdtypes): 150 | sdtypes_to_columns = {sdtype: column for column, sdtype in columns_to_sdtypes.items()} 151 | if len(sdtypes_to_columns) != 2: 152 | raise TransformerInputError( 153 | 'The GPS columns must have one latitude and on longitude columns sdtypes. ' 154 | 'Please provide GPS data with valid fields.' 155 | ) 156 | 157 | @classmethod 158 | def validate_sdtypes(cls, columns_to_sdtypes): 159 | """Validate the columns to sdtypes mapping.""" 160 | cls._validate_supported_sdtypes(columns_to_sdtypes) 161 | cls._validate_uniqueness_sdtype(columns_to_sdtypes) 162 | 163 | @classmethod 164 | def validate_imports(cls): 165 | """Check that the GPS transformers can be imported.""" 166 | error_message = 'You must have SDV Enterprise with the gps add-on to use the GPS features.' 167 | 168 | try: 169 | gps_module = importlib.import_module('rdt.transformers.gps') 170 | except ModuleNotFoundError: 171 | raise ImportError(error_message) from None 172 | 173 | required_classes = [ 174 | 'RandomLocationGenerator', 175 | 'MetroAreaAnonymizer', 176 | 'GPSNoiser', 177 | ] 178 | for class_name in required_classes: 179 | if not hasattr(gps_module, class_name): 180 | raise ImportError(error_message) 181 | -------------------------------------------------------------------------------- /rdt/transformers/boolean.py: -------------------------------------------------------------------------------- 1 | """Transformer for boolean data.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from rdt.transformers.base import BaseTransformer 7 | from rdt.transformers.null import NullTransformer 8 | 9 | 10 | class BinaryEncoder(BaseTransformer): 11 | """Transformer for boolean data. 12 | 13 | This transformer replaces boolean values with their integer representation 14 | transformed to float. 15 | 16 | Null values are replaced using a ``NullTransformer``. 17 | 18 | Args: 19 | missing_value_replacement (object): 20 | Indicate what to replace the null values with. If the string ``'mode'`` is given, 21 | replace them with the most common value. 22 | Defaults to ``mode``. 23 | model_missing_values (bool): 24 | **DEPRECATED** Whether to create a new column to indicate which values were null or 25 | not. The column will be created only if there are null values. If ``True``, create 26 | the new column if there are null values. If ``False``, do not create the new column 27 | even if there are null values. Defaults to ``False``. 28 | missing_value_generation (str or None): 29 | The way missing values are being handled. There are three strategies: 30 | 31 | * ``random``: Randomly generates missing values based on the percentage of 32 | missing values. 33 | * ``from_column``: Creates a binary column that describes whether the original 34 | value was missing. Then use it to recreate missing values. 35 | * ``None``: Do nothing with the missing values on the reverse transform. Simply 36 | pass whatever data we get through. 37 | """ 38 | 39 | INPUT_SDTYPE = 'boolean' 40 | null_transformer = None 41 | 42 | def __init__( 43 | self, 44 | missing_value_replacement='mode', 45 | model_missing_values=None, 46 | missing_value_generation='random', 47 | ): 48 | super().__init__() 49 | self.missing_value_replacement = missing_value_replacement 50 | self._set_missing_value_generation(missing_value_generation) 51 | if model_missing_values is not None: 52 | self._set_model_missing_values(model_missing_values) 53 | 54 | def _fit(self, data): 55 | """Fit the transformer to the data. 56 | 57 | Args: 58 | data (pandas.Series): 59 | Data to fit to. 60 | """ 61 | self.null_transformer = NullTransformer( 62 | self.missing_value_replacement, self.missing_value_generation 63 | ) 64 | self.null_transformer.fit(data) 65 | if self.null_transformer.models_missing_values(): 66 | self.output_properties['is_null'] = { 67 | 'sdtype': 'float', 68 | 'next_transformer': None, 69 | } 70 | 71 | def _transform(self, data): 72 | """Transform boolean to float. 73 | 74 | The boolean values will be replaced by the corresponding integer 75 | representations as float values. 76 | 77 | Args: 78 | data (pandas.Series): 79 | Data to transform. 80 | 81 | Returns: 82 | np.ndarray 83 | """ 84 | data = pd.to_numeric(data, errors='coerce') 85 | return self.null_transformer.transform(data).astype(float) 86 | 87 | def _reverse_transform(self, data): 88 | """Transform float values back to the original boolean values. 89 | 90 | Args: 91 | data (pandas.DataFrame or pandas.Series): 92 | Data to revert. 93 | 94 | Returns: 95 | pandas.Series: 96 | Reverted data. 97 | """ 98 | if not isinstance(data, np.ndarray): 99 | data = data.to_numpy() 100 | 101 | data = self.null_transformer.reverse_transform(data) 102 | if isinstance(data, np.ndarray): 103 | if data.ndim == 2: 104 | data = data[:, 0] 105 | 106 | data = pd.Series(data) 107 | 108 | isna = data.isna() 109 | data = np.round(data).clip(0, 1).astype('boolean').astype('object') 110 | data[isna] = np.nan 111 | 112 | return data 113 | 114 | def _set_fitted_parameters(self, column_name, null_transformer): 115 | """Manually set the parameters on the transformer to get it into a fitted state. 116 | 117 | Args: 118 | column_name (str): 119 | The name of the column to use for the transformer. 120 | null_transformer (NullTransformer): 121 | A fitted null transformer instance that can be used to generate 122 | null values for the column. 123 | """ 124 | self.reset_randomization() 125 | self.columns = [column_name] 126 | self.output_columns = [column_name] 127 | self.null_transformer = null_transformer 128 | if self.null_transformer.models_missing_values(): 129 | self.output_columns.append(column_name + '.is_null') 130 | -------------------------------------------------------------------------------- /rdt/transformers/null.py: -------------------------------------------------------------------------------- 1 | """Transformer for data that contains Null values.""" 2 | 3 | import logging 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from rdt.errors import TransformerInputError 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | class NullTransformer: 14 | """Transformer for data that contains Null values. 15 | 16 | Args: 17 | missing_value_replacement (object or None): 18 | Indicate what to do with the null values. If an integer, float or string is given, 19 | replace them with the given value. If the strings ``'mean'`` or ``'mode'`` are given, 20 | replace them with the corresponding aggregation (``'mean'`` only works for numerical 21 | values) if ``'random'`` replace each null value with a random value in the data range. 22 | If ``None`` is given, do not replace them. Defaults to ``None``. 23 | missing_value_generation (str or None): 24 | The way missing values are being handled. There are three strategies: 25 | 26 | * ``random``: Randomly generates missing values based on the percentage of 27 | missing values. 28 | * ``from_column``: Creates a binary column that describes whether the original 29 | value was missing. Then use it to recreate missing values. 30 | * ``None``: Do nothing with the missing values on the reverse transform. Simply 31 | pass whatever data we get through. 32 | """ 33 | 34 | nulls = None 35 | _missing_value_generation = None 36 | _missing_value_replacement = None 37 | _null_percentage = None 38 | 39 | def __init__(self, missing_value_replacement=None, missing_value_generation='random'): 40 | self._missing_value_replacement = missing_value_replacement 41 | if missing_value_generation not in (None, 'from_column', 'random'): 42 | raise TransformerInputError( 43 | "'missing_value_generation' must be one of the following values: " 44 | "None, 'from_column' or 'random'." 45 | ) 46 | 47 | self._missing_value_generation = missing_value_generation 48 | self._min_value = None 49 | self._max_value = None 50 | 51 | def models_missing_values(self): 52 | """Indicate whether this transformer creates a null column on transform. 53 | 54 | Returns: 55 | bool: 56 | Whether a null column is created on transform. 57 | """ 58 | return self._missing_value_generation == 'from_column' 59 | 60 | def _get_missing_value_replacement(self, data): 61 | """Get the fill value to use for the given data. 62 | 63 | Args: 64 | data (pd.Series): 65 | The data that is being transformed. 66 | 67 | Return: 68 | object: 69 | The fill value that needs to be used. 70 | 71 | Raise: 72 | TransformerInputError: 73 | Error raised when data only contains nans and ``_missing_value_replacement`` 74 | is set to 'mean' or 'mode'. 75 | """ 76 | if self._missing_value_replacement is None: 77 | return None 78 | 79 | if self._missing_value_replacement in {'mean', 'mode', 'random'} and pd.isna(data).all(): 80 | msg = ( 81 | f"'missing_value_replacement' cannot be set to '{self._missing_value_replacement}'" 82 | ' when the provided data only contains NaNs. Using 0 instead.' 83 | ) 84 | LOGGER.info(msg) 85 | return 0 86 | 87 | if self._missing_value_replacement == 'mean': 88 | return data.mean() 89 | 90 | if self._missing_value_replacement == 'mode': 91 | return data.mode(dropna=True)[0] 92 | 93 | return self._missing_value_replacement 94 | 95 | def fit(self, data): 96 | """Fit the transformer to the data. 97 | 98 | Evaluate if the transformer has to create the null column or not. 99 | 100 | Args: 101 | data (pandas.Series): 102 | Data to transform. 103 | """ 104 | self._missing_value_replacement = self._get_missing_value_replacement(data) 105 | if self._missing_value_replacement == 'random': 106 | self._min_value = data.min() 107 | self._max_value = data.max() 108 | 109 | if self._missing_value_generation is not None: 110 | null_values = data.isna().to_numpy() 111 | self.nulls = null_values.any() 112 | 113 | if not self.nulls and self.models_missing_values(): 114 | self._missing_value_generation = None 115 | guidance_message = ( 116 | f'Guidance: There are no missing values in column {data.name}. ' 117 | 'Extra column not created.' 118 | ) 119 | LOGGER.info(guidance_message) 120 | 121 | if self._missing_value_generation == 'random': 122 | self._null_percentage = null_values.sum() / len(data) 123 | 124 | def _set_fitted_parameters(self, null_ratio): 125 | """Manually set the parameters on the transformer to get it into a fitted state. 126 | 127 | Args: 128 | null_ratio (float): 129 | The fraction of values to replace with null values. 130 | """ 131 | if null_ratio < 0 or null_ratio > 1.0: 132 | raise ValueError('null_ratio should be a value between 0 and 1.') 133 | 134 | if null_ratio != 0: 135 | self.nulls = True 136 | self._null_percentage = null_ratio 137 | 138 | def transform(self, data): 139 | """Replace null values with the indicated ``missing_value_replacement``. 140 | 141 | If required, create the null indicator column. 142 | 143 | Args: 144 | data (pandas.Series or numpy.ndarray): 145 | Data to transform. 146 | 147 | Returns: 148 | numpy.ndarray 149 | """ 150 | isna = data.isna() 151 | if self._missing_value_replacement == 'random': 152 | data_mask = list( 153 | np.random.uniform(low=self._min_value, high=self._max_value, size=len(data)) 154 | ) 155 | data = data.mask(data.isna(), data_mask) 156 | 157 | elif isna.any() and self._missing_value_replacement is not None: 158 | data = data.infer_objects().fillna(self._missing_value_replacement) 159 | 160 | if self._missing_value_generation == 'from_column': 161 | return pd.concat([data, isna.astype(np.float64)], axis=1).to_numpy() 162 | 163 | return data.to_numpy() 164 | 165 | def reverse_transform(self, data): 166 | """Restore null values to the data. 167 | 168 | If a null indicator column was created during fit, use it as a reference. 169 | Otherwise, randomly replace values with ``np.nan``. The percentage of values 170 | that will be replaced is the percentage of null values seen in the fitted data. 171 | 172 | Args: 173 | data (numpy.ndarray): 174 | Data to transform. 175 | 176 | Returns: 177 | pandas.Series 178 | """ 179 | data = data.copy() 180 | if self._missing_value_generation == 'from_column': 181 | if self.nulls: 182 | isna = data[:, 1] > 0.5 183 | 184 | data = data[:, 0] 185 | 186 | elif self.nulls: 187 | isna = np.random.random((len(data),)) < self._null_percentage 188 | 189 | data = pd.Series(data) 190 | 191 | if self.nulls and isna.any(): 192 | data.loc[isna] = np.nan 193 | 194 | return data 195 | -------------------------------------------------------------------------------- /rdt/transformers/pii/__init__.py: -------------------------------------------------------------------------------- 1 | """Personal Identifiable Information Transformers module.""" 2 | 3 | from rdt.transformers.pii.anonymizer import ( 4 | AnonymizedFaker, 5 | PseudoAnonymizedFaker, 6 | ) 7 | 8 | __all__ = [ 9 | 'AnonymizedFaker', 10 | 'PseudoAnonymizedFaker', 11 | ] 12 | -------------------------------------------------------------------------------- /rdt/transformers/pii/anonymization.py: -------------------------------------------------------------------------------- 1 | """Anonymization module for the RDT PII Transformer.""" 2 | 3 | import inspect 4 | import warnings 5 | from functools import lru_cache 6 | 7 | from faker import Faker 8 | from faker.config import AVAILABLE_LOCALES 9 | 10 | from rdt.transformers import AnonymizedFaker 11 | 12 | SDTYPE_ANONYMIZERS = { 13 | 'address': {'provider_name': 'address', 'function_name': 'address'}, 14 | 'email': {'provider_name': 'internet', 'function_name': 'email'}, 15 | 'ipv4_address': {'provider_name': 'internet', 'function_name': 'ipv4'}, 16 | 'ipv6_address': {'provider_name': 'internet', 'function_name': 'ipv6'}, 17 | 'mac_address': { 18 | 'provider_name': 'internet', 19 | 'function_name': 'mac_address', 20 | }, 21 | 'name': {'provider_name': 'person', 'function_name': 'name'}, 22 | 'phone_number': { 23 | 'provider_name': 'phone_number', 24 | 'function_name': 'phone_number', 25 | }, 26 | 'ssn': {'provider_name': 'ssn', 'function_name': 'ssn'}, 27 | 'user_agent_string': { 28 | 'provider_name': 'user_agent', 29 | 'function_name': 'user_agent', 30 | }, 31 | } 32 | 33 | 34 | @lru_cache() 35 | def get_faker_instance(): 36 | """Return a ``faker.Faker`` instance with all the locales.""" 37 | return Faker(AVAILABLE_LOCALES) 38 | 39 | 40 | def is_faker_function(function_name): 41 | """Return whether or not the function name is a valid Faker function. 42 | 43 | Args: 44 | function_name (str): 45 | String representing predefined ``sdtype`` or a ``faker`` function. 46 | 47 | Returns: 48 | True if the ``function_name`` is know to ``Faker``, otherwise False. 49 | """ 50 | try: 51 | with warnings.catch_warnings(): 52 | warnings.filterwarnings('ignore', module='faker') 53 | getattr(get_faker_instance(), function_name) 54 | except AttributeError: 55 | return False 56 | 57 | return True 58 | 59 | 60 | def _detect_provider_name(function_name, locales=None): 61 | function_name = getattr(Faker(locale=locales), function_name) 62 | module = inspect.getmodule(function_name).__name__ 63 | module = module.split('.') 64 | if len(module) == 2: 65 | return 'BaseProvider' 66 | return '.'.join(module[2:]) 67 | 68 | 69 | def get_anonymized_transformer(function_name, transformer_kwargs=None): 70 | """Get an instance with an ``AnonymizedFaker`` for the given ``function_name``. 71 | 72 | Args: 73 | function_name (str): 74 | String representing predefined ``sdtype`` or a ``faker`` function. 75 | transformer_kwargs (dict): 76 | Keyword args to pass into AnonymizedFaker transformer. Optional. 77 | """ 78 | transformer_kwargs = transformer_kwargs or {} 79 | locales = transformer_kwargs.get('locales') 80 | if function_name in SDTYPE_ANONYMIZERS: 81 | transformer_kwargs.update(SDTYPE_ANONYMIZERS[function_name]) 82 | return AnonymizedFaker(**transformer_kwargs) 83 | 84 | provider_name = _detect_provider_name(function_name, locales=locales) 85 | transformer_kwargs.update({ 86 | 'function_name': function_name, 87 | 'provider_name': provider_name, 88 | }) 89 | 90 | return AnonymizedFaker(**transformer_kwargs) 91 | -------------------------------------------------------------------------------- /rdt/transformers/pii/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for anonymization.""" 2 | 3 | import inspect 4 | 5 | from faker import Faker 6 | 7 | 8 | def get_provider_name(function_name): 9 | """Return the ``faker`` provider name for a given ``function_name``. 10 | 11 | Args: 12 | function_name (str): 13 | String representing a ``faker`` function. 14 | 15 | Returns: 16 | provider_name (str): 17 | String representing the provider name of the faker function. 18 | """ 19 | function_name = getattr(Faker(), function_name) 20 | module = inspect.getmodule(function_name).__name__ 21 | module = module.split('.') 22 | if len(module) == 2: 23 | return 'BaseProvider' 24 | 25 | return module[-1] 26 | -------------------------------------------------------------------------------- /rdt/transformers/text.py: -------------------------------------------------------------------------------- 1 | """Transformers for text data.""" 2 | 3 | import warnings 4 | 5 | from rdt.transformers.id import IDGenerator, RegexGenerator # noqa: F401 6 | 7 | warnings.warn( 8 | "Importing 'IDGenerator' or 'RegexGenerator' for ID columns from 'rdt.transformers.text' " 9 | "is deprecated. Please use 'rdt.transformers.id' instead.", 10 | DeprecationWarning, 11 | stacklevel=2, 12 | ) 13 | -------------------------------------------------------------------------------- /scripts/check_for_prereleases.py: -------------------------------------------------------------------------------- 1 | """Script that checks project requirements for pre-release versions.""" 2 | 3 | from pathlib import Path 4 | 5 | import tomllib 6 | from packaging.requirements import Requirement 7 | 8 | 9 | def get_dev_dependencies(dependency_list): 10 | """Return list of dependencies with prerelease specifiers.""" 11 | prereleases = [] 12 | for dependency in dependency_list: 13 | requirement = Requirement(dependency) 14 | if requirement.specifier.prereleases or requirement.url: 15 | prereleases.append(dependency) 16 | 17 | return prereleases 18 | 19 | 20 | if __name__ == '__main__': 21 | folder = Path(__file__).parent 22 | toml_path = folder.joinpath('..', 'pyproject.toml') 23 | 24 | with open(toml_path, 'rb') as f: 25 | pyproject = tomllib.load(f) 26 | 27 | dependencies = pyproject['project']['dependencies'] 28 | optional_dependencies = pyproject['project'].get('optional-dependencies', {}) 29 | for dependency_list in optional_dependencies.values(): 30 | dependencies.extend(dependency_list) 31 | dev_deps = get_dev_dependencies(dependencies) 32 | 33 | if dev_deps: 34 | raise RuntimeError(f'Found dev dependencies: {", ".join(dev_deps)}') 35 | -------------------------------------------------------------------------------- /scripts/release_notes_generator.py: -------------------------------------------------------------------------------- 1 | """Script to generate release notes.""" 2 | 3 | import argparse 4 | import os 5 | from collections import defaultdict 6 | 7 | import requests 8 | 9 | LABEL_TO_HEADER = { 10 | 'feature request': 'New Features', 11 | 'bug': 'Bugs Fixed', 12 | 'internal': 'Internal', 13 | 'maintenance': 'Maintenance', 14 | 'customer success': 'Customer Success', 15 | 'documentation': 'Documentation', 16 | 'misc': 'Miscellaneous', 17 | } 18 | ISSUE_LABELS = [ 19 | 'documentation', 20 | 'maintenance', 21 | 'internal', 22 | 'bug', 23 | 'feature request', 24 | 'customer success', 25 | ] 26 | ISSUE_LABELS_ORDERED_BY_IMPORTANCE = [ 27 | 'feature request', 28 | 'customer success', 29 | 'bug', 30 | 'documentation', 31 | 'internal', 32 | 'maintenance', 33 | ] 34 | NEW_LINE = '\n' 35 | GITHUB_URL = 'https://api.github.com/repos/sdv-dev/rdt' 36 | GITHUB_TOKEN = os.getenv('GH_ACCESS_TOKEN') 37 | 38 | 39 | def _get_milestone_number(milestone_title): 40 | url = f'{GITHUB_URL}/milestones' 41 | headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'} 42 | query_params = {'milestone': milestone_title, 'state': 'all', 'per_page': 100} 43 | response = requests.get(url, headers=headers, params=query_params, timeout=10) 44 | body = response.json() 45 | if response.status_code != 200: 46 | raise Exception(str(body)) 47 | 48 | milestones = body 49 | for milestone in milestones: 50 | if milestone.get('title') == milestone_title: 51 | return milestone.get('number') 52 | 53 | raise ValueError(f'Milestone {milestone_title} not found in past 100 milestones.') 54 | 55 | 56 | def _get_issues_by_milestone(milestone): 57 | headers = {'Authorization': f'Bearer {GITHUB_TOKEN}'} 58 | # get milestone number 59 | milestone_number = _get_milestone_number(milestone) 60 | url = f'{GITHUB_URL}/issues' 61 | page = 1 62 | query_params = {'milestone': milestone_number, 'state': 'all'} 63 | issues = [] 64 | while True: 65 | query_params['page'] = page 66 | response = requests.get(url, headers=headers, params=query_params, timeout=10) 67 | body = response.json() 68 | if response.status_code != 200: 69 | raise Exception(str(body)) 70 | 71 | issues_on_page = body 72 | if not issues_on_page: 73 | break 74 | 75 | # Filter our PRs 76 | issues_on_page = [issue for issue in issues_on_page if issue.get('pull_request') is None] 77 | issues.extend(issues_on_page) 78 | page += 1 79 | 80 | return issues 81 | 82 | 83 | def _get_issues_by_category(release_issues): 84 | category_to_issues = defaultdict(list) 85 | 86 | for issue in release_issues: 87 | issue_title = issue['title'] 88 | issue_number = issue['number'] 89 | issue_url = issue['html_url'] 90 | line = f'* {issue_title} - Issue [#{issue_number}]({issue_url})' 91 | assignee = issue.get('assignee') 92 | if assignee: 93 | login = assignee['login'] 94 | line += f' by @{login}' 95 | 96 | # Check if any known label is marked on the issue 97 | labels = [label['name'] for label in issue['labels']] 98 | found_category = False 99 | for category in ISSUE_LABELS: 100 | if category in labels: 101 | category_to_issues[category].append(line) 102 | found_category = True 103 | break 104 | 105 | if not found_category: 106 | category_to_issues['misc'].append(line) 107 | 108 | return category_to_issues 109 | 110 | 111 | def _create_release_notes(issues_by_category, version, date): 112 | title = f'## v{version} - {date}' 113 | release_notes = f'{title}{NEW_LINE}{NEW_LINE}' 114 | 115 | for category in ISSUE_LABELS_ORDERED_BY_IMPORTANCE + ['misc']: 116 | issues = issues_by_category.get(category) 117 | if issues: 118 | section_text = ( 119 | f'### {LABEL_TO_HEADER[category]}{NEW_LINE}{NEW_LINE}' 120 | f'{NEW_LINE.join(issues)}{NEW_LINE}{NEW_LINE}' 121 | ) 122 | 123 | release_notes += section_text 124 | 125 | return release_notes 126 | 127 | 128 | def update_release_notes(release_notes): 129 | """Add the release notes for the new release to the ``HISTORY.md``.""" 130 | file_path = 'HISTORY.md' 131 | with open(file_path, 'r') as history_file: 132 | history = history_file.read() 133 | 134 | token = '# HISTORY\n\n' 135 | split_index = history.find(token) + len(token) + 1 136 | header = history[:split_index] 137 | new_notes = f'{header}{release_notes}{history[split_index:]}' 138 | 139 | with open(file_path, 'w') as new_history_file: 140 | new_history_file.write(new_notes) 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('-v', '--version', type=str, help='Release version number (ie. v1.0.1)') 146 | parser.add_argument('-d', '--date', type=str, help='Date of release in format YYYY-MM-DD') 147 | args = parser.parse_args() 148 | release_number = args.version 149 | release_issues = _get_issues_by_milestone(release_number) 150 | issues_by_category = _get_issues_by_category(release_issues) 151 | release_notes = _create_release_notes(issues_by_category, release_number, args.date) 152 | update_release_notes(release_notes) 153 | -------------------------------------------------------------------------------- /static_code_analysis.txt: -------------------------------------------------------------------------------- 1 | Run started:2025-06-26 21:09:47.571406 2 | 3 | Test results: 4 | No issues identified. 5 | 6 | Code scanned: 7 | Total lines of code: 6290 8 | Total lines skipped (#nosec): 0 9 | Total potential issues skipped due to specifically being disabled (e.g., #nosec BXXX): 0 10 | 11 | Run metrics: 12 | Total issues (by severity): 13 | Undefined: 0 14 | Low: 0 15 | Medium: 0 16 | High: 0 17 | Total issues (by confidence): 18 | Undefined: 0 19 | Low: 0 20 | Medium: 0 21 | High: 0 22 | Files skipped (0): 23 | -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import operator 3 | import os 4 | import shutil 5 | import stat 6 | import sys 7 | from pathlib import Path 8 | 9 | import tomli 10 | from invoke import task 11 | from packaging.requirements import Requirement 12 | from packaging.version import Version 13 | 14 | COMPARISONS = { 15 | '>=': operator.ge, 16 | '>': operator.gt, 17 | '<': operator.lt, 18 | '<=': operator.le, 19 | } 20 | 21 | 22 | if not hasattr(inspect, 'getargspec'): 23 | inspect.getargspec = inspect.getfullargspec 24 | 25 | 26 | @task 27 | def check_dependencies(c): 28 | c.run('python -m pip check') 29 | 30 | 31 | @task 32 | def unit(c): 33 | c.run( 34 | 'python -m pytest ./tests/unit ./tests/performance/tests ./tests/datasets/tests ' 35 | '--cov=rdt --cov-fail-under=100 --cov-report=xml:./unit_cov.xml' 36 | ) 37 | 38 | 39 | @task 40 | def integration(c): 41 | c.run('python -m pytest ./tests/integration --cov=rdt --cov-report=xml:./integration_cov.xml') 42 | 43 | 44 | @task 45 | def performance(c): 46 | c.run('python -m pytest -v ./tests/performance/test_performance.py') 47 | 48 | 49 | def _get_minimum_versions(dependencies, python_version): 50 | min_versions = {} 51 | for dependency in dependencies: 52 | if '@' in dependency: 53 | name, url = dependency.split(' @ ') 54 | min_versions[name] = f'{url}#egg={name}' 55 | continue 56 | 57 | req = Requirement(dependency) 58 | if ';' in dependency: 59 | marker = req.marker 60 | if marker and not marker.evaluate({'python_version': python_version}): 61 | continue # Skip this dependency if the marker does not apply to the current Python version 62 | 63 | if req.name not in min_versions: 64 | min_version = next( 65 | (spec.version for spec in req.specifier if spec.operator in ('>=', '==')), 66 | None, 67 | ) 68 | if min_version: 69 | min_versions[req.name] = f'{req.name}=={min_version}' 70 | 71 | elif '@' not in min_versions[req.name]: 72 | existing_version = Version(min_versions[req.name].split('==')[1]) 73 | new_version = next( 74 | (spec.version for spec in req.specifier if spec.operator in ('>=', '==')), 75 | existing_version, 76 | ) 77 | if new_version > existing_version: 78 | min_versions[req.name] = ( 79 | f'{req.name}=={new_version}' # Change when a valid newer version is found 80 | ) 81 | 82 | return list(min_versions.values()) 83 | 84 | 85 | @task 86 | def install_minimum(c): 87 | with open('pyproject.toml', 'rb') as pyproject_file: 88 | pyproject_data = tomli.load(pyproject_file) 89 | 90 | dependencies = pyproject_data.get('project', {}).get('dependencies', []) 91 | python_version = '.'.join(map(str, sys.version_info[:2])) 92 | minimum_versions = _get_minimum_versions(dependencies, python_version) 93 | 94 | if minimum_versions: 95 | install_deps = ' '.join(minimum_versions) 96 | c.run(f'python -m pip install {install_deps}') 97 | 98 | 99 | @task 100 | def minimum(c): 101 | install_minimum(c) 102 | check_dependencies(c) 103 | unit(c) 104 | integration(c) 105 | 106 | 107 | @task 108 | def readme(c): 109 | test_path = Path('tests/readme_test') 110 | if test_path.exists() and test_path.is_dir(): 111 | shutil.rmtree(test_path) 112 | 113 | cwd = os.getcwd() 114 | os.makedirs(test_path, exist_ok=True) 115 | shutil.copy('README.md', test_path / 'README.md') 116 | os.chdir(test_path) 117 | c.run('rundoc run --single-session python3 -t python3 README.md') 118 | os.chdir(cwd) 119 | shutil.rmtree(test_path) 120 | 121 | 122 | @task 123 | def lint(c): 124 | check_dependencies(c) 125 | c.run('ruff check .') 126 | c.run('ruff format --check --diff .') 127 | 128 | 129 | @task 130 | def fix_lint(c): 131 | check_dependencies(c) 132 | c.run('ruff check --fix .') 133 | c.run('ruff format .') 134 | 135 | 136 | def remove_readonly(func, path, _): 137 | "Clear the readonly bit and reattempt the removal" 138 | os.chmod(path, stat.S_IWRITE) 139 | func(path) 140 | 141 | 142 | @task 143 | def rmdir(c, path): 144 | try: 145 | shutil.rmtree(path, onerror=remove_readonly) 146 | except PermissionError: 147 | pass 148 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT test module.""" 2 | 3 | 4 | def safe_compare_dataframes(first, second): 5 | """Compare two dataframes even if they have NaN values. 6 | 7 | Args: 8 | first (pandas.DataFrame): DataFrame to compare 9 | second (pandas.DataFrame): DataFrame to compare 10 | 11 | Returns: 12 | bool 13 | """ 14 | if first.isna().all().all(): 15 | return first.equals(second) 16 | 17 | else: 18 | nulls = (first.isna() == second.isna()).all().all() 19 | values = (first[~first.isna()] == second[~second.isna()]).all().all() 20 | return nulls and values 21 | -------------------------------------------------------------------------------- /tests/code_style.py: -------------------------------------------------------------------------------- 1 | """RDT code style module.""" 2 | 3 | import importlib 4 | import inspect 5 | from pathlib import Path 6 | from types import FunctionType 7 | 8 | import pytest 9 | 10 | from rdt.transformers import TRANSFORMERS, get_transformer_class 11 | from rdt.transformers.base import BaseTransformer 12 | 13 | 14 | def validate_transformer_subclass(transformer): 15 | """Test whether or not the ``Transformer`` is a subclass of ``BaseTransformer``.""" 16 | fail_message = 'Transformer must be a subclass of ``BaseTransformer``.' 17 | assert issubclass(transformer, BaseTransformer), fail_message 18 | 19 | 20 | def validate_transformer_module(transformer): 21 | """Test whether or not the ``Transformer`` is inside the right module.""" 22 | transformer_file = Path(inspect.getfile(transformer)) 23 | transformer_folder = transformer_file.parent 24 | is_valid = False 25 | 26 | if transformer_folder.match('transformers'): 27 | is_valid = True 28 | elif transformer_folder.parent.match('transformers'): 29 | is_valid = True 30 | 31 | assert is_valid, 'The transformer module is not placed inside a valid path.' 32 | 33 | 34 | def validate_transformer_importable_from_parent_module(transformer): 35 | """Validate wheter the transformer can be imported from the parent module.""" 36 | name = transformer.get_name() 37 | module = getattr(transformer, '__module__', '') 38 | module = module.rsplit('.', 1)[0] 39 | imported_transformer = getattr(importlib.import_module(module), name, None) 40 | assert imported_transformer is not None, f'Could not import {name} from {module}' 41 | 42 | 43 | def get_test_location(transformer): 44 | """Return the expected unit test location of a transformer.""" 45 | transformer_file = Path(inspect.getfile(transformer)) 46 | transformer_folder = transformer_file.parent 47 | rdt_unit_test_path = Path(__file__).parent / 'unit' 48 | 49 | test_location = None 50 | if transformer_folder.match('transformers'): 51 | test_location = rdt_unit_test_path / 'transformers' / f'test_{transformer_file.name}' 52 | 53 | elif transformer_folder.parent.match('transformers'): 54 | test_location = rdt_unit_test_path / 'transformers' / transformer_folder.name 55 | test_location = test_location / f'test_{transformer_file.name}' 56 | 57 | return test_location 58 | 59 | 60 | def validate_test_location(transformer): 61 | """Validate if the test file exists in the expected location.""" 62 | test_location = get_test_location(transformer) 63 | if test_location is None: 64 | return False, 'The expected test location was not found.' 65 | 66 | assert test_location.exists(), 'The expected test location does not exist.' 67 | 68 | 69 | def _load_module_from_path(path): 70 | """Return the module from a given ``PosixPath``.""" 71 | assert path.exists(), 'The expected test module was not found.' 72 | module_path = path.parent 73 | module_name = path.name.split('.')[0] 74 | if module_path.name == 'transformers': 75 | module_path = f'rdt.transformers.{module_name}' 76 | elif module_path.parent.name == 'transformers': 77 | module_path = f'rdt.transformers.{module_path.parent.name}.{module_name}' 78 | 79 | spec = importlib.util.spec_from_file_location(module_path, path) 80 | module = importlib.util.module_from_spec(spec) 81 | spec.loader.exec_module(module) 82 | 83 | return module 84 | 85 | 86 | def validate_test_names(transformer): 87 | """Validate if the test methods are properly specified.""" 88 | test_file = get_test_location(transformer) 89 | module = _load_module_from_path(test_file) 90 | 91 | test_class = getattr(module, f'Test{transformer.get_name()}', None) 92 | assert test_class is not None, 'The expected test class was not found.' 93 | 94 | test_functions = inspect.getmembers(test_class, predicate=inspect.isfunction) 95 | test_functions = [test for test, _ in test_functions if test.startswith('test')] 96 | 97 | assert test_functions, 'No test functions found within the test module.' 98 | 99 | transformer_functions = [ 100 | name 101 | for name, function in transformer.__dict__.items() 102 | if isinstance(function, (FunctionType, classmethod, staticmethod)) 103 | ] 104 | 105 | valid_test_functions = [] 106 | for test in test_functions: 107 | count = len(valid_test_functions) 108 | for transformer_function in transformer_functions: 109 | simple_test = rf'test_{transformer_function}' 110 | described_test = rf'test_{transformer_function}_' 111 | if test.startswith(described_test): 112 | valid_test_functions.append(test) 113 | elif test.startswith(simple_test): 114 | valid_test_functions.append(test) 115 | 116 | fail_message = f'No function name was found for the test: {test}' 117 | assert len(valid_test_functions) > count, fail_message 118 | 119 | 120 | @pytest.mark.parametrize('transformer', TRANSFORMERS.values(), ids=TRANSFORMERS.keys()) # noqa 121 | def test_transformer_code_style(transformer): 122 | """Validate a transformer.""" 123 | if not inspect.isclass(transformer): 124 | transformer = get_transformer_class(transformer) 125 | 126 | validate_transformer_subclass(transformer) 127 | validate_transformer_module(transformer) 128 | validate_test_location(transformer) 129 | validate_test_names(transformer) 130 | validate_transformer_importable_from_parent_module(transformer) 131 | -------------------------------------------------------------------------------- /tests/datasets/tests/test_boolean.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from rdt.performance.datasets import boolean 4 | 5 | NUM_ROWS = 50 6 | 7 | 8 | class TestRandomBooleanGenerator: 9 | def test_generate(self): 10 | """Test the `RandomBooleanGenerator.generate` method. 11 | 12 | Expect that the specified number of rows of booleans is generated, and 13 | that there are 2 unique values (True and False). 14 | 15 | Input: 16 | - the number of rows 17 | Output: 18 | - a random boolean array of the specified number of rows 19 | """ 20 | output = boolean.RandomBooleanGenerator.generate(NUM_ROWS) 21 | assert len(output) == NUM_ROWS 22 | assert output.dtype == bool 23 | assert len(pd.unique(output)) == 2 24 | assert pd.isna(output).sum() == 0 25 | 26 | 27 | class TestRandomBooleanNaNsGenerator: 28 | def test_generate(self): 29 | """Test the `RandomBooleanNaNsGenerator.generate` method. 30 | 31 | Expect that the specified number of rows of booleans is generated, and 32 | that there are 3 unique values (True, False, and None). 33 | 34 | Input: 35 | - the number of rows 36 | Output: 37 | - a random boolean array of the specified number of rows, with null values 38 | """ 39 | output = boolean.RandomBooleanNaNsGenerator.generate(NUM_ROWS) 40 | assert len(output) == NUM_ROWS 41 | assert output.dtype == 'O' 42 | assert len(pd.unique(output)) == 3 43 | assert pd.isna(output).sum() > 0 44 | 45 | 46 | class TestRandomSkewedBooleanGenerator: 47 | def test_generate(self): 48 | """Test the `RandomSkewedBooleanGenerator.generate` method. 49 | 50 | Expect that the specified number of rows of booleans is generated, and 51 | that there are 3 unique values (True, False, and None). 52 | 53 | Input: 54 | - the number of rows 55 | Output: 56 | - a skewed random boolean array of the specified number of rows 57 | """ 58 | output = boolean.RandomSkewedBooleanGenerator.generate(NUM_ROWS) 59 | assert len(output) == NUM_ROWS 60 | assert output.dtype == bool 61 | assert len(pd.unique(output)) == 2 62 | assert pd.isna(output).sum() == 0 63 | 64 | 65 | class TestRandomSkewedBooleanNaNsGenerator: 66 | def test_generate(self): 67 | """Test the `RandomSkewedBooleanNaNsGenerator.generate` method. 68 | 69 | Expect that the specified number of rows of booleans is generated, and 70 | that there are 3 unique values (True, False, and None). 71 | 72 | Input: 73 | - the number of rows 74 | Output: 75 | - a skewed random boolean array of the specified number of rows, 76 | with null values 77 | """ 78 | output = boolean.RandomSkewedBooleanNaNsGenerator.generate(NUM_ROWS) 79 | assert len(output) == NUM_ROWS 80 | assert output.dtype == 'O' 81 | assert len(pd.unique(output)) == 3 82 | assert pd.isna(output).sum() > 0 83 | 84 | 85 | class TestConstantBooleanGenerator: 86 | def test_generate(self): 87 | """Test the `ConstantBooleanGenerator.generate` method. 88 | 89 | Expect that the specified number of rows of booleans is generated, and 90 | that there is only one unique value (True or False). 91 | 92 | Input: 93 | - the number of rows 94 | Output: 95 | - a boolean array of the specified number of rows, with all values equal 96 | to either True or False 97 | """ 98 | output = boolean.ConstantBooleanGenerator.generate(NUM_ROWS) 99 | assert len(output) == NUM_ROWS 100 | assert output.dtype == bool 101 | assert len(pd.unique(output)) == 1 102 | assert pd.isna(output).sum() == 0 103 | 104 | 105 | class TestConstantBooleanNaNsGenerator: 106 | def test(self): 107 | output = boolean.ConstantBooleanNaNsGenerator.generate(NUM_ROWS) 108 | assert len(output) == NUM_ROWS 109 | assert output.dtype == 'O' 110 | assert len(pd.unique(output)) == 2 111 | assert pd.isna(output).sum() > 0 112 | -------------------------------------------------------------------------------- /tests/datasets/tests/test_categorical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pandas.api.types import is_integer_dtype 4 | 5 | from rdt.performance.datasets import categorical 6 | 7 | 8 | class TestRandomIntegerGenerator: 9 | def test(self): 10 | output = categorical.RandomIntegerGenerator.generate(10) 11 | assert len(output) == 10 12 | assert is_integer_dtype(output.dtype) 13 | assert len(pd.unique(output)) < 6 14 | assert np.isnan(output).sum() == 0 15 | 16 | 17 | class TestRandomIntegerNaNsGenerator: 18 | def test(self): 19 | output = categorical.RandomIntegerNaNsGenerator.generate(10) 20 | assert len(output) == 10 21 | assert output.dtype == float 22 | assert len(pd.unique(output)) < 7 23 | assert np.isnan(output).sum() > 0 24 | 25 | 26 | class TestRandomStringGenerator: 27 | def test(self): 28 | output = categorical.RandomStringGenerator.generate(10) 29 | assert len(output) == 10 30 | assert output.dtype.type == np.str_ 31 | assert len(pd.unique(output)) < 6 32 | assert pd.isna(output).sum() == 0 33 | 34 | 35 | class TestRandomStringNaNsGenerator: 36 | def test(self): 37 | output = categorical.RandomStringNaNsGenerator.generate(10) 38 | assert len(output) == 10 39 | assert output.dtype.type == np.object_ 40 | assert len(pd.unique(output)) < 7 41 | assert sum(pd.isna(output)) > 0 42 | 43 | 44 | class TestRandomMixedGenerator: 45 | def test(self): 46 | output = categorical.RandomMixedGenerator.generate(10) 47 | assert len(output) == 10 48 | assert output.dtype.type == np.object_ 49 | assert pd.isna(output).sum() == 0 50 | 51 | 52 | class TestRandomMixedNaNsGenerator: 53 | def test(self): 54 | output = categorical.RandomMixedNaNsGenerator.generate(10) 55 | assert len(output) == 10 56 | assert output.dtype.type == np.object_ 57 | assert sum(pd.isna(output)) > 0 58 | 59 | 60 | class TestSingleIntegerGenerator: 61 | def test(self): 62 | output = categorical.SingleIntegerGenerator.generate(10) 63 | assert len(output) == 10 64 | assert is_integer_dtype(output.dtype) 65 | assert len(pd.unique(output)) == 1 66 | assert np.isnan(output).sum() == 0 67 | 68 | 69 | class TestSingleIntegerNaNsGenerator: 70 | def test(self): 71 | output = categorical.SingleIntegerNaNsGenerator.generate(10) 72 | assert len(output) == 10 73 | assert output.dtype == float 74 | assert len(pd.unique(output)) == 2 75 | assert np.isnan(output).sum() >= 1 76 | 77 | 78 | class TestSingleStringGenerator: 79 | def test(self): 80 | output = categorical.SingleStringGenerator.generate(10) 81 | assert len(output) == 10 82 | assert output.dtype.type == np.str_ 83 | assert len(pd.unique(output)) == 1 84 | assert pd.isna(output).sum() == 0 85 | 86 | 87 | class TestSingleStringNaNsGenerator: 88 | def test(self): 89 | output = categorical.SingleStringNaNsGenerator.generate(10) 90 | assert len(output) == 10 91 | assert output.dtype.type == np.object_ 92 | assert len(pd.unique(output)) == 2 93 | assert sum(pd.isna(output)) >= 1 94 | 95 | 96 | class TestUniqueIntegerGenerator: 97 | def test(self): 98 | output = categorical.UniqueIntegerGenerator.generate(10) 99 | assert len(output) == 10 100 | assert is_integer_dtype(output.dtype) 101 | assert len(pd.unique(output)) == 10 102 | assert np.isnan(output).sum() == 0 103 | 104 | 105 | class TestUniqueIntegerNaNsGenerator: 106 | def test(self): 107 | output = categorical.UniqueIntegerNaNsGenerator.generate(10) 108 | nulls = np.isnan(output).sum() 109 | 110 | assert len(output) == 10 111 | assert output.dtype == float 112 | assert len(pd.unique(output)) == 10 - nulls + 1 113 | assert nulls > 0 114 | 115 | 116 | class TestUniqueStringGenerator: 117 | def test(self): 118 | output = categorical.UniqueStringGenerator.generate(10) 119 | assert len(output) == 10 120 | assert output.dtype.type == np.str_ 121 | assert len(pd.unique(output)) == 10 122 | assert pd.isna(output).sum() == 0 123 | 124 | 125 | class TestUniqueStringNaNsGenerator: 126 | def test(self): 127 | output = categorical.UniqueStringNaNsGenerator.generate(10) 128 | nulls = sum(pd.isna(output)) 129 | 130 | assert len(output) == 10 131 | assert output.dtype == np.object_ 132 | assert len(pd.unique(output)) == 10 - nulls + 1 133 | assert nulls > 0 134 | -------------------------------------------------------------------------------- /tests/datasets/tests/test_datetime.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from rdt.performance.datasets import datetime 7 | 8 | 9 | class TestRandomGapDatetimeGenerator: 10 | def test(self): 11 | output = datetime.RandomGapDatetimeGenerator.generate(10) 12 | assert len(output) == 10 13 | assert output.dtype == 'datetime64[us]' 14 | assert len(pd.unique(output)) > 1 15 | assert np.isnan(output).sum() == 0 16 | 17 | 18 | class TestRandomGapSecondsDatetimeGenerator: 19 | def test(self): 20 | output = datetime.RandomGapSecondsDatetimeGenerator.generate(10) 21 | assert len(output) == 10 22 | assert output.dtype == 'datetime64[us]' 23 | assert len(pd.unique(output)) > 1 24 | assert np.isnan(output).sum() == 0 25 | 26 | 27 | class TestRandomGapDatetimeNaNsGenerator: 28 | def test(self): 29 | output = datetime.RandomGapDatetimeNaNsGenerator.generate(10) 30 | assert len(output) == 10 31 | assert output.dtype == 'O' 32 | assert len(pd.unique(output)) > 1 33 | assert pd.isna(output).sum() > 0 34 | 35 | 36 | class TestEqualGapHoursDatetimeGenerator: 37 | def test(self): 38 | output = datetime.EqualGapHoursDatetimeGenerator.generate(10) 39 | assert len(output) == 10 40 | assert output.dtype == 'datetime64[us]' 41 | assert all(x == dt.timedelta(hours=1) for x in np.diff(output)) 42 | assert np.isnan(output).sum() == 0 43 | 44 | 45 | class TestEqualGapDaysDatetimeGenerator: 46 | def test(self): 47 | output = datetime.EqualGapDaysDatetimeGenerator.generate(10) 48 | assert len(output) == 10 49 | assert output.dtype == 'datetime64[us]' 50 | assert all(x == dt.timedelta(1) for x in np.diff(output)) 51 | assert np.isnan(output).sum() == 0 52 | 53 | 54 | class TestEqualGapWeeksDatetimeGenerator: 55 | def test(self): 56 | output = datetime.EqualGapWeeksDatetimeGenerator.generate(10) 57 | assert len(output) == 10 58 | assert output.dtype == 'datetime64[us]' 59 | assert all(x == dt.timedelta(weeks=1) for x in np.diff(output)) 60 | assert np.isnan(output).sum() == 0 61 | -------------------------------------------------------------------------------- /tests/datasets/tests/test_numerical.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pandas.api.types import is_integer_dtype 4 | 5 | from rdt.performance.datasets import numerical 6 | 7 | 8 | class TestRandomIntegerGenerator: 9 | def test(self): 10 | output = numerical.RandomIntegerGenerator.generate(10) 11 | assert len(output) == 10 12 | assert is_integer_dtype(output.dtype) 13 | assert len(pd.unique(output)) > 1 14 | assert np.isnan(output).sum() == 0 15 | 16 | 17 | class TestRandomIntegerNaNsGenerator: 18 | def test(self): 19 | output = numerical.RandomIntegerNaNsGenerator.generate(10) 20 | assert len(output) == 10 21 | assert output.dtype == float 22 | assert len(pd.unique(output)) > 1 23 | assert np.isnan(output).sum() > 0 24 | 25 | 26 | class TestConstantIntegerGenerator: 27 | def test(self): 28 | output = numerical.ConstantIntegerGenerator.generate(10) 29 | assert len(output) == 10 30 | assert is_integer_dtype(output.dtype) 31 | assert len(pd.unique(output)) == 1 32 | assert np.isnan(output).sum() == 0 33 | 34 | 35 | class TestConstantIntegerNaNsGenerator: 36 | def test(self): 37 | output = numerical.ConstantIntegerNaNsGenerator.generate(10) 38 | assert len(output) == 10 39 | assert output.dtype == float 40 | assert len(pd.unique(output)) == 2 41 | assert np.isnan(output).sum() >= 1 42 | 43 | 44 | class TestAlmostConstantIntegerGenerator: 45 | def test(self): 46 | output = numerical.AlmostConstantIntegerGenerator.generate(10) 47 | assert len(output) == 10 48 | assert is_integer_dtype(output.dtype) 49 | assert len(pd.unique(output)) == 2 50 | assert np.isnan(output).sum() == 0 51 | 52 | 53 | class TestAlmostConstantIntegerNaNsGenerator: 54 | def test(self): 55 | output = numerical.AlmostConstantIntegerNaNsGenerator.generate(10) 56 | assert len(output) == 10 57 | assert output.dtype == float 58 | assert len(pd.unique(output)) == 3 59 | assert np.isnan(output).sum() >= 1 60 | 61 | 62 | class TestNormalGenerator: 63 | def test(self): 64 | output = numerical.NormalGenerator.generate(10) 65 | assert len(output) == 10 66 | assert output.dtype == float 67 | assert len(pd.unique(output)) == 10 68 | assert np.isnan(output).sum() == 0 69 | 70 | 71 | class TestNormalNaNsGenerator: 72 | def test(self): 73 | output = numerical.NormalNaNsGenerator.generate(10) 74 | assert len(output) == 10 75 | assert output.dtype == float 76 | assert 1 < len(pd.unique(output)) <= 10 77 | assert np.isnan(output).sum() >= 1 78 | 79 | 80 | class TestBigNormalGenerator: 81 | def test(self): 82 | output = numerical.BigNormalGenerator.generate(10) 83 | assert len(output) == 10 84 | assert output.dtype == float 85 | assert len(pd.unique(output)) == 10 86 | assert np.isnan(output).sum() == 0 87 | 88 | 89 | class TestBigNormalNaNsGenerator: 90 | def test(self): 91 | output = numerical.BigNormalNaNsGenerator.generate(10) 92 | assert len(output) == 10 93 | assert output.dtype == float 94 | assert 1 < len(pd.unique(output)) <= 10 95 | assert np.isnan(output).sum() >= 1 96 | -------------------------------------------------------------------------------- /tests/datasets/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for the datasets.utils module.""" 2 | 3 | import numpy as np 4 | 5 | from rdt.performance.datasets import utils 6 | 7 | 8 | def test_add_nulls_int(): 9 | array = np.arange(100) 10 | 11 | with_nans = utils.add_nans(array) 12 | 13 | assert len(with_nans) == 100 14 | assert 1 <= np.isnan(with_nans).sum() < 100 15 | 16 | nans = np.isnan(with_nans) 17 | np.testing.assert_array_equal(array[~nans], with_nans[~nans]) 18 | 19 | 20 | def test_add_nulls_float(): 21 | array = np.arange(100).astype(float) 22 | 23 | with_nans = utils.add_nans(array) 24 | 25 | assert len(with_nans) == 100 26 | assert 1 <= np.isnan(with_nans).sum() < 100 27 | 28 | nans = np.isnan(with_nans) 29 | np.testing.assert_array_equal(array[~nans], with_nans[~nans]) 30 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT integration testing package.""" 2 | 3 | from tests.integration.test_transformers import validate_transformer 4 | 5 | __all__ = [ 6 | 'validate_transformer', 7 | ] 8 | -------------------------------------------------------------------------------- /tests/integration/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT transformers integration testing package.""" 2 | -------------------------------------------------------------------------------- /tests/integration/transformers/pii/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT Personal Identifiable Information transformers integration testing package.""" 2 | -------------------------------------------------------------------------------- /tests/integration/transformers/pii/test_anonymization.py: -------------------------------------------------------------------------------- 1 | from faker import Faker 2 | 3 | from rdt.transformers.pii.anonymization import is_faker_function 4 | 5 | 6 | def test_is_faker_function(): 7 | """Test is_faker_function checks if function is a valid Faker function.""" 8 | # Run 9 | result = is_faker_function('address') 10 | 11 | # Assert 12 | assert result is True 13 | 14 | 15 | def test_is_faker_function_non_default_locale(): 16 | """Test is_faker_function checks non-default locales.""" 17 | # Setup 18 | function_name = 'postcode_in_province' 19 | 20 | # Run 21 | result = is_faker_function(function_name) 22 | 23 | # Assert 24 | assert result is True 25 | assert not hasattr(Faker(), function_name) 26 | -------------------------------------------------------------------------------- /tests/integration/transformers/test_boolean.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from rdt.transformers import BinaryEncoder 5 | from rdt.transformers.null import NullTransformer 6 | 7 | 8 | class TestBinaryEncoder: 9 | def test_boolean_some_nans(self): 10 | """Test BinaryEncoder on input with some nan values. 11 | 12 | Ensure that the BinaryEncoder can fit, transform, and reverse transform on boolean data 13 | with Nones. Expect that the reverse transformed data is the same as the input, but None 14 | becomes nan and the False/nan values can be interchanged. 15 | 16 | Also ensures that the intermediate transformed data is unchanged after reversing. 17 | 18 | Input: 19 | - boolean data with None values 20 | Output: 21 | - The reversed transformed data 22 | """ 23 | # Setup 24 | data = pd.DataFrame([True, False, None, False], columns=['bool']) 25 | column = 'bool' 26 | transformer = BinaryEncoder() 27 | 28 | # Run 29 | transformer.fit(data, column) 30 | transformed = transformer.transform(data) 31 | unchanged_transformed = transformed.copy() 32 | reverse = transformer.reverse_transform(transformed) 33 | 34 | # Assert 35 | np.testing.assert_array_equal(unchanged_transformed, transformed) 36 | assert reverse['bool'][0] in {True, np.nan} 37 | for value in reverse['bool'][1:]: 38 | assert value is False or np.isnan(value) 39 | 40 | def test_boolean_missing_value_replacement_mode(self): 41 | """Test BinaryEncoder when `missing_value_replacement` is set to 'mode'. 42 | 43 | Ensure that the BinaryEncoder can fit, transform, and reverse transform on 44 | boolean data when `missing_value_replacement` is set to `'mode'` and 45 | `missing_value_generation` is set to 'from_column'. Expect that the reverse 46 | transformed data is the same as the input. 47 | """ 48 | # Setup 49 | data = pd.DataFrame([True, True, None, False], columns=['bool']) 50 | column = 'bool' 51 | transformer = BinaryEncoder( 52 | missing_value_replacement='mode', 53 | missing_value_generation='from_column', 54 | ) 55 | 56 | # Run 57 | transformer.fit(data, column) 58 | transformed = transformer.transform(data) 59 | reverse = transformer.reverse_transform(transformed) 60 | 61 | # Assert 62 | expected_transformed = pd.DataFrame({ 63 | 'bool': [1.0, 1.0, 1.0, 0.0], 64 | 'bool.is_null': [0.0, 0.0, 1.0, 0.0], 65 | }) 66 | pd.testing.assert_frame_equal(transformed, expected_transformed) 67 | pd.testing.assert_frame_equal(reverse, data) 68 | 69 | def test_boolean_missing_value_generation_none(self): 70 | """Test the BinaryEncoder when ``missing_value_generation`` is None. 71 | 72 | In this test, the nans should be replacd by the mode on the transformed data. 73 | """ 74 | # Setup 75 | data = pd.DataFrame([True, True, None, False], columns=['bool']) 76 | column = 'bool' 77 | transformer = BinaryEncoder(missing_value_replacement='mode', missing_value_generation=None) 78 | 79 | # Run 80 | transformer.fit(data, column) 81 | transformed = transformer.transform(data) 82 | reverse = transformer.reverse_transform(transformed) 83 | 84 | # Assert 85 | expected_transformed = pd.DataFrame({'bool': [1.0, 1.0, 1.0, 0.0]}) 86 | expected_reversed = pd.DataFrame({'bool': [True, True, True, False]}) 87 | pd.testing.assert_frame_equal(transformed, expected_transformed) 88 | pd.testing.assert_frame_equal(reverse, expected_reversed, check_dtype=False) 89 | 90 | def test__reverse_transform_from_manually_set_parameters_from_column(self): 91 | """Test the ``reverse_transform`` after manually setting parameters.""" 92 | # Setup 93 | data = pd.DataFrame([True, True, None, False], columns=['bool']) 94 | transformed = pd.DataFrame({ 95 | 'bool': [1.0, 1.0, 1.0, 0.0], 96 | 'bool.is_null': [0.0, 0.0, 1.0, 0.0], 97 | }) 98 | column_name = 'bool' 99 | transformer = BinaryEncoder() 100 | 101 | # Run 102 | null_transformer = NullTransformer('mode', missing_value_generation='from_column') 103 | null_transformer._set_fitted_parameters(0.25) 104 | transformer._set_fitted_parameters(column_name, null_transformer) 105 | reverse = transformer.reverse_transform(transformed) 106 | 107 | # Assert 108 | pd.testing.assert_frame_equal(reverse, data) 109 | 110 | def test__reverse_transform_from_manually_set_parameters_random(self): 111 | """Test the ``reverse_transform`` after manually setting parameters.""" 112 | # Setup 113 | data = pd.DataFrame([True, True, None, False, False, True, None, False], columns=['bool']) 114 | transformed = pd.DataFrame({ 115 | 'bool': [1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0], 116 | }) 117 | column_name = 'bool' 118 | transformer = BinaryEncoder() 119 | 120 | # Run 121 | null_transformer = NullTransformer('mode', missing_value_generation='random') 122 | null_transformer._set_fitted_parameters(0.2) 123 | transformer._set_fitted_parameters(column_name, null_transformer) 124 | reverse = transformer.reverse_transform(transformed) 125 | 126 | # Get indices that are not NaN/None as the transformer used a random ratio 127 | nan_indices_data = data[data.isna().any(axis=1)].index 128 | nan_indices_reverse = reverse[reverse.isna().any(axis=1)].index 129 | nan_indices = nan_indices_data.union(nan_indices_reverse) 130 | compare_data = data.drop(index=nan_indices) 131 | compare_reverse = reverse.drop(index=nan_indices) 132 | expected_reverse = pd.DataFrame({ 133 | 'bool': [np.nan, True, np.nan, False, False, True, False, False] 134 | }) 135 | 136 | # Assert 137 | pd.testing.assert_frame_equal(expected_reverse, reverse) 138 | pd.testing.assert_frame_equal(compare_data, compare_reverse) 139 | -------------------------------------------------------------------------------- /tests/performance/README.md: -------------------------------------------------------------------------------- 1 | # RDT Performance Tests 2 | 3 | This subpackage contains the performance tests for RDT. 4 | -------------------------------------------------------------------------------- /tests/performance/__init__.py: -------------------------------------------------------------------------------- 1 | """Functions to evaluate and test the performance of the RDT Transformers.""" 2 | 3 | from tests.performance.test_performance import validate_performance 4 | 5 | __all__ = [ 6 | 'validate_performance', 7 | ] 8 | -------------------------------------------------------------------------------- /tests/performance/test_performance.py: -------------------------------------------------------------------------------- 1 | """Test whether the performance of the Transformers is the expected one.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from rdt.performance.datasets import get_dataset_generators_by_type 8 | from rdt.performance.performance import evaluate_transformer_performance 9 | from rdt.performance.profiling import profile_transformer 10 | from rdt.transformers import get_transformers_by_type 11 | from rdt.transformers.categorical import ( 12 | CustomLabelEncoder, 13 | OrderedLabelEncoder, 14 | OrderedUniformEncoder, 15 | ) 16 | from rdt.transformers.numerical import ClusterBasedNormalizer 17 | 18 | SANDBOX_TRANSFORMERS = [ 19 | ClusterBasedNormalizer, 20 | OrderedLabelEncoder, 21 | CustomLabelEncoder, 22 | OrderedUniformEncoder, 23 | ] 24 | 25 | 26 | def _get_performance_test_cases(): 27 | """Get all the (transformer, dataset_generator) combinations for testing.""" 28 | all_test_cases = [] 29 | 30 | dataset_generators = get_dataset_generators_by_type() 31 | transformers = get_transformers_by_type() 32 | 33 | for sdtype, transformers_for_type in transformers.items(): 34 | dataset_generators_for_type = dataset_generators.get(sdtype, []) 35 | 36 | for transformer in transformers_for_type: 37 | if transformer in SANDBOX_TRANSFORMERS: 38 | continue 39 | 40 | for dataset_generator in dataset_generators_for_type: 41 | all_test_cases.append((transformer, dataset_generator)) 42 | 43 | return all_test_cases 44 | 45 | 46 | test_cases = _get_performance_test_cases() 47 | 48 | 49 | def validate_performance(performance, dataset_generator, should_assert=False): 50 | """Validate the performance of all transformers for a dataset_generator. 51 | 52 | Args: 53 | performance (pd.DataFrame): 54 | The performance metrics of a transformer against a dataset_generator. 55 | dataset_generator (rdt.tests.datasets.BaseDatasetGenerator): 56 | The dataset generator to performance test against. 57 | should_assert (bool): 58 | Whether or not to raise AssertionErrors. 59 | 60 | Returns: 61 | list[bool]: 62 | A list of if each performance metric was valid or not. 63 | """ 64 | expected = dataset_generator.get_performance_thresholds() 65 | out = [] 66 | for test_name, value in performance.items(): 67 | function, metric = test_name.lower().replace(' ', '_').rsplit('_', 1) 68 | expected_metric = expected[function][metric] 69 | valid = value < expected_metric 70 | out.append(valid) 71 | 72 | if should_assert and not valid: 73 | raise AssertionError(f'{function} {metric}: {value} > {expected_metric}') 74 | 75 | return out 76 | 77 | 78 | @pytest.mark.parametrize(('transformer', 'dataset_generator'), test_cases) 79 | def test_performance(transformer, dataset_generator): 80 | """Run the performance tests for RDT. 81 | 82 | This test should find all relevant transformers for the given 83 | dataset generator, and run the ``profile_transformer`` 84 | method, which will assert that the memory consumption 85 | and times are under the maximum acceptable values. 86 | 87 | Input: 88 | transformer (rdt.transformers.BaseTransformer): 89 | The transformer to test. 90 | dataset_generator (rdt.tests.dataset.BaseDatasetGenerator): 91 | The dataset generator to performance tests against. 92 | """ 93 | performance = evaluate_transformer_performance(transformer, dataset_generator) 94 | validate_performance(performance, dataset_generator, should_assert=True) 95 | 96 | 97 | def _round_to_magnitude(value): 98 | if value == 0: 99 | raise ValueError('Value cannot be exactly 0.') 100 | 101 | for digits in range(-15, 15): 102 | rounded = np.round(value, digits) 103 | if rounded != 0: 104 | return rounded 105 | 106 | # We should never reach this line 107 | raise ValueError('Value is too big') 108 | 109 | 110 | def find_transformer_boundaries( 111 | transformer, 112 | dataset_generator, 113 | fit_size, 114 | transform_size, 115 | iterations=1, 116 | multiplier=5, 117 | ): 118 | """Helper function to find valid candidate boundaries for performance tests. 119 | 120 | The function works by: 121 | - Running the profiling multiple times 122 | - Averaging out the values for each metric 123 | - Multiplying the found values by the given multiplier (default=5). 124 | - Rounding to the found order of magnitude 125 | 126 | As an example, if a method took 0.012 seconds to run, the expected output 127 | threshold will be set to 0.1, but if it took 0.016, it will be set to 0.2. 128 | 129 | Args: 130 | transformer (Transformer): 131 | Transformer instance to profile. 132 | dataset_generator (type): 133 | Dataset Generator class to use. 134 | fit_size (int): 135 | Number of values to use when fitting the transformer. 136 | transform_size (int): 137 | Number of values to use when transforming and reverse transforming. 138 | iterations (int): 139 | Number of iterations to perform. 140 | multiplier (int): 141 | The value used to multiply the average results before rounding them 142 | up/down. Defaults to 5. 143 | 144 | Returns: 145 | pd.Series: 146 | Candidate values for each metric. 147 | """ 148 | results = [ 149 | profile_transformer(transformer, dataset_generator, transform_size, fit_size) 150 | for _ in range(iterations) 151 | ] 152 | means = pd.DataFrame(results).mean(axis=0) 153 | return (means * multiplier).apply(_round_to_magnitude) 154 | -------------------------------------------------------------------------------- /tests/performance/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT performance testing package.""" 2 | -------------------------------------------------------------------------------- /tests/performance/tests/test_profiling.py: -------------------------------------------------------------------------------- 1 | """Tests for the profiling module.""" 2 | 3 | from copy import deepcopy 4 | from unittest.mock import Mock, patch 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from rdt.performance.datasets import BaseDatasetGenerator 10 | from rdt.performance.profiling import profile_transformer 11 | from rdt.transformers import FloatFormatter 12 | 13 | 14 | @patch('rdt.performance.profiling.mp') 15 | @patch('rdt.performance.profiling.deepcopy', spec_set=deepcopy) 16 | def test_profile_transformer(deepcopy_mock, multiprocessor_mock): 17 | """Test the ``profile_transformer`` function. 18 | 19 | The function should run the ``fit``, ``transform`` 20 | and ``reverse_transform`` method for the provided transformer 21 | with the dataset created by the provided generator. It should 22 | also output a DataFrame with the average time and peak memory 23 | for each method. 24 | 25 | Input: 26 | - Mock transformer 27 | - Mock dataset generator 28 | - transform size of 100 29 | 30 | Side effects: 31 | - ``fit``, ``transform`` and ``reverse_transform`` should be 32 | called with correct data 33 | 34 | Output: 35 | - DataFrame with times and memories 36 | """ 37 | # Setup 38 | transformer_mock = Mock(spec_set=FloatFormatter) 39 | dataset_gen_mock = Mock(spec_set=BaseDatasetGenerator) 40 | transformer_mock.return_value.transform.return_value = np.zeros(100) 41 | dataset_gen_mock.generate.return_value = np.ones(100) 42 | deepcopy_mock.return_value = transformer_mock.return_value 43 | 44 | # Run 45 | profiling_results = profile_transformer(transformer_mock.return_value, dataset_gen_mock, 100) 46 | 47 | # Assert 48 | expected_output_columns = [ 49 | 'Fit Time', 50 | 'Fit Memory', 51 | 'Transform Time', 52 | 'Transform Memory', 53 | 'Reverse Transform Time', 54 | 'Reverse Transform Memory', 55 | ] 56 | assert len(deepcopy_mock.mock_calls) == 10 57 | assert len(transformer_mock.return_value.fit.mock_calls) == 11 58 | assert len(transformer_mock.return_value.transform.mock_calls) == 11 59 | assert len(transformer_mock.return_value.reverse_transform.mock_calls) == 10 60 | 61 | all( 62 | np.testing.assert_array_equal(call[1][0], np.ones(100)) 63 | for call in transformer_mock.fit.mock_calls 64 | ) 65 | all( 66 | np.testing.assert_array_equal(call[1][0], np.ones(100)) 67 | for call in transformer_mock.transform.mock_calls 68 | ) 69 | all( 70 | np.testing.assert_array_equal(call[1][0], np.zeros(100)) 71 | for call in transformer_mock.reverse_transform.mock_calls 72 | ) 73 | 74 | assert expected_output_columns == list(profiling_results.index) 75 | 76 | process_mock = multiprocessor_mock.get_context().Process 77 | fit_call = process_mock.mock_calls[0] 78 | transform_call = process_mock.mock_calls[3] 79 | reverse_transform_call = process_mock.mock_calls[6] 80 | 81 | assert fit_call[2]['args'][0] == transformer_mock.return_value.fit 82 | pd.testing.assert_frame_equal(fit_call[2]['args'][1], pd.DataFrame({'test': np.ones(100)})) 83 | assert transform_call[2]['args'][0] == transformer_mock.return_value.transform 84 | pd.testing.assert_frame_equal( 85 | transform_call[2]['args'][1].reset_index(drop=True), 86 | pd.DataFrame({'test': np.ones(100)}), 87 | ) 88 | assert reverse_transform_call[2]['args'][0] == transformer_mock.return_value.reverse_transform 89 | np.testing.assert_array_equal(reverse_transform_call[2]['args'][1], np.zeros(100)) 90 | -------------------------------------------------------------------------------- /tests/test_scripts.py: -------------------------------------------------------------------------------- 1 | from scripts.check_for_prereleases import get_dev_dependencies 2 | 3 | 4 | def test_get_dev_dependencies(): 5 | """Test get_dev_dependencies ignores regular releases.""" 6 | # Setup 7 | dependencies = ['rdt>=1.1.1', 'sdv>=1.0.2'] 8 | 9 | # Run 10 | dev_dependencies = get_dev_dependencies(dependency_list=dependencies) 11 | 12 | # Assert 13 | assert len(dev_dependencies) == 0 14 | 15 | 16 | def test_get_dev_dependencies_prereleases(): 17 | """Test get_dev_dependencies detects prereleases.""" 18 | # Setup 19 | dependencies = ['rdt>=1.1.1.dev0', 'sdv>=1.0.2.rc1'] 20 | 21 | # Run 22 | dev_dependencies = get_dev_dependencies(dependency_list=dependencies) 23 | 24 | # Assert 25 | assert dev_dependencies == dependencies 26 | 27 | 28 | def test_get_dev_dependencies_url(): 29 | """Test get_dev_dependencies detects url requirements.""" 30 | # Setup 31 | dependencies = ['rdt>=1.1.1', 'sdv @ git+https://github.com/sdv-dev/sdv.git@main'] 32 | 33 | # Run 34 | dev_dependencies = get_dev_dependencies(dependency_list=dependencies) 35 | 36 | # Assert 37 | assert dev_dependencies == ['sdv @ git+https://github.com/sdv-dev/sdv.git@main'] 38 | -------------------------------------------------------------------------------- /tests/test_tasks.py: -------------------------------------------------------------------------------- 1 | """Tests for the ``tasks.py`` file.""" 2 | 3 | from tasks import _get_minimum_versions 4 | 5 | 6 | def test_get_minimum_versions(): 7 | """Test the ``_get_minimum_versions`` method. 8 | 9 | The method should return the minimum versions of the dependencies for the given python version. 10 | If a library is linked to an URL, the minimum version should be the URL. 11 | """ 12 | # Setup 13 | dependencies = [ 14 | "numpy>=1.20.0,<2;python_version<'3.10'", 15 | "numpy>=1.23.3,<2;python_version>='3.10'", 16 | "pandas>=1.2.0,<2;python_version<'3.10'", 17 | "pandas>=1.3.0,<2;python_version>='3.10'", 18 | 'humanfriendly>=8.2,<11', 19 | 'pandas @ git+https://github.com/pandas-dev/pandas.git@master', 20 | ] 21 | 22 | # Run 23 | minimum_versions_39 = _get_minimum_versions(dependencies, '3.9') 24 | minimum_versions_310 = _get_minimum_versions(dependencies, '3.10') 25 | 26 | # Assert 27 | expected_versions_39 = [ 28 | 'numpy==1.20.0', 29 | 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', 30 | 'humanfriendly==8.2', 31 | ] 32 | expected_versions_310 = [ 33 | 'numpy==1.23.3', 34 | 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', 35 | 'humanfriendly==8.2', 36 | ] 37 | 38 | assert minimum_versions_39 == expected_versions_39 39 | assert minimum_versions_310 == expected_versions_310 40 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT unit testing package.""" 2 | -------------------------------------------------------------------------------- /tests/unit/test___init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from types import ModuleType 3 | from unittest.mock import Mock, patch 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | 9 | import rdt 10 | from rdt import _find_addons, get_demo 11 | 12 | 13 | @pytest.fixture 14 | def mock_rdt(): 15 | rdt_module = sys.modules['rdt'] 16 | rdt_mock = Mock() 17 | rdt_mock.submodule.__name__ = 'rdt.submodule' 18 | sys.modules['rdt'] = rdt_mock 19 | yield rdt_mock 20 | sys.modules['rdt'] = rdt_module 21 | 22 | 23 | def test_get_demo(): 24 | demo = get_demo() 25 | 26 | assert list(demo.columns) == [ 27 | 'last_login', 28 | 'email_optin', 29 | 'credit_card', 30 | 'age', 31 | 'dollars_spent', 32 | ] 33 | assert len(demo) == 5 34 | assert list(demo.isna().sum(axis=0)) == [1, 1, 1, 0, 1] 35 | 36 | 37 | def test_get_demo_many_rows(): 38 | demo = get_demo(10) 39 | 40 | login_dates = pd.Series( 41 | [ 42 | '2021-06-26', 43 | '2021-02-10', 44 | 'NaT', 45 | '2020-09-26', 46 | '2020-12-22', 47 | '2019-11-27', 48 | '2002-05-10', 49 | '2014-10-04', 50 | '2014-03-19', 51 | '2015-09-13', 52 | ], 53 | dtype='datetime64[ns]', 54 | ) 55 | email_optin = [ 56 | False, 57 | False, 58 | False, 59 | True, 60 | np.nan, 61 | np.nan, 62 | False, 63 | True, 64 | False, 65 | False, 66 | ] 67 | credit_card = [ 68 | 'VISA', 69 | 'VISA', 70 | 'AMEX', 71 | np.nan, 72 | 'DISCOVER', 73 | 'AMEX', 74 | 'AMEX', 75 | 'DISCOVER', 76 | 'DISCOVER', 77 | 'VISA', 78 | ] 79 | age = [29, 18, 21, 45, 32, 50, 93, 75, 39, 66] 80 | dollars_spent = [ 81 | 99.99, 82 | np.nan, 83 | 2.50, 84 | 25.00, 85 | 19.99, 86 | 52.48, 87 | 39.99, 88 | 4.67, 89 | np.nan, 90 | 23.28, 91 | ] 92 | 93 | expected = pd.DataFrame({ 94 | 'last_login': login_dates, 95 | 'email_optin': email_optin, 96 | 'credit_card': credit_card, 97 | 'age': age, 98 | 'dollars_spent': dollars_spent, 99 | }) 100 | 101 | pd.testing.assert_frame_equal(demo, expected) 102 | 103 | 104 | @patch.object(rdt, 'entry_points') 105 | def test__find_addons_module(entry_points_mock, mock_rdt): 106 | """Test loading an add-on.""" 107 | # Setup 108 | add_on_mock = Mock(spec=ModuleType) 109 | entry_point = Mock() 110 | entry_point.name = 'rdt.submodule.entry_name' 111 | entry_point.load.return_value = add_on_mock 112 | entry_points_mock.return_value = [entry_point] 113 | 114 | # Run 115 | _find_addons() 116 | 117 | # Assert 118 | entry_points_mock.assert_called_once_with(group='rdt_modules') 119 | assert mock_rdt.submodule.entry_name == add_on_mock 120 | assert sys.modules['rdt.submodule.entry_name'] == add_on_mock 121 | 122 | 123 | @patch.object(rdt, 'entry_points') 124 | def test__find_addons_type_error(entry_points_mock): 125 | """Test it when entry_points raises a TypeError (happens for py38, py39).""" 126 | 127 | # Setup 128 | def side_effect(arg=None): 129 | if arg == 'rdt_modules': 130 | raise TypeError() 131 | return {arg: []} 132 | 133 | entry_points_mock.side_effect = side_effect 134 | 135 | # Run 136 | _find_addons() 137 | 138 | # Assert 139 | entry_points_mock.assert_called_with() 140 | 141 | 142 | @patch.object(rdt, 'entry_points') 143 | def test__find_addons_object(entry_points_mock, mock_rdt): 144 | """Test loading an add-on.""" 145 | # Setup 146 | entry_point = Mock() 147 | entry_point.name = 'rdt.submodule:entry_object.entry_method' 148 | entry_point.load.return_value = 'new_method' 149 | entry_points_mock.return_value = [entry_point] 150 | 151 | # Run 152 | _find_addons() 153 | 154 | # Assert 155 | entry_points_mock.assert_called_once_with(group='rdt_modules') 156 | assert mock_rdt.submodule.entry_object.entry_method == 'new_method' 157 | 158 | 159 | @patch('warnings.warn') 160 | @patch('rdt.entry_points') 161 | def test__find_addons_bad_addon(entry_points_mock, warning_mock): 162 | """Test failing to load an add-on generates a warning.""" 163 | 164 | # Setup 165 | def entry_point_error(): 166 | raise ValueError() 167 | 168 | bad_entry_point = Mock() 169 | bad_entry_point.name = 'bad_entry_point' 170 | bad_entry_point.value = 'bad_module' 171 | bad_entry_point.load.side_effect = entry_point_error 172 | entry_points_mock.return_value = [bad_entry_point] 173 | msg = 'Failed to load "bad_entry_point" from "bad_module".' 174 | 175 | # Run 176 | _find_addons() 177 | 178 | # Assert 179 | entry_points_mock.assert_called_once_with(group='rdt_modules') 180 | warning_mock.assert_called_once_with(msg) 181 | 182 | 183 | @patch('warnings.warn') 184 | @patch('rdt.entry_points') 185 | def test__find_addons_wrong_base(entry_points_mock, warning_mock): 186 | """Test incorrect add-on name generates a warning.""" 187 | # Setup 188 | bad_entry_point = Mock() 189 | bad_entry_point.name = 'bad_base.bad_entry_point' 190 | entry_points_mock.return_value = [bad_entry_point] 191 | msg = ( 192 | "Failed to set 'bad_base.bad_entry_point': expected base module to be 'rdt', found " 193 | "'bad_base'." 194 | ) 195 | 196 | # Run 197 | _find_addons() 198 | 199 | # Assert 200 | entry_points_mock.assert_called_once_with(group='rdt_modules') 201 | warning_mock.assert_called_once_with(msg) 202 | 203 | 204 | @patch('warnings.warn') 205 | @patch('rdt.entry_points') 206 | def test__find_addons_missing_submodule(entry_points_mock, warning_mock): 207 | """Test incorrect add-on name generates a warning.""" 208 | # Setup 209 | bad_entry_point = Mock() 210 | bad_entry_point.name = 'rdt.missing_submodule.new_submodule' 211 | entry_points_mock.return_value = [bad_entry_point] 212 | msg = ( 213 | "Failed to set 'rdt.missing_submodule.new_submodule': module 'rdt' has no attribute " 214 | "'missing_submodule'." 215 | ) 216 | 217 | # Run 218 | _find_addons() 219 | 220 | # Assert 221 | entry_points_mock.assert_called_once_with(group='rdt_modules') 222 | warning_mock.assert_called_once_with(msg) 223 | 224 | 225 | @patch('warnings.warn') 226 | @patch('rdt.entry_points') 227 | def test__find_addons_module_and_object(entry_points_mock, warning_mock): 228 | """Test incorrect add-on name generates a warning.""" 229 | # Setup 230 | bad_entry_point = Mock() 231 | bad_entry_point.name = 'rdt.missing_submodule:new_object' 232 | entry_points_mock.return_value = [bad_entry_point] 233 | msg = ( 234 | "Failed to set 'rdt.missing_submodule:new_object': cannot add 'new_object' to unknown " 235 | "submodule 'rdt.missing_submodule'." 236 | ) 237 | 238 | # Run 239 | _find_addons() 240 | 241 | # Assert 242 | entry_points_mock.assert_called_once_with(group='rdt_modules') 243 | warning_mock.assert_called_once_with(msg) 244 | 245 | 246 | @patch('warnings.warn') 247 | @patch.object(rdt, 'entry_points') 248 | def test__find_addons_missing_object(entry_points_mock, warning_mock, mock_rdt): 249 | """Test incorrect add-on name generates a warning.""" 250 | # Setup 251 | bad_entry_point = Mock() 252 | bad_entry_point.name = 'rdt.submodule:missing_object.new_method' 253 | entry_points_mock.return_value = [bad_entry_point] 254 | msg = "Failed to set 'rdt.submodule:missing_object.new_method': missing_object." 255 | 256 | del mock_rdt.submodule.missing_object 257 | 258 | # Run 259 | _find_addons() 260 | 261 | # Assert 262 | entry_points_mock.assert_called_once_with(group='rdt_modules') 263 | warning_mock.assert_called_once_with(msg) 264 | -------------------------------------------------------------------------------- /tests/unit/test__utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | 5 | from rdt._utils import _validate_unique_transformer_instances 6 | from rdt.errors import ( 7 | InvalidConfigError, 8 | ) 9 | from rdt.transformers import ( 10 | BaseMultiColumnTransformer, 11 | BaseTransformer, 12 | ) 13 | 14 | 15 | @pytest.fixture() 16 | def column_name_to_transformer(): 17 | return { 18 | 'colA': BaseTransformer(), 19 | 'colB': BaseTransformer(), 20 | 'colC': BaseTransformer(), 21 | ('colD', 'colE'): BaseMultiColumnTransformer(), 22 | ('colF', 'colG'): BaseMultiColumnTransformer(), 23 | 'colH': None, 24 | 'colI': None, 25 | } 26 | 27 | 28 | def test__validate_unique_transformer_instances_no_duplicates(column_name_to_transformer): 29 | """Test the function does not error when no duplicate transformers are present.""" 30 | # Run and Assert 31 | _validate_unique_transformer_instances(column_name_to_transformer) 32 | 33 | 34 | def test__validate_unique_transformer_instances_one_duplicate(column_name_to_transformer): 35 | """Test the function errors when one transformer instance is reused.""" 36 | # Setup 37 | column_name_to_transformer = column_name_to_transformer.copy() 38 | column_name_to_transformer['duped_column_1'] = column_name_to_transformer['colA'] 39 | column_name_to_transformer['duped_column_2'] = column_name_to_transformer['colA'] 40 | 41 | # Run and Assert 42 | expected_msg = re.escape( 43 | "The same transformer instance is being assigned to columns ('colA', 'duped_column_1', " 44 | "'duped_column_2'). Please create different transformer objects for each assignment." 45 | ) 46 | with pytest.raises(InvalidConfigError, match=expected_msg): 47 | _validate_unique_transformer_instances(column_name_to_transformer) 48 | 49 | 50 | def test__validate_unique_transformer_instances_multi_column(column_name_to_transformer): 51 | """Test the function with multi-column transformers.""" 52 | # Setup 53 | column_name_to_transformer = column_name_to_transformer.copy() 54 | duplicate_transformer = column_name_to_transformer[('colD', 'colE')] 55 | column_name_to_transformer[('duped_column_1', 'duped_column_2')] = duplicate_transformer 56 | 57 | # Run and Assert 58 | expected_msg = re.escape( 59 | "The same transformer instance is being assigned to columns (('colD', 'colE'), " 60 | "('duped_column_1', 'duped_column_2')). Please create different transformer " 61 | 'objects for each assignment.' 62 | ) 63 | with pytest.raises(InvalidConfigError, match=expected_msg): 64 | _validate_unique_transformer_instances(column_name_to_transformer) 65 | 66 | 67 | def test__validate_unique_transformer_instances_multiple_duplicates(column_name_to_transformer): 68 | """Test the function errors when many transformer instances are reused.""" 69 | # Setup 70 | column_name_to_transformer = column_name_to_transformer.copy() 71 | column_name_to_transformer['duped_column_1'] = column_name_to_transformer['colA'] 72 | column_name_to_transformer['duped_column_2'] = column_name_to_transformer['colA'] 73 | column_name_to_transformer['duped_column_3'] = column_name_to_transformer['colB'] 74 | column_name_to_transformer['duped_column_4'] = column_name_to_transformer['colB'] 75 | 76 | # Run and Assert 77 | expected_msg = re.escape( 78 | "The same transformer instances are being assigned to columns ('colA', 'duped_column_1', " 79 | "'duped_column_2'), columns ('colB', 'duped_column_3', 'duped_column_4'). Please create " 80 | 'different transformer objects for each assignment.' 81 | ) 82 | with pytest.raises(InvalidConfigError, match=expected_msg): 83 | _validate_unique_transformer_instances(column_name_to_transformer) 84 | -------------------------------------------------------------------------------- /tests/unit/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT transformers unit testing package.""" 2 | -------------------------------------------------------------------------------- /tests/unit/transformers/pii/__init__.py: -------------------------------------------------------------------------------- 1 | """RDT Personal Identifiable Information testing module.""" 2 | -------------------------------------------------------------------------------- /tests/unit/transformers/pii/test_anonymization.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | from rdt.transformers.pii.anonymization import ( 4 | _detect_provider_name, 5 | get_anonymized_transformer, 6 | get_faker_instance, 7 | is_faker_function, 8 | ) 9 | 10 | 11 | class TestAnonymization: 12 | def test__detect_provider_name(self): 13 | """Test the ``_detect_provider_name`` method. 14 | 15 | Test that the function returns an expected provider name from the ``faker.Faker`` instance. 16 | If this is from the ``BaseProvider`` it should also return that name. 17 | 18 | Input: 19 | - Faker function name. 20 | 21 | Output: 22 | - The faker provider name for that function. 23 | """ 24 | # Run / Assert 25 | email_provider = _detect_provider_name('email') 26 | lexify_provider = _detect_provider_name('lexify') 27 | state_provider = _detect_provider_name('state') 28 | 29 | assert email_provider == 'internet' 30 | assert lexify_provider == 'BaseProvider' 31 | assert state_provider == 'address.en_US' 32 | 33 | @patch('rdt.transformers.pii.anonymization.AnonymizedFaker') 34 | def test_get_anonymized_transformer_with_existing_sdtype(self, mock_anonymized_faker): 35 | """Test the ``get_anonymized_transformer`` method. 36 | 37 | Test that when calling with an existing ``sdtype`` / ``function_name`` from the 38 | ``SDTYPE_ANONYMZIERS`` dictionary, their ``provider_name`` and ``function_name`` are being 39 | used by default, and also other ``kwargs`` and provided locales are being passed to the 40 | ``AnonymizedFaker``. 41 | 42 | Input: 43 | - ``function_name`` from the ``SDTYPE_ANONYMIZERS``. 44 | - ``function_kwargs`` additional keyword arguments for that set of arguments. 45 | 46 | Mock: 47 | - Mock ``AnonymizedFaker`` and assert that has been called with the expected 48 | arguments. 49 | 50 | Output: 51 | - The return value must be the instance of ``AnonymizedFaker``. 52 | """ 53 | # Setup 54 | output = get_anonymized_transformer( 55 | 'email', 56 | transformer_kwargs={ 57 | 'function_kwargs': {'domain': '@gmail.com'}, 58 | 'locales': ['en_CA', 'fr_CA'], 59 | }, 60 | ) 61 | 62 | # Assert 63 | assert output == mock_anonymized_faker.return_value 64 | mock_anonymized_faker.assert_called_once_with( 65 | provider_name='internet', 66 | function_name='email', 67 | function_kwargs={'domain': '@gmail.com'}, 68 | locales=['en_CA', 'fr_CA'], 69 | ) 70 | 71 | @patch('rdt.transformers.pii.anonymization.AnonymizedFaker') 72 | def test_get_anonymized_transformer_with_custom_sdtype(self, mock_anonymized_faker): 73 | """Test the ``get_anonymized_transformer`` method. 74 | 75 | Test that when calling with a custom ``sdtype`` / ``function_name`` that does not belong 76 | to the ``SDTYPE_ANONYMZIERS`` dictionary. The ``provider_name`` is being found 77 | automatically other ``kwargs`` and provided locales are being passed to the 78 | ``AnonymizedFaker``. 79 | 80 | Input: 81 | - ``function_name`` color. 82 | - ``function_kwargs`` a dictionary with ``'hue': 'red'``. 83 | 84 | Mock: 85 | - Mock ``AnonymizedFaker`` and assert that has been called with the expected 86 | arguments. 87 | 88 | Output: 89 | - The return value must be the instance of ``AnonymizedFaker``. 90 | """ 91 | # Setup 92 | output = get_anonymized_transformer( 93 | 'color', 94 | transformer_kwargs={ 95 | 'function_kwargs': {'hue': 'red'}, 96 | 'locales': ['en_CA', 'fr_CA'], 97 | }, 98 | ) 99 | 100 | # Assert 101 | assert output == mock_anonymized_faker.return_value 102 | mock_anonymized_faker.assert_called_once_with( 103 | provider_name='color', 104 | function_name='color', 105 | function_kwargs={'hue': 'red'}, 106 | locales=['en_CA', 'fr_CA'], 107 | ) 108 | 109 | @patch('rdt.transformers.pii.anonymization.Faker') 110 | def test_is_faker_function(self, faker_mock): 111 | """Test that the method returns True if the ``function_name`` is a valid faker function. 112 | 113 | This test mocks the ``Faker`` method to make sure that the ``function_name`` is an 114 | attribute it has. 115 | """ 116 | # Setup 117 | faker_mock.return_value = Mock(spec=['address']) 118 | 119 | # Run 120 | result = is_faker_function('address') 121 | 122 | # Assert 123 | assert result is True 124 | 125 | @patch('rdt.transformers.pii.anonymization.get_faker_instance') 126 | def test_is_faker_function_error(self, mock_get_faker_instance): 127 | """Test that the method returns False if ``function_name`` is not a valid faker function. 128 | 129 | If the ``function_name`` is not an attribute of ``Faker()`` then we should return false. 130 | This test mocks ``Faker`` to not have the attribute that is passed as ``function_name``. 131 | """ 132 | # Setup 133 | mock_get_faker_instance.return_value = Mock(spec=[]) 134 | 135 | # Run 136 | result = is_faker_function('blah') 137 | 138 | # Assert 139 | assert result is False 140 | mock_get_faker_instance.assert_called_once() 141 | 142 | @patch('rdt.transformers.pii.anonymization.Faker') 143 | def test_get_faker_instance(self, mock_faker): 144 | """Test that ``get_faker_instance`` returns the same object.""" 145 | # Setup 146 | first_instance = get_faker_instance() 147 | 148 | # Run 149 | second_instance = get_faker_instance() 150 | 151 | # Assert 152 | assert id(first_instance) == id(second_instance) 153 | -------------------------------------------------------------------------------- /tests/unit/transformers/pii/test_utils.py: -------------------------------------------------------------------------------- 1 | from rdt.transformers.pii.utils import get_provider_name 2 | 3 | 4 | def test_get_provider_name(): 5 | """Test the ``get_provider_name`` method. 6 | 7 | Test that the function returns an expected provider name from the ``faker.Faker`` instance. 8 | If this is from the ``BaseProvider`` it should also return that name. 9 | """ 10 | # Run 11 | email_provider = get_provider_name('email') 12 | lexify_provider = get_provider_name('lexify') 13 | 14 | # Assert 15 | assert email_provider == 'internet' 16 | assert lexify_provider == 'BaseProvider' 17 | -------------------------------------------------------------------------------- /tests/unit/transformers/test___init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rdt.transformers import ( 4 | AnonymizedFaker, 5 | BinaryEncoder, 6 | FloatFormatter, 7 | RegexGenerator, 8 | UniformEncoder, 9 | UnixTimestampEncoder, 10 | get_default_transformers, 11 | get_transformer_class, 12 | get_transformer_name, 13 | ) 14 | 15 | 16 | def test_get_transformer_name(): 17 | """Test the ``get_transformer_name`` method. 18 | 19 | Validate the method returns the class path when passed the class. 20 | 21 | Input: 22 | - a class. 23 | 24 | Output: 25 | - the path of the class. 26 | """ 27 | # Setup 28 | transformer = BinaryEncoder 29 | 30 | # Run 31 | returned = get_transformer_name(transformer) 32 | 33 | # Assert 34 | assert returned == 'rdt.transformers.boolean.BinaryEncoder' 35 | 36 | 37 | def test_get_transformer_name_incorrect_input(): 38 | """Test the ``get_transformer_name`` method crashes. 39 | 40 | Validate the method raises a ``ValueError`` when passed a string. 41 | 42 | Input: 43 | - a string. 44 | 45 | Raises: 46 | - ``ValueError``, with the correct output message. 47 | """ 48 | # Setup 49 | transformer = 'rdt.transformers.boolean.BinaryEncoder' 50 | 51 | # Run / Assert 52 | error_msg = 'The transformer rdt.transformers.boolean.BinaryEncoder must be passed as a class.' 53 | with pytest.raises(ValueError, match=error_msg): 54 | get_transformer_name(transformer) 55 | 56 | 57 | def test_get_transformer_class_transformer_path(): 58 | """Test the ``get_transformer_class`` method. 59 | 60 | Validate the method returns the correct class when passed the class path. 61 | 62 | Input: 63 | - a string describing the transformer path. 64 | 65 | Output: 66 | - the class corresponding to the transformer path. 67 | """ 68 | # Setup 69 | transformer_path = 'rdt.transformers.boolean.BinaryEncoder' 70 | 71 | # Run 72 | returned = get_transformer_class(transformer_path) 73 | 74 | # Assert 75 | assert returned == BinaryEncoder 76 | 77 | 78 | def test_get_transformer_class_partial_path(): 79 | """Test with non fully specified path.""" 80 | # Run 81 | returned = get_transformer_class('rdt.transformers.BinaryEncoder') 82 | 83 | # Assert 84 | assert returned == BinaryEncoder 85 | 86 | 87 | def test_get_default_transformers(): 88 | """Test the ``get_default_transformers`` method. 89 | 90 | Check that the right default transformer is returned for each type. 91 | """ 92 | # Run 93 | default_transformer_dict = get_default_transformers() 94 | 95 | # Assert 96 | expected_dict = { 97 | 'numerical': FloatFormatter, 98 | 'categorical': UniformEncoder, 99 | 'boolean': UniformEncoder, 100 | 'datetime': UnixTimestampEncoder, 101 | 'id': RegexGenerator, 102 | 'pii': AnonymizedFaker, 103 | } 104 | 105 | for sdtype, transformer in expected_dict.items(): 106 | assert isinstance(default_transformer_dict[sdtype], transformer) 107 | -------------------------------------------------------------------------------- /tests/unit/transformers/test_boolean.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from unittest.mock import Mock 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from rdt.transformers import BinaryEncoder 8 | from rdt.transformers.null import NullTransformer 9 | 10 | 11 | class TestBinaryEncoder(TestCase): 12 | def test___init__(self): 13 | """Test default instance""" 14 | # Run 15 | transformer = BinaryEncoder() 16 | 17 | # Asserts 18 | error_message = 'Unexpected missing_value_replacement' 19 | error_generation = 'Unexpected missing_value_generation' 20 | assert transformer.missing_value_replacement == 'mode', error_message 21 | assert transformer.missing_value_generation == 'random', error_generation 22 | 23 | def test___init___model_missing_value_passed(self): 24 | """Test when model missing value is passed to the init.""" 25 | # Run 26 | transformer = BinaryEncoder(model_missing_values=True) 27 | 28 | # Assert 29 | transformer.missing_value_generation == 'from_column' 30 | 31 | def test__fit_missing_value_replacement_not_ignore(self): 32 | """Test _fit missing_value_replacement not equal to ignore""" 33 | # Setup 34 | data = pd.Series([False, True, True, False, True]) 35 | 36 | # Run 37 | transformer = BinaryEncoder(missing_value_replacement=0) 38 | transformer._fit(data) 39 | 40 | # Asserts 41 | error_msg = 'Unexpected fill value' 42 | assert transformer.null_transformer._missing_value_replacement == 0, error_msg 43 | 44 | def test__fit_array(self): 45 | """Test _fit with numpy.array""" 46 | # Setup 47 | data = pd.Series([False, True, True, False, True]) 48 | 49 | # Run 50 | transformer = BinaryEncoder(missing_value_replacement=0) 51 | transformer._fit(data) 52 | 53 | # Asserts 54 | error_msg = 'Unexpected fill value' 55 | assert transformer.null_transformer._missing_value_replacement == 0, error_msg 56 | 57 | def test__fit_missing_value_generation_from_column(self): 58 | """Test output_properties contains 'is_null' column. 59 | 60 | When missing_value_generation is 'from_column' the expected output is to have an extra 61 | column. 62 | """ 63 | # Setup 64 | transformer = BinaryEncoder(missing_value_generation='from_column') 65 | data = pd.Series([True, np.nan]) 66 | 67 | # Run 68 | transformer._fit(data) 69 | 70 | # Assert 71 | assert transformer.output_properties == { 72 | None: {'sdtype': 'float', 'next_transformer': None}, 73 | 'is_null': {'sdtype': 'float', 'next_transformer': None}, 74 | } 75 | 76 | def test__transform_series(self): 77 | """Test transform pandas.Series""" 78 | # Setup 79 | data = pd.Series([False, True, None, True, False]) 80 | 81 | # Run 82 | transformer = Mock() 83 | BinaryEncoder._transform(transformer, data) 84 | 85 | # Asserts 86 | expect_call_count = 1 87 | expect_call_args = pd.Series([0.0, 1.0, None, 1.0, 0.0], dtype=float) 88 | 89 | error_msg = 'NullTransformer.transform must be called one time' 90 | assert transformer.null_transformer.transform.call_count == expect_call_count, error_msg 91 | pd.testing.assert_series_equal( 92 | transformer.null_transformer.transform.call_args[0][0], 93 | expect_call_args, 94 | ) 95 | 96 | def test__transform_array(self): 97 | """Test transform numpy.array""" 98 | # Setup 99 | data = pd.Series([False, True, None, True, False]) 100 | 101 | # Run 102 | transformer = Mock() 103 | BinaryEncoder._transform(transformer, data) 104 | 105 | # Asserts 106 | expect_call_count = 1 107 | expect_call_args = pd.Series([0.0, 1.0, None, 1.0, 0.0], dtype=float) 108 | 109 | error_msg = 'NullTransformer.transform must be called one time' 110 | assert transformer.null_transformer.transform.call_count == expect_call_count, error_msg 111 | pd.testing.assert_series_equal( 112 | transformer.null_transformer.transform.call_args[0][0], 113 | expect_call_args, 114 | ) 115 | 116 | def test__reverse_transform_missing_value_replacement_not_ignore(self): 117 | """Test _reverse_transform with missing_value_replacement not equal to ignore""" 118 | # Setup 119 | data = np.array([0.0, 1.0, 0.0, 1.0, 0.0]) 120 | transformed_data = np.array([0.0, 1.0, 0.0, 1.0, 0.0]) 121 | 122 | # Run 123 | transformer = Mock() 124 | transformer.missing_value_replacement = 0 125 | transformer.null_transformer.reverse_transform.return_value = transformed_data 126 | 127 | result = BinaryEncoder._reverse_transform(transformer, data) 128 | 129 | # Asserts 130 | expect = np.array([False, True, False, True, False]) 131 | expect_call_count = 1 132 | 133 | np.testing.assert_equal(result, expect) 134 | 135 | error_msg = ( 136 | 'NullTransformer.reverse_transform should not be called when ' 137 | 'missing_value_replacement is ignore' 138 | ) 139 | reverse_transform_call_count = transformer.null_transformer.reverse_transform.call_count 140 | assert reverse_transform_call_count == expect_call_count, error_msg 141 | 142 | def test__reverse_transform_series(self): 143 | """Test when data is a Series.""" 144 | # Setup 145 | data = pd.Series([1.0, 0.0, 1.0]) 146 | 147 | # Run 148 | transformer = Mock() 149 | transformer.null_transformer.reverse_transform.return_value = data 150 | result = BinaryEncoder._reverse_transform(transformer, data) 151 | 152 | # Asserts 153 | expected = np.array([True, False, True]) 154 | assert isinstance(result, pd.Series) 155 | np.testing.assert_equal(result.array, expected) 156 | 157 | def test__reverse_transform_not_null_values(self): 158 | """Test _reverse_transform not null values correctly""" 159 | # Setup 160 | data = np.array([1.0, 0.0, 1.0]) 161 | 162 | # Run 163 | transformer = Mock() 164 | transformer.null_transformer.reverse_transform.return_value = data 165 | 166 | result = BinaryEncoder._reverse_transform(transformer, data) 167 | 168 | # Asserts 169 | expected = np.array([True, False, True]) 170 | 171 | assert isinstance(result, pd.Series) 172 | np.testing.assert_equal(result.array, expected) 173 | 174 | def test__reverse_transform_2d_ndarray(self): 175 | """Test _reverse_transform not null values correctly""" 176 | # Setup 177 | data = np.array([[1.0], [0.0], [1.0]]) 178 | 179 | # Run 180 | transformer = Mock() 181 | transformer.null_transformer.reverse_transform.return_value = data 182 | 183 | result = BinaryEncoder._reverse_transform(transformer, data) 184 | 185 | # Asserts 186 | expected = np.array([True, False, True]) 187 | 188 | assert isinstance(result, pd.Series) 189 | np.testing.assert_equal(result.array, expected) 190 | 191 | def test__reverse_transform_float_values(self): 192 | """Test the ``_reverse_transform`` method with decimals. 193 | 194 | Expect that the ``_reverse_transform`` method handles decimal inputs 195 | correctly by rounding them. 196 | 197 | Input: 198 | - Transformed data with decimal values. 199 | Output: 200 | - Reversed transformed data. 201 | """ 202 | # Setup 203 | data = np.array([1.2, 0.32, 1.01]) 204 | transformer = Mock() 205 | transformer.null_transformer.reverse_transform.return_value = data 206 | 207 | # Run 208 | result = BinaryEncoder._reverse_transform(transformer, data) 209 | 210 | # Asserts 211 | expected = np.array([True, False, True]) 212 | 213 | assert isinstance(result, pd.Series) 214 | np.testing.assert_equal(result.to_numpy(), expected) 215 | 216 | def test__reverse_transform_float_values_out_of_range(self): 217 | """Test the ``_reverse_transform`` method with decimals that are out of range. 218 | 219 | Expect that the ``_reverse_transform`` method handles decimal inputs 220 | correctly by rounding them. If the rounded decimal inputs are < 0 or > 1, expect 221 | expect them to be clipped. 222 | 223 | Input: 224 | - Transformed data with decimal values, some of which round to < 0 or > 1. 225 | Output: 226 | - Reversed transformed data. 227 | """ 228 | # Setup 229 | data = np.array([1.9, -0.7, 1.01]) 230 | transformer = Mock() 231 | transformer.null_transformer.reverse_transform.return_value = data 232 | 233 | # Run 234 | result = BinaryEncoder._reverse_transform(transformer, data) 235 | 236 | # Asserts 237 | expected = np.array([True, False, True]) 238 | 239 | assert isinstance(result, pd.Series) 240 | np.testing.assert_equal(result.array, expected) 241 | 242 | def test__reverse_transform_numpy_nan(self): 243 | """Test the ``_reverse_transform`` method with decimals that are out of range. 244 | 245 | Expect that the ``_reverse_transform`` method contains the `np.nan` instead of 246 | other `nan` value. 247 | 248 | Input: 249 | - Transformed data with decimal values, some of which are ``np.nan``. 250 | 251 | Mock: 252 | - Mock `np.nan`. 253 | 254 | Output: 255 | - Reversed transformed data containing `np.nan` mocked value. 256 | """ 257 | # Setup 258 | data = np.array([1.9, np.nan, 1.01]) 259 | transformer = Mock() 260 | transformer.null_transformer.reverse_transform.return_value = data 261 | 262 | # Run 263 | result = BinaryEncoder._reverse_transform(transformer, data) 264 | 265 | # Asserts 266 | assert np.isnan(result[1]) 267 | assert isinstance(result[1], float) 268 | 269 | def test__set_fitted_parameters(self): 270 | """Test ``_set_fitted_parameters`` sets the required parameters for transformer.""" 271 | # Setup 272 | transformer = BinaryEncoder() 273 | column_name = 'single_col' 274 | null_transformer = NullTransformer('mode') 275 | 276 | # Run 277 | transformer._set_fitted_parameters(column_name, null_transformer) 278 | 279 | # Assert 280 | assert transformer.columns == [column_name] 281 | assert transformer.output_columns == [column_name] 282 | assert transformer.null_transformer == null_transformer 283 | 284 | def test__set_fitted_parameters_from_column(self): 285 | """Test ``_set_fitted_parameters`` sets the required parameters for transformer.""" 286 | # Setup 287 | transformer = BinaryEncoder() 288 | column_name = 'single_col' 289 | bool_col_name = column_name + '.is_null' 290 | null_transformer = NullTransformer('mode', 'from_column') 291 | 292 | # Run 293 | transformer._set_fitted_parameters(column_name, null_transformer) 294 | 295 | # Assert 296 | assert transformer.columns == [column_name] 297 | assert transformer.output_columns == [column_name, bool_col_name] 298 | assert transformer.null_transformer == null_transformer 299 | -------------------------------------------------------------------------------- /tests/unit/transformers/test_text.py: -------------------------------------------------------------------------------- 1 | """Test Text Transformers.""" 2 | 3 | import pytest 4 | 5 | 6 | def test_deprecation_warning_is_raised(): 7 | """Test that a deprecation warning is raised when importing from this module.""" 8 | # Run and Assert 9 | expected_message = ( 10 | "Importing 'IDGenerator' or 'RegexGenerator' for ID columns from 'rdt.transformers.text' " 11 | "is deprecated. Please use 'rdt.transformers.id' instead." 12 | ) 13 | with pytest.warns(DeprecationWarning, match=expected_message): 14 | from rdt.transformers.text import IDGenerator, RegexGenerator # noqa: F401 15 | --------------------------------------------------------------------------------