├── .editorconfig ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── dependabot.yml ├── stale.yml └── workflows │ ├── contributors.yml │ ├── documentation_links.yml │ ├── label-conflicts.yml │ ├── releasing.yml │ └── testing.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── .tmp │ └── exp_version_manager.yml ├── apidocs_common.md ├── apidocs_config.md ├── apidocs_coreclasses.md ├── apidocs_head.md ├── apidocs_model.md ├── apidocs_ssl.md ├── apidocs_utils.md ├── authors.md ├── contributing.md ├── data.md ├── development_manual.md ├── experiment_tracking.md ├── explainability.md ├── faq.md ├── gs_cite.md ├── gs_installation.md ├── gs_usage.md ├── history.md ├── imgs │ ├── auto_encoder.png │ ├── diataxis.webp │ ├── dndt.png │ ├── gflu.png │ ├── gflu_v2.png │ ├── gradient_histograms.png │ ├── gradient_norms.png │ ├── log_logits.png │ ├── model_stacking_concept.png │ ├── node_arch.png │ ├── node_dense_arch.png │ ├── pytorch_tabular_logo.png │ ├── pytorch_tabular_logo.svg │ ├── pytorch_tabular_logo_inv.png │ ├── pytorch_tabular_logo_small.svg │ ├── pytorch_tabular_logo_small_2.png │ ├── pytorch_tabular_logo_small_2.svg │ └── tabnet_architecture.png ├── index.md ├── models.md ├── optimizer.md ├── other_features.md ├── ssl_models.md ├── tabular_model.md ├── training.md └── tutorials │ ├── 01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb │ ├── 02-Exploring Advanced Features with PyTorch Tabular.ipynb │ ├── 03-Neural Embedding in Scikit-Learn Workflows.ipynb │ ├── 04-Implementing New Architectures.ipynb │ ├── 05-Experiment_Tracking_using_WandB.ipynb │ ├── 05.1-Experiment_Tracking_using_Tensorboard.ipynb │ ├── 06-Imbalanced Classification.ipynb │ ├── 07-Probabilistic Regression with MDN.ipynb │ ├── 08-Self-Supervised Learning-DAE.ipynb │ ├── 09-Cross Validation.ipynb │ ├── 10-Hyperparameter Tuning.ipynb │ ├── 11-Test Time Augmentation.ipynb │ ├── 12-Bagged Predictions.ipynb │ ├── 13-Using Model Sweep as an initial Model Selection Tool.ipynb │ ├── 14-Explainability.ipynb │ ├── 15-Multi Target Classification.ipynb │ ├── 15-Search Best Architecture and Hyperparameter.ipynb │ ├── 16-Model Stacking.ipynb │ ├── imgs │ ├── prob_reg_eq_1.png │ ├── prob_reg_eq_2.png │ ├── prob_reg_fig_1.png │ ├── prob_reg_fig_2.png │ ├── prob_reg_fig_3.png │ ├── prob_reg_hist_1.png │ ├── prob_reg_hist_2.png │ ├── prob_reg_hist_3.png │ ├── prob_reg_hist_4.png │ ├── prob_reg_mdn_1.png │ ├── prob_reg_mdn_2.png │ ├── prob_reg_mdn_3.png │ ├── prob_reg_mixing12_3.png │ ├── prob_reg_mixing1_3.png │ ├── prob_reg_mixing2_3.png │ ├── prob_reg_non_mdn_2.png │ ├── prob_reg_non_mdn_3.png │ ├── prob_reg_pdfs_4.png │ ├── tensorboard_example.png │ ├── wandb_preview_1.png │ └── wandb_preview_2.png │ └── trainer_config.yml ├── examples ├── PyTorch Tabular with Bank Marketing Dataset.ipynb ├── README.md ├── __only_for_dev__ │ ├── adhoc_scaffold.py │ ├── runtime_benchmarks.txt │ ├── to_test_captum.py │ ├── to_test_classification.py │ ├── to_test_dae.py │ ├── to_test_node.py │ ├── to_test_regression.py │ └── to_test_regression_custom_models.py ├── covertype_classification.py ├── covertype_classification_using_yaml.py └── yaml_config │ ├── data_config.yml │ ├── gate_full_model_config.yml │ ├── gate_lite_mdn.yml │ ├── gate_lite_model_config.yml │ ├── optimizer_config.yml │ └── trainer_config.yml ├── mkdocs.yml ├── pyproject.toml ├── requirements ├── base.txt ├── dev.txt └── extra.txt ├── setup.cfg ├── setup.py ├── src └── pytorch_tabular │ ├── __init__.py │ ├── categorical_encoders.py │ ├── config │ ├── __init__.py │ └── config.py │ ├── feature_extractor.py │ ├── models │ ├── __init__.py │ ├── autoint │ │ ├── __init__.py │ │ ├── autoint.py │ │ └── config.py │ ├── base_model.py │ ├── category_embedding │ │ ├── __init__.py │ │ ├── category_embedding_model.py │ │ └── config.py │ ├── common │ │ ├── __init__.py │ │ ├── heads │ │ │ ├── __init__.py │ │ │ ├── blocks.py │ │ │ └── config.py │ │ └── layers │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── batch_norm.py │ │ │ ├── embeddings.py │ │ │ ├── gated_units.py │ │ │ ├── misc.py │ │ │ ├── soft_trees.py │ │ │ └── transformers.py │ ├── danet │ │ ├── __init__.py │ │ ├── arch_blocks.py │ │ ├── config.py │ │ └── danet.py │ ├── ft_transformer │ │ ├── __init__.py │ │ ├── config.py │ │ └── ft_transformer.py │ ├── gandalf │ │ ├── __init__.py │ │ ├── config.py │ │ └── gandalf.py │ ├── gate │ │ ├── __init__.py │ │ ├── config.py │ │ └── gate_model.py │ ├── mixture_density │ │ ├── __init__.py │ │ ├── config.py │ │ └── mdn.py │ ├── node │ │ ├── __init__.py │ │ ├── architecture_blocks.py │ │ ├── config.py │ │ └── node_model.py │ ├── stacking │ │ ├── __init__.py │ │ ├── config.py │ │ └── stacking_model.py │ ├── tab_transformer │ │ ├── __init__.py │ │ ├── config.py │ │ └── tab_transformer.py │ └── tabnet │ │ ├── __init__.py │ │ ├── config.py │ │ └── tabnet_model.py │ ├── ssl_models │ ├── __init__.py │ ├── base_model.py │ ├── common │ │ ├── __init__.py │ │ ├── augmentations.py │ │ ├── heads.py │ │ ├── layers.py │ │ ├── noise_generators.py │ │ ├── ssl_losses.py │ │ ├── ssl_utils.py │ │ └── utils.py │ └── dae │ │ ├── __init__.py │ │ ├── config.py │ │ └── dae.py │ ├── tabular_datamodule.py │ ├── tabular_model.py │ ├── tabular_model_sweep.py │ ├── tabular_model_tuner.py │ └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── logger.py │ ├── nn_utils.py │ └── python_utils.py └── tests ├── ___test_augmentations.py ├── __init__.py ├── conftest.py ├── test_autoint.py ├── test_categorical_embedding.py ├── test_common.py ├── test_danet.py ├── test_datamodule.py ├── test_ft_transformer.py ├── test_gandalf.py ├── test_gate.py ├── test_mdn.py ├── test_model_stacking.py ├── test_node.py ├── test_ssl.py ├── test_tabnet.py └── test_tabtransformer.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: manujosephv 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | 9 | **Describe the bug** 10 | A clear and concise description of what the bug is. 11 | 12 | **To Reproduce** 13 | Steps to reproduce the behavior: 14 | 15 | 1. Go to '...' 16 | 1. Click on '....' 17 | 1. Scroll down to '....' 18 | 1. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | 28 | - OS: [e.g. iOS] 29 | - Browser [e.g. chrome, safari] 30 | - Version [e.g. 22] 31 | 32 | **Smartphone (please complete the following information):** 33 | 34 | - Device: [e.g. iPhone6] 35 | - OS: [e.g. iOS8.1] 36 | - Browser [e.g. stock browser, safari] 37 | - Version [e.g. 22] 38 | 39 | **Additional context** 40 | Add any other context about the problem here. 41 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | --- 8 | 9 | **Is your feature request related to a problem? Please describe.** 10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 11 | 12 | **Describe the solution you'd like** 13 | A clear and concise description of what you want to happen. 14 | 15 | **Describe alternatives you've considered** 16 | A clear and concise description of any alternative solutions or features you've considered. 17 | 18 | **Additional context** 19 | Add any other context or screenshots about the feature request here. 20 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Basic dependabot.yml file with minimum configuration for two package managers 2 | 3 | version: 2 4 | updates: 5 | # Enable version updates for python 6 | - package-ecosystem: "pip" 7 | # Look for a `requirements` in the `root` directory 8 | directory: "/" 9 | # Check for updates once a week 10 | schedule: 11 | interval: "monthly" 12 | # Labels on pull requests for version updates only 13 | labels: 14 | - "dependencies" 15 | pull-request-branch-name: 16 | # Separate sections of the branch name with a hyphen for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 17 | separator: "-" 18 | # Allow up to 5 open pull requests for pip dependencies 19 | open-pull-requests-limit: 5 20 | reviewers: 21 | - "manujosephv" 22 | 23 | # Enable version updates for GitHub Actions 24 | - package-ecosystem: "github-actions" 25 | directory: "/" 26 | # Check for updates once a week 27 | schedule: 28 | interval: "monthly" 29 | # Labels on pull requests for version updates only 30 | labels: 31 | - "enhancement" 32 | pull-request-branch-name: 33 | # Separate sections of the branch name with a hyphen for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 34 | separator: "-" 35 | # Allow up to 5 open pull requests for GitHub Actions 36 | open-pull-requests-limit: 5 37 | reviewers: 38 | - "manujosephv" 39 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | - enhancement 10 | - Open for Contribution 11 | # Label to use when marking an issue as stale 12 | staleLabel: wontfix 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed if no further activity occurs. Thank you 17 | for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | -------------------------------------------------------------------------------- /.github/workflows/contributors.yml: -------------------------------------------------------------------------------- 1 | name: Add contributors 2 | on: 3 | workflow_dispatch: 4 | schedule: 5 | - cron: "0 0 1,15 * *" 6 | # push: 7 | # branches: 8 | # - master 9 | 10 | jobs: 11 | contrib-readme-job: 12 | runs-on: ubuntu-latest 13 | name: A job to automate contrib in readme 14 | steps: 15 | - name: Contribute List 16 | uses: akhilmhdh/contributors-readme-action@v2.3.10 17 | env: 18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 19 | -------------------------------------------------------------------------------- /.github/workflows/documentation_links.yml: -------------------------------------------------------------------------------- 1 | name: Read the Docs Pull Request Preview 2 | on: 3 | pull_request_target: 4 | types: 5 | - opened 6 | 7 | permissions: 8 | pull-requests: write 9 | 10 | jobs: 11 | documentation-links: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: readthedocs/actions/preview@v1 15 | with: 16 | project-slug: "pytorch-tabular" 17 | -------------------------------------------------------------------------------- /.github/workflows/label-conflicts.yml: -------------------------------------------------------------------------------- 1 | name: Label merge conflicts 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request_target: 7 | types: ["synchronize", "reopened", "opened"] 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | triage-conflicts: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: mschilde/auto-label-merge-conflicts@591722e97f3c4142df3eca156ed0dcf2bcd362bd # Oct 25, 2021 18 | with: 19 | CONFLICT_LABEL_NAME: "has conflicts" 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | MAX_RETRIES: 3 22 | WAIT_MS: 5000 23 | -------------------------------------------------------------------------------- /.github/workflows/releasing.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | tags: ["v?[0-9]+.[0-9]+.[0-9]+"] 7 | pull_request: 8 | branches: [main] 9 | release: 10 | types: [published] 11 | 12 | jobs: 13 | # based on https://github.com/pypa/gh-action-pypi-publish 14 | release-pkg: 15 | runs-on: ubuntu-20.04 16 | timeout-minutes: 10 17 | steps: 18 | - name: Checkout 🛎️ 19 | uses: actions/checkout@v4 20 | - name: Set up Python 🐍 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: 3.8 24 | 25 | - name: Create package 📦 26 | run: | 27 | pip install "twine==5.1.1" setuptools wheel 28 | python setup.py sdist bdist_wheel 29 | ls -lh dist/ 30 | twine check dist/* 31 | 32 | - name: Upload to release 33 | if: github.event_name == 'release' 34 | uses: AButler/upload-release-assets@v3.0 35 | with: 36 | files: "dist/*" 37 | repo-token: ${{ secrets.GITHUB_TOKEN }} 38 | 39 | - name: Publish distribution 📦 to PyPI 40 | if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release' 41 | uses: pypa/gh-action-pypi-publish@v1.12.2 42 | with: 43 | user: __token__ 44 | password: ${{ secrets.pypi_password }} 45 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: {} 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 10 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 11 | 12 | jobs: 13 | pytest: 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | os: [ubuntu-latest] 19 | python-version: ["3.8", "3.9", "3.10"] 20 | include: 21 | - { os: "ubuntu-20.04", python-version: "3.8", requires: "oldest" } 22 | - { os: "ubuntu-20.04", python-version: "3.9", requires: "oldest" } 23 | 24 | env: 25 | TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" 26 | 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Set up Python ${{ matrix.python-version }} 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | 34 | - name: Set min. dependencies 35 | if: matrix.requires == 'oldest' 36 | run: | 37 | import os 38 | fname = 'requirements/base.txt' 39 | lines = [line.replace('>=', '==') for line in open(fname).readlines()] 40 | open(fname, 'w').writelines(lines) 41 | shell: python 42 | 43 | - name: Install main package & dependencies 44 | run: | 45 | pip install -e .[extra] -r requirements/dev.txt -f ${TORCH_URL} 46 | pip list 47 | 48 | - name: Restore test's datasets 49 | uses: actions/cache/restore@v4 50 | with: 51 | path: tests/.datasets 52 | key: test-datasets 53 | 54 | - name: Run test-suite 55 | run: python -m pytest -v 56 | 57 | - name: Save test's datasets 58 | uses: actions/cache/save@v4 59 | with: 60 | path: tests/.datasets 61 | key: test-datasets 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | #data 141 | data/ 142 | wandb/ 143 | saved_models/ 144 | output/ 145 | .vscode/ 146 | examples/basic/ 147 | .tmp/ 148 | exp_version_manager.yml 149 | .tmp/exp_version_manager.yml 150 | compare.py 151 | checkpoints/ 152 | docs/examples/basic/ 153 | examples/test_save/ 154 | 155 | # Ruff 156 | .ruff_cache/ 157 | tests/.datasets/ 158 | test.py 159 | lightning_logs/ 160 | docs/tutorials/examples/basic/ 161 | docs/tutorials/pytorch-tabular-covertype/ 162 | 163 | # Pycharm 164 | .idea/ 165 | test.ipynb 166 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" 7 | # submodules: true 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v5.0.0 12 | hooks: 13 | - id: end-of-file-fixer 14 | exclude: "setup.cfg" 15 | - id: trailing-whitespace 16 | exclude: | 17 | (?x)( 18 | docs/| 19 | setup.cfg 20 | ) 21 | - id: check-case-conflict 22 | - id: check-yaml 23 | - id: check-toml 24 | - id: check-json 25 | - id: check-added-large-files 26 | - id: check-docstring-first 27 | - id: detect-private-key 28 | 29 | - repo: https://github.com/PyCQA/docformatter 30 | rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5 31 | hooks: 32 | - id: docformatter 33 | additional_dependencies: [tomli] 34 | args: ["--in-place"] 35 | 36 | - repo: https://github.com/executablebooks/mdformat 37 | rev: 0.7.19 38 | hooks: 39 | - id: mdformat 40 | additional_dependencies: 41 | - mdformat-gfm 42 | - mdformat-black 43 | - mdformat_frontmatter 44 | exclude: | 45 | (?x)( 46 | docs/| 47 | README.md 48 | ) 49 | 50 | - repo: https://github.com/astral-sh/ruff-pre-commit 51 | rev: v0.8.3 52 | hooks: 53 | - id: ruff 54 | args: ["--fix"] 55 | - id: ruff-format 56 | - id: ruff 57 | 58 | - repo: https://github.com/pre-commit/mirrors-prettier 59 | rev: v4.0.0-alpha.8 60 | hooks: 61 | - id: prettier 62 | files: \.(json|yml|yaml|toml) 63 | # https://prettier.io/docs/en/options.html#print-width 64 | args: ["--print-width=120"] 65 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | # sphinx: 10 | # configuration: docs/conf.py 11 | 12 | # Build documentation with MkDocs 13 | mkdocs: 14 | configuration: mkdocs.yml 15 | 16 | # Optionally build your docs in additional formats such as PDF 17 | formats: 18 | - pdf 19 | 20 | build: 21 | os: ubuntu-20.04 22 | tools: 23 | python: "3.8" 24 | 25 | # Optionally set the version of Python and requirements required to build your docs 26 | python: 27 | install: 28 | - requirements: requirements/dev.txt 29 | - method: pip 30 | path: .[extra] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020, Manu Joseph 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | include requirements *.txt 5 | 6 | recursive-include tests * 7 | recursive-exclude * __pycache__ 8 | recursive-exclude * *.py[co] 9 | 10 | recursive-include docs *.md *.jpg *.png *.gif 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | clean-build: ## remove build artifacts 32 | rm -fr build/ 33 | rm -fr dist/ 34 | rm -fr .eggs/ 35 | find . -name '*.egg-info' -exec rm -fr {} + 36 | find . -name '*.egg' -exec rm -f {} + 37 | 38 | clean-pyc: ## remove Python file artifacts 39 | find . -name '*.pyc' -exec rm -f {} + 40 | find . -name '*.pyo' -exec rm -f {} + 41 | find . -name '*~' -exec rm -f {} + 42 | find . -name '__pycache__' -exec rm -fr {} + 43 | 44 | clean-test: ## remove test and coverage artifacts 45 | rm -fr .tox/ 46 | rm -f .coverage 47 | rm -fr htmlcov/ 48 | rm -fr .pytest_cache 49 | 50 | env: 51 | mkdir -p .env 52 | # Create a virtual environment 53 | python3 -m venv .env 54 | # Activate the virtual environment 55 | source .env/$name/bin/activate 56 | # Create a temporary requirements file 57 | # Read the contents of the file into a variable and write it to a file. 58 | echo $(cat requirements.txt) > requirements.tmp 59 | # Install the required dependencies from the temporary file 60 | pip install -r requirements.tmp 61 | rm requirements.tmp 62 | 63 | test: ## run tests quickly with the default Python 64 | pytest 65 | 66 | test-all: ## run tests on every Python version with tox 67 | tox 68 | 69 | coverage: ## check code coverage quickly with the default Python 70 | coverage run --source pytorch_tabular -m pytest 71 | coverage report -m 72 | coverage html 73 | $(BROWSER) htmlcov/index.html 74 | 75 | 76 | docs: ## generate mkdocs HTML documentation, including API docs 77 | rm -rf site 78 | mkdocs build 79 | $(BROWSER) site/index.html 80 | 81 | servedocs: docs ## compile the docs watching for changes 82 | mksdocs serve 83 | 84 | release: dist ## package and upload a release 85 | twine upload dist/* 86 | 87 | dist: clean ## builds source and wheel package 88 | python setup.py sdist 89 | python setup.py bdist_wheel 90 | ls -l dist 91 | 92 | install: clean env ## install the package to the active Python's site-packages 93 | pip install -e .[dev] 94 | -------------------------------------------------------------------------------- /docs/.tmp/exp_version_manager.yml: -------------------------------------------------------------------------------- 1 | classification: 17 2 | -------------------------------------------------------------------------------- /docs/apidocs_common.md: -------------------------------------------------------------------------------- 1 | ## Embeddings 2 | 3 | ::: pytorch_tabular.models.common.layers.Embedding1dLayer 4 | options: 5 | heading_level: 3 6 | ::: pytorch_tabular.models.common.layers.Embedding2dLayer 7 | options: 8 | heading_level: 3 9 | ::: pytorch_tabular.models.common.layers.PreEncoded1dLayer 10 | options: 11 | heading_level: 3 12 | ::: pytorch_tabular.models.common.layers.SharedEmbeddings 13 | options: 14 | heading_level: 3 15 | 16 | ## Gated Units 17 | ::: pytorch_tabular.models.common.layers.GatedFeatureLearningUnit 18 | options: 19 | heading_level: 3 20 | ::: pytorch_tabular.models.common.layers.GEGLU 21 | options: 22 | heading_level: 3 23 | ::: pytorch_tabular.models.common.layers.ReGLU 24 | options: 25 | heading_level: 3 26 | ::: pytorch_tabular.models.common.layers.SwiGLU 27 | options: 28 | heading_level: 3 29 | ::: pytorch_tabular.models.common.layers.PositionWiseFeedForward 30 | options: 31 | heading_level: 3 32 | 33 | ## Soft Trees 34 | ::: pytorch_tabular.models.common.layers.NeuralDecisionTree 35 | options: 36 | heading_level: 3 37 | ::: pytorch_tabular.models.common.layers.ODST 38 | options: 39 | heading_level: 3 40 | 41 | ## Transformers 42 | ::: pytorch_tabular.models.common.layers.AddNorm 43 | options: 44 | heading_level: 3 45 | 46 | ::: pytorch_tabular.models.common.layers.AppendCLSToken 47 | options: 48 | heading_level: 3 49 | ::: pytorch_tabular.models.common.layers.MultiHeadedAttention 50 | options: 51 | heading_level: 3 52 | ::: pytorch_tabular.models.common.layers.TransformerEncoderBlock 53 | options: 54 | heading_level: 3 55 | 56 | ## Miscellaneous 57 | ::: pytorch_tabular.models.common.layers.Lambda 58 | options: 59 | heading_level: 3 60 | ::: pytorch_tabular.models.common.layers.ModuleWithInit 61 | options: 62 | heading_level: 3 63 | ::: pytorch_tabular.models.common.layers.Residual 64 | options: 65 | heading_level: 3 66 | 67 | 68 | ## Activations 69 | ::: pytorch_tabular.models.common.activations.Entmoid15 70 | options: 71 | heading_level: 3 72 | ::: pytorch_tabular.models.common.activations.entmoid15 73 | options: 74 | heading_level: 3 75 | ::: pytorch_tabular.models.common.activations.entmax15 76 | options: 77 | heading_level: 3 78 | ::: pytorch_tabular.models.common.activations.sparsemax 79 | options: 80 | heading_level: 3 81 | ::: pytorch_tabular.models.common.activations.sparsemoid 82 | options: 83 | heading_level: 3 84 | ::: pytorch_tabular.models.common.activations.t_softmax 85 | options: 86 | heading_level: 3 87 | ::: pytorch_tabular.models.common.activations.TSoftmax 88 | options: 89 | heading_level: 3 90 | ::: pytorch_tabular.models.common.activations.RSoftmax 91 | options: 92 | heading_level: 3 93 | 94 | -------------------------------------------------------------------------------- /docs/apidocs_config.md: -------------------------------------------------------------------------------- 1 | # Configurations 2 | 3 | ## Core Configuration 4 | 5 | ::: pytorch_tabular.config.DataConfig 6 | options: 7 | heading_level: 3 8 | ::: pytorch_tabular.config.ModelConfig 9 | options: 10 | heading_level: 3 11 | ::: pytorch_tabular.config.SSLModelConfig 12 | options: 13 | heading_level: 3 14 | ::: pytorch_tabular.config.TrainerConfig 15 | options: 16 | heading_level: 3 17 | ::: pytorch_tabular.config.ExperimentConfig 18 | options: 19 | heading_level: 3 20 | ::: pytorch_tabular.config.OptimizerConfig 21 | options: 22 | heading_level: 3 23 | ::: pytorch_tabular.config.ExperimentRunManager 24 | options: 25 | heading_level: 3 26 | 27 | ## Head Configuration 28 | 29 | In addition to these core classes, we also have config classes for heads 30 | 31 | ::: pytorch_tabular.models.common.heads.LinearHeadConfig 32 | options: 33 | heading_level: 3 34 | ::: pytorch_tabular.models.common.heads.MixtureDensityHeadConfig 35 | options: 36 | heading_level: 3 37 | -------------------------------------------------------------------------------- /docs/apidocs_coreclasses.md: -------------------------------------------------------------------------------- 1 | # Core Classes 2 | 3 | ::: pytorch_tabular.TabularModel 4 | options: 5 | heading_level: 3 6 | ::: pytorch_tabular.TabularDatamodule 7 | options: 8 | heading_level: 3 9 | ::: pytorch_tabular.TabularModelTuner 10 | options: 11 | heading_level: 3 12 | ::: pytorch_tabular.model_sweep 13 | options: 14 | heading_level: 3 15 | -------------------------------------------------------------------------------- /docs/apidocs_head.md: -------------------------------------------------------------------------------- 1 | ## Configuration Classes 2 | 3 | ::: pytorch_tabular.models.common.heads.LinearHeadConfig 4 | options: 5 | heading_level: 3 6 | ::: pytorch_tabular.models.common.heads.MixtureDensityHeadConfig 7 | options: 8 | heading_level: 3 9 | 10 | ## Head Classes 11 | 12 | ::: pytorch_tabular.models.common.heads.LinearHead 13 | options: 14 | heading_level: 3 15 | ::: pytorch_tabular.models.common.heads.MixtureDensityHead 16 | options: 17 | heading_level: 3 18 | -------------------------------------------------------------------------------- /docs/apidocs_model.md: -------------------------------------------------------------------------------- 1 | ## Configuration Classes 2 | 3 | ::: pytorch_tabular.models.AutoIntConfig 4 | options: 5 | heading_level: 3 6 | ::: pytorch_tabular.models.CategoryEmbeddingModelConfig 7 | options: 8 | heading_level: 3 9 | ::: pytorch_tabular.models.DANetConfig 10 | options: 11 | heading_level: 3 12 | ::: pytorch_tabular.models.FTTransformerConfig 13 | options: 14 | heading_level: 3 15 | ::: pytorch_tabular.models.GANDALFConfig 16 | options: 17 | heading_level: 3 18 | ::: pytorch_tabular.models.GatedAdditiveTreeEnsembleConfig 19 | options: 20 | heading_level: 3 21 | ::: pytorch_tabular.models.MDNConfig 22 | options: 23 | heading_level: 3 24 | ::: pytorch_tabular.models.NodeConfig 25 | options: 26 | heading_level: 3 27 | ::: pytorch_tabular.models.TabNetModelConfig 28 | options: 29 | heading_level: 3 30 | ::: pytorch_tabular.models.TabTransformerConfig 31 | options: 32 | heading_level: 3 33 | ::: pytorch_tabular.models.StackingModelConfig 34 | options: 35 | heading_level: 3 36 | ::: pytorch_tabular.config.ModelConfig 37 | options: 38 | heading_level: 3 39 | 40 | ## Model Classes 41 | 42 | ::: pytorch_tabular.models.AutoIntModel 43 | options: 44 | heading_level: 3 45 | ::: pytorch_tabular.models.CategoryEmbeddingModel 46 | options: 47 | heading_level: 3 48 | ::: pytorch_tabular.models.DANetModel 49 | options: 50 | heading_level: 3 51 | ::: pytorch_tabular.models.FTTransformerModel 52 | options: 53 | heading_level: 3 54 | ::: pytorch_tabular.models.GANDALFModel 55 | options: 56 | heading_level: 3 57 | ::: pytorch_tabular.models.GatedAdditiveTreeEnsembleModel 58 | options: 59 | heading_level: 3 60 | ::: pytorch_tabular.models.MDNModel 61 | options: 62 | heading_level: 3 63 | ::: pytorch_tabular.models.NODEModel 64 | options: 65 | heading_level: 3 66 | ::: pytorch_tabular.models.TabNetModel 67 | options: 68 | heading_level: 3 69 | ::: pytorch_tabular.models.TabTransformerModel 70 | options: 71 | heading_level: 3 72 | ::: pytorch_tabular.models.StackingModel 73 | options: 74 | heading_level: 3 75 | ## Base Model Class 76 | ::: pytorch_tabular.models.BaseModel 77 | options: 78 | heading_level: 3 79 | -------------------------------------------------------------------------------- /docs/apidocs_ssl.md: -------------------------------------------------------------------------------- 1 | ## Configuration Classes 2 | 3 | ::: pytorch_tabular.ssl_models.DenoisingAutoEncoderConfig 4 | options: 5 | heading_level: 3 6 | 7 | ## Model Classes 8 | 9 | ::: pytorch_tabular.ssl_models.DenoisingAutoEncoderModel 10 | options: 11 | heading_level: 3 12 | 13 | ## Base Model Class 14 | ::: pytorch_tabular.ssl_models.SSLBaseModel 15 | options: 16 | heading_level: 3 17 | -------------------------------------------------------------------------------- /docs/apidocs_utils.md: -------------------------------------------------------------------------------- 1 | # Utilities 2 | 3 | ## Special Feature Classes 4 | ::: pytorch_tabular.CategoricalEmbeddingTransformer 5 | options: 6 | heading_level: 3 7 | ::: pytorch_tabular.DeepFeatureExtractor 8 | options: 9 | heading_level: 3 10 | 11 | ## Data Utilities 12 | ::: pytorch_tabular.utils.get_balanced_sampler 13 | options: 14 | heading_level: 3 15 | ::: pytorch_tabular.utils.get_class_weighted_cross_entropy 16 | options: 17 | heading_level: 3 18 | ::: pytorch_tabular.utils.get_gaussian_centers 19 | options: 20 | heading_level: 3 21 | ::: pytorch_tabular.utils.load_covertype_dataset 22 | options: 23 | heading_level: 3 24 | ::: pytorch_tabular.utils.make_mixed_dataset 25 | options: 26 | heading_level: 3 27 | ::: pytorch_tabular.utils.print_metrics 28 | options: 29 | heading_level: 3 30 | 31 | ## NN Utilities 32 | ::: pytorch_tabular.utils._initialize_layers 33 | options: 34 | heading_level: 3 35 | ::: pytorch_tabular.utils._initialize_kaiming 36 | options: 37 | heading_level: 3 38 | ::: pytorch_tabular.utils._linear_dropout_bn 39 | options: 40 | heading_level: 3 41 | ::: pytorch_tabular.utils._make_ix_like 42 | options: 43 | heading_level: 3 44 | ::: pytorch_tabular.utils.reset_all_weights 45 | options: 46 | heading_level: 3 47 | ::: pytorch_tabular.utils.to_one_hot 48 | options: 49 | heading_level: 3 50 | ::: pytorch_tabular.utils.count_parameters 51 | options: 52 | heading_level: 3 53 | 54 | ## Python Utilities 55 | ::: pytorch_tabular.utils.getattr_nested 56 | options: 57 | heading_level: 3 58 | ::: pytorch_tabular.utils.ifnone 59 | options: 60 | heading_level: 3 61 | ::: pytorch_tabular.utils.check_numpy 62 | options: 63 | heading_level: 3 64 | ::: pytorch_tabular.utils.pl_load 65 | options: 66 | heading_level: 3 67 | ::: pytorch_tabular.utils.generate_doc_dataclass 68 | options: 69 | heading_level: 3 70 | ::: pytorch_tabular.utils.suppress_lightning_logs 71 | options: 72 | heading_level: 3 73 | ::: pytorch_tabular.utils.enable_lightning_logs 74 | options: 75 | heading_level: 3 76 | ::: pytorch_tabular.utils.int_to_human_readable 77 | options: 78 | heading_level: 3 79 | -------------------------------------------------------------------------------- /docs/authors.md: -------------------------------------------------------------------------------- 1 | # Credits 2 | 3 | ## Development Lead 4 | 5 | - [Manu Joseph](https://github.com/manujosephv) | Email: [manujosephv@gmail.com](mailto:manujosephv@gmail.com) | [LinkedIn](https://linkedin.com/in/in/manujosephv) | [Twitter](https://twitter.com/manujosephv) | [Blog](https://github.com/manujosephv/manujosephv/blob/main/https:/deep-and-shallow.com/feed) 6 | 7 | ## Contributors 8 | 9 | - [Jirka Borovec](https://github.com/Borda) 10 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions are welcome, and they are greatly appreciated! Every 4 | little bit helps, and credit will always be given. 5 | 6 | You can contribute in many ways: 7 | 8 | ## Types of Contributions 9 | 10 | ### Report Bugs 11 | 12 | Report bugs at . 13 | 14 | If you are reporting a bug, please include: 15 | 16 | - Your operating system name and version. 17 | - Any details about your local setup that might be helpful in 18 | troubleshooting. 19 | - Detailed steps to reproduce the bug. 20 | 21 | ### Fix Bugs 22 | 23 | Look through the GitHub issues for bugs. Anything tagged with "bug" and 24 | "help wanted" is open to whoever wants to implement it. 25 | 26 | ### Implement Features 27 | 28 | Look through the GitHub issues for features. Anything tagged with 29 | "enhancement" and "help wanted" is open to whoever wants to implement 30 | it. 31 | 32 | ### Write Documentation 33 | 34 | Pytorch Tabular could always use more documentation, whether as part of the 35 | official Pytorch Tabular docs, in docstrings, or even on the web in blog 36 | posts, articles, and such. 37 | 38 | ### Submit Feedback 39 | 40 | The best way to send feedback is to file an issue at 41 | . 42 | 43 | If you are proposing a feature: 44 | 45 | - Explain in detail how it would work. 46 | - Keep the scope as narrow as possible, to make it easier to 47 | implement. 48 | - Remember that this is a volunteer-driven project, and that 49 | contributions are welcome :) 50 | 51 | ## Get Started! 52 | 53 | Ready to contribute? Here's how to set up PyTorch Tabular for local 54 | development. 55 | 56 | ```bash 57 | git clone git@github.com:your_name_here/pytorch_tabular.git 58 | ``` 59 | 60 | * Fork the pytorch_tabular repo on GitHub. 61 | 62 | * Clone your fork locally and change directory to the checked out folder: 63 | 64 | ```bash 65 | git clone git@github.com:your_name_here/pytorch_tabular.git 66 | cd pytorch_tabular 67 | ``` 68 | 69 | * Setup a local environment (preferably in a virtual environment). 70 | 71 | Using python native venv: 72 | 73 | ``` bash 74 | mkdir .env 75 | python3 -m venv .env/tabular_env 76 | source .env/tabular_env/bin/activate 77 | pip install -e .[dev] 78 | ``` 79 | 80 | * Create a branch for local development: 81 | 82 | ```bash 83 | git checkout -b name-of-your-bugfix-or-feature 84 | ``` 85 | 86 | Now you can make your changes locally. 87 | 88 | !!! warning 89 | 90 | Never work in the `master` branch! 91 | 92 | !!! tip 93 | 94 | Have meaningful commit messages. This will help with the review and further processing of the PR. 95 | 96 | * When you are done, run the `pytest` unit tests and see if everything is a success. 97 | 98 | ```bash 99 | pytest tests/ 100 | ``` 101 | !!!note 102 | 103 | If you are adding a new feature, please add a test for it. 104 | 105 | * When you are done making changes and all test cases are passing, run `pre-commit` to make sure all the linting and formatting is done correctly. 106 | 107 | ```bash 108 | pre-commit run --all-files 109 | ``` 110 | Accept the changes if any after reviewing. 111 | 112 | !!!warning 113 | 114 | Do not commit pre-commit changes to to `setup.cfg`. The file has been excluded from one hook for bump2version compatibility. For a complet and uptodate list of excluded files, please check `.pre-commit-config.yaml` file. 115 | 116 | * Commit your changes and push your branch to GitHub: 117 | ```bash 118 | git add . 119 | git commit -m "Your detailed description of your changes." 120 | git push origin name-of-your-bugfix-or-feature 121 | ``` 122 | * Submit a pull request through the GitHub website. 123 | 124 | ## Pull Request Guidelines 125 | 126 | Before you submit a pull request, check that it meets these guidelines: 127 | 128 | 1. The pull request should include tests. 129 | 1. If the pull request adds functionality, the docs should be updated. 130 | Put your new functionality into a function with a docstring. 131 | 132 | ## Tips 133 | 134 | To run a subset of tests: 135 | 136 | ```bash 137 | pytest tests\test_* 138 | ``` 139 | 140 | ## Deploying 141 | 142 | A reminder for the maintainers on how to deploy. Make sure all your 143 | changes are committed (including an entry in HISTORY.rst). Then run: 144 | 145 | ```bash 146 | bump2version patch \# possible: major / minor / patch \$ git push \$ 147 | git push --tags 148 | ``` 149 | 150 | GitHub Actions will take care of the rest. 151 | -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | PyTorch Tabular uses Pandas Dataframes as the container which holds data. As Pandas is the most popular way of handling tabular data, this was an obvious choice. Keeping ease of useability in mind, PyTorch Tabular accepts dataframes as is, i.e. no need to split the data into `X` and `y` like in Sci-kit Learn. 2 | 3 | Pytorch Tabular handles this using a `DataConfig` object. 4 | 5 | ## Basic Usage 6 | 7 | - `target`: List\[str\]: A list of strings with the names of the target column(s) 8 | - `continuous_cols`: List\[str\]: Column names of the numeric fields. Defaults to \[\] 9 | - `categorical_cols`: List\[str\]: Column names of the categorical fields to treat differently 10 | 11 | ### Usage Example 12 | 13 | ```python 14 | data_config = DataConfig( 15 | target=["label"], 16 | continuous_cols=["feature_1", "feature_2"], 17 | categorical_cols=["cat_feature_1", "cat_feature_2"], 18 | ) 19 | ``` 20 | 21 | ## Advanced Usage: 22 | 23 | ### Date Columns 24 | 25 | If you have date_columns in the dataframe, mention the column names in `date_columns` parameter and set `encode_date_columns` to `True`. This will extract relevant features like the Month, Week, Quarter etc. and add them to your feature list internally. 26 | 27 | `date_columns` is not just a list of column names, but a list of (column name, freq) tuples. The freq is a standard Pandas date frequency tags which denotes the lowest temporal granularity which is relevant for the problem. 28 | 29 | For eg., if there is a date column for Launch Date for a Product and they only launch once a month. Then there is no sense in extracting features like week, or day etc. So, we keep the frequency at `M` 30 | 31 | ```python 32 | date_columns = [("launch_date", "M")] 33 | ``` 34 | 35 | ### Feature Transformations 36 | 37 | Feature Scaling is an almost essential step to get goog performance from most Machine Learning Algorithms, and Deep Learning is not an exception. `normalize_continuous_features` flag(which is `True` by default) scales the input continuous features using a `StandardScaler` 38 | 39 | Sometimes, changing the feature distributions using non-linear transformations helps the machine learning/deep learning algorithms. 40 | 41 | PyTorch Tabular offers 4 standard transformations using the `continuous_feature_transform` parameter: 42 | 43 | - `yeo-johnson` 44 | - `box-cox` 45 | - `quantile_uniform` 46 | - `quantile_normal` 47 | 48 | `yeo-johnson` and `box-cox` are a family of parametric, monotonic transformations that aim to map data from any distribution to as close to a Gaussian distribution as possible in order to stabilize variance and minimize skewness. `box-cox` can only be applied to *strictly positive* data. Sci-kit Learn has a good [write-up](https://scikit-learn.org/stable/modules/preprocessing.html#mapping-to-a-gaussian-distribution) about them 49 | 50 | `quantile_normal` and `quantile_uniform` are monotonic, non-parametric transformations which aims to transfom the features to a normal distribution or a uniform distribution, respectively.By performing a rank transformation, a quantile transform smooths out unusual distributions and is less influenced by outliers than scaling methods. It does, however, distort correlations and distances within and across features. 51 | 52 | ::: pytorch_tabular.config.DataConfig 53 | options: 54 | show_root_heading: yes 55 | -------------------------------------------------------------------------------- /docs/development_manual.md: -------------------------------------------------------------------------------- 1 | # Bump Version and Trigger Build 2 | 3 | ## Bump Version(Patch Update) 4 | 5 | ```bash 6 | #Commit all changes 7 | bump2version patch 8 | ``` 9 | 10 | - eg: 0.1.1 to 0.1.2-dev0 11 | - For minor patch updates 12 | - No tags are created 13 | 14 | ## Bump Version(Minor Update) 15 | 16 | ```bash 17 | #Commit all changes 18 | bump2version minor 19 | ``` 20 | 21 | - eg: 0.1.1-dev0 to 0.2.1-dev0 22 | - For minor feature updates 23 | - No tags are created 24 | 25 | ## Bump Version(Major Update) 26 | 27 | ```bash 28 | #Commit all changes 29 | bump2version major 30 | ``` 31 | 32 | - eg: 0.1.1-dev0 to 1.0.0-dev0 33 | - For major feature updates 34 | - No tags are created 35 | 36 | ## Bump Version(Release) 37 | 38 | ```bash 39 | # Add new Version and changelog to History.md 40 | # Commit all changes and run 41 | bump2version --tag release 42 | # Check if the tag is present 43 | git tag 44 | # Push the changes to GitHub 45 | git push origin 46 | 47 | ``` 48 | 49 | - eg: 0.1.1-dev0 to 0.1.1 50 | - To trigger GitHub Actions to push to PyPi 51 | - Tags are created 52 | 53 | # Revert Version and Delete a Tag 54 | 55 | - Update the version numbers in 56 | 1. setup.py 57 | 1. setup.cfg 58 | 1. __init__.py 59 | - Delete the Git Tags in local 60 | ```bash 61 | git tag -d 62 | ``` 63 | - Delete the tags from GitHub 64 | ```bash 65 | git push --delete origin 66 | ``` 67 | -------------------------------------------------------------------------------- /docs/experiment_tracking.md: -------------------------------------------------------------------------------- 1 | Experiment Tracking is almost an essential part of machine learning. It is critical in upholding reproduceability. PyTorch Tabular embraces this and supports experiment tracking internally. Currently, PyTorch Tabular supports two experiment Tracking Framework: 2 | 3 | 1. Tensorboard 4 | 1. Weights and Biases 5 | 6 | Tensorboard logging is barebones. PyTorch Tabular just logs the losses and metrics to tensorboard. W&B tracking is much more feature rich - in addition to tracking losses and metrics, it can also track the gradients of the different layers, logits of your model across epochs, etc. 7 | 8 | ## Basic Usage 9 | 10 | - `project_name`: str: The name of the project under which all runs will be logged. For Tensorboard this defines the folder under which the logs will be saved and for W&B it defines the project name 11 | - `run_name`: str: The name of the run; a specific identifier to recognize the run. If left blank, will be assigned a auto-generated name based on the task. 12 | - `log_target`: str: Determines where logging happens - Tensorboard or W&B. Choices are: `wandb` `tensorboard`. Defaults to `tensorboard` 13 | 14 | ### Usage Example 15 | 16 | ```python 17 | experiment_config = ExperimentConfig(project_name="MyAwesomeProject", run_name="my_cool_new_model", log_target="wandb") 18 | ``` 19 | 20 | ## Advanced Usage 21 | 22 | ### Track Gradients 23 | 24 | It is a good idea to track gradients to monitor if the model is learning as it is supposed to. There are two ways you can do that. 25 | 26 | 1. `exp_watch` parameter in `ExperimentConfig` can be set to `"gradients"` and choose `log_target` as `"wandb"`. This will track a histogram of gradients across epochs. 27 | 28 | ![Gradient histograms](imgs/gradient_histograms.png) 29 | 30 | 2. You can also set `track_grad_norm` to `1` or `2` for L1 or L2 norm of the gradients. This works for both Tensorboard ad W&B. 31 | 32 | ![Gradient histograms](imgs/gradient_norms.png) 33 | 34 | ### Track Logits 35 | 36 | Sometimes, it also helps to track the Logits of the model. As training progresses, the Logits should become more pronounced or concentrated around the target that we are trying to model. 37 | 38 | This can be done using the parameter `log_logits` in `ExperimentConfig`. 39 | 40 | ![Gradient histograms](imgs/log_logits.png) 41 | 42 | 43 | ::: pytorch_tabular.config.ExperimentConfig 44 | options: 45 | show_root_heading: yes 46 | -------------------------------------------------------------------------------- /docs/explainability.md: -------------------------------------------------------------------------------- 1 | The explainability features in PyTorch Tabular allow users to interpret and understand the predictions made by a tabular deep learning model. These features provide insights into the model's decision-making process and help identify the most influential features. Some of the explainability features are inbuilt from the models, and a lot of others are based on the [Captum](https://captum.ai/) library. 2 | 3 | ## Native Feature Importance 4 | One of the features of the GBDT models which everybody loves is the feature importance. It helps us understand which features are the most important for the model. PyTorch Tabular provides a similar feature for some of the models - GANDALF, GATE, and FTTransformers - where the models natively support the extraction of feature importance. 5 | 6 | ``` python 7 | # tabular_model is the trained model of a supported model 8 | tabular_model.feature_importance() 9 | ``` 10 | 11 | ## Local Feature Attributions/Explanations 12 | Local feature attributions/explanations help us understand the contribution of each feature towards the prediction for a particular sample. PyTorch Tabular provides this feature for all the models except TabTransformer, Tabnet, and Mixed Density Networks. It is based on the [Captum](https://captum.ai/) library. The library provides a lot of algorithms for computing feature attributions. PyTorch Tabular provides a wrapper around the library to make it easy to use. The following algorithms are supported: 13 | 14 | - GradientShap: https://captum.ai/api/gradient_shap.html 15 | - IntegratedGradients: https://captum.ai/api/integrated_gradients.html 16 | - DeepLift: https://captum.ai/api/deep_lift.html 17 | - DeepLiftShap: https://captum.ai/api/deep_lift_shap.html 18 | - InputXGradient: https://captum.ai/api/input_x_gradient.html 19 | - FeaturePermutation: https://captum.ai/api/feature_permutation.html 20 | - FeatureAblation: https://captum.ai/api/feature_ablation.html 21 | - KernelShap: https://captum.ai/api/kernel_shap.html 22 | 23 | `PyTorch Tabular` also supports explaining single instances as well as batches of instances. But, larger datasets will take longer to explain. An exception is the `FeaturePermutation` and `FeatureAblation` methods, which is only meaningful for large batches of instances. 24 | 25 | Most of these explainability methods require a baseline. This is used to compare the attributions of the input with the attributions of the baseline. The baseline can be a scalar value, a tensor of the same shape as the input, or a special string like "b|10000" which means 10000 samples from the training data. If the baseline is not provided, the default baseline (zero) is used. 26 | 27 | ``` python 28 | # tabular_model is the trained model of a supported model 29 | 30 | # Explain a single instance using the GradientShap method and baseline as 10000 samples from the training data 31 | tabular_model.explain(test.head(1), method="GradientShap", baselines="b|10000") 32 | 33 | # Explain a batch of instances using the IntegratedGradients method and baseline as 0 34 | tabular_model.explain(test.head(10), method="IntegratedGradients", baselines=0) 35 | ``` 36 | 37 | Checkout the [Captum documentation](https://captum.ai/docs/algorithms) for more details on the algorithms and the [Explainability Tutorial](tutorials/14-Explainability.ipynb) for example usage. 38 | 39 | ## API Reference 40 | ::: pytorch_tabular.TabularModel.explain 41 | options: 42 | show_root_heading: yes 43 | heading_level: 4 44 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/faq.md -------------------------------------------------------------------------------- /docs/gs_cite.md: -------------------------------------------------------------------------------- 1 | If you use PyTorch Tabular for a scientific publication, we would appreciate citations to the published software and the following paper: 2 | 3 | - [arxiv Paper](https://arxiv.org/abs/2104.13638) 4 | 5 | ``` 6 | @misc{joseph2021pytorch, 7 | title={PyTorch Tabular: A Framework for Deep Learning with Tabular Data}, 8 | author={Manu Joseph}, 9 | year={2021}, 10 | eprint={2104.13638}, 11 | archivePrefix={arXiv}, 12 | primaryClass={cs.LG} 13 | } 14 | ``` 15 | 16 | - Zenodo Software Citation 17 | 18 | ``` 19 | @article{manujosephv_2021, 20 | title={manujosephv/pytorch_tabular: v0.5.0-alpha}, 21 | DOI={10.5281/zenodo.4732773}, 22 | abstractNote={

First Alpha Release

}, 23 | publisher={Zenodo}, 24 | author={manujosephv}, 25 | year={2021}, 26 | month={May} 27 | } 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/gs_installation.md: -------------------------------------------------------------------------------- 1 | !!! note 2 | 3 | Although the installation includes PyTorch, the best and recommended way is to first install PyTorch from [here](https://pytorch.org/get-started/locally/), picking up the right CUDA version for your machine. (PyTorch Version >1.3) 4 | 5 | Once, you have got PyTorch installed and working, just use: 6 | 7 | ```bash 8 | pip install "pytorch_tabular[extra]" 9 | ``` 10 | 11 | to install the complete library with extra dependencies: 12 | 13 | - Weights&Biases for experiment tracking 14 | - Plotly for some visualization 15 | - Captum for Interpretability 16 | 17 | And : 18 | 19 | ``` bash 20 | pip install "pytorch_tabular" 21 | ``` 22 | 23 | for the bare essentials. 24 | 25 | The sources for `pytorch_tabular` can be downloaded from the Github repo. 26 | 27 | You can clone the public repository: 28 | 29 | ``` bash 30 | git clone git://github.com/manujosephv/pytorch_tabular 31 | ``` 32 | 33 | Once you have a copy of the source, you can install it with: 34 | 35 | ``` bash 36 | pip install . 37 | ``` 38 | 39 | or 40 | 41 | ``` bash 42 | python setup.py install 43 | ``` 44 | -------------------------------------------------------------------------------- /docs/gs_usage.md: -------------------------------------------------------------------------------- 1 | PyTorch Tabular comes with intelligent defaults that make it easy to get started with tabular deep learning. However, it also provides the flexibility to customize the model and pipeline to suit your needs. 2 | 3 | Here is a simple example of how to use PyTorch Tabular to train a model, evaluate on new data, generate predictions, and save and load the model. 4 | 5 | ```python 6 | from pytorch_tabular import TabularModel 7 | from pytorch_tabular.models import CategoryEmbeddingModelConfig 8 | from pytorch_tabular.config import ( 9 | DataConfig, 10 | OptimizerConfig, 11 | TrainerConfig, 12 | ) 13 | 14 | data_config = DataConfig( 15 | target=[ 16 | "target" 17 | ], # target should always be a list. 18 | continuous_cols=num_col_names, 19 | categorical_cols=cat_col_names, 20 | ) 21 | trainer_config = TrainerConfig( 22 | auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate 23 | batch_size=1024, 24 | max_epochs=100, 25 | ) 26 | optimizer_config = OptimizerConfig() 27 | 28 | model_config = CategoryEmbeddingModelConfig( 29 | task="classification", 30 | layers="1024-512-512", # Number of nodes in each layer 31 | activation="LeakyReLU", # Activation between each layers 32 | learning_rate=1e-3, 33 | ) 34 | 35 | tabular_model = TabularModel( 36 | data_config=data_config, 37 | model_config=model_config, 38 | optimizer_config=optimizer_config, 39 | trainer_config=trainer_config, 40 | ) 41 | tabular_model.fit(train=train, validation=val) 42 | result = tabular_model.evaluate(test) 43 | pred_df = tabular_model.predict(test) 44 | tabular_model.save_model("examples/basic") 45 | loaded_model = TabularModel.load_model("examples/basic") 46 | ``` 47 | 48 | For more detailed tutorials and how-to guides refer to the **Tutorials** and **How-To Guides** sections. 49 | -------------------------------------------------------------------------------- /docs/imgs/auto_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/auto_encoder.png -------------------------------------------------------------------------------- /docs/imgs/diataxis.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/diataxis.webp -------------------------------------------------------------------------------- /docs/imgs/dndt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/dndt.png -------------------------------------------------------------------------------- /docs/imgs/gflu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/gflu.png -------------------------------------------------------------------------------- /docs/imgs/gflu_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/gflu_v2.png -------------------------------------------------------------------------------- /docs/imgs/gradient_histograms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/gradient_histograms.png -------------------------------------------------------------------------------- /docs/imgs/gradient_norms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/gradient_norms.png -------------------------------------------------------------------------------- /docs/imgs/log_logits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/log_logits.png -------------------------------------------------------------------------------- /docs/imgs/model_stacking_concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/model_stacking_concept.png -------------------------------------------------------------------------------- /docs/imgs/node_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/node_arch.png -------------------------------------------------------------------------------- /docs/imgs/node_dense_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/node_dense_arch.png -------------------------------------------------------------------------------- /docs/imgs/pytorch_tabular_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/pytorch_tabular_logo.png -------------------------------------------------------------------------------- /docs/imgs/pytorch_tabular_logo_inv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/pytorch_tabular_logo_inv.png -------------------------------------------------------------------------------- /docs/imgs/pytorch_tabular_logo_small_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/pytorch_tabular_logo_small_2.png -------------------------------------------------------------------------------- /docs/imgs/tabnet_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/imgs/tabnet_architecture.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ![PyTorch Tabular](imgs/pytorch_tabular_logo.png#only-light) 2 | ![PyTorch Tabular](imgs/pytorch_tabular_logo_inv.png#only-dark) 3 | 4 | [![pypi](https://img.shields.io/pypi/v/pytorch_tabular.svg)](https://pypi.python.org/pypi/pytorch_tabular) 5 | [![Testing](https://github.com/manujosephv/pytorch_tabular/actions/workflows/testing.yml/badge.svg?event=push)](https://github.com/manujosephv/pytorch_tabular/actions/workflows/testing.yml) 6 | [![documentation status](https://readthedocs.org/projects/pytorch_tabular/badge/?version=latest)](https://pytorch_tabular.readthedocs.io/en/latest/?badge=latest) 7 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/manujosephv/pytorch_tabular/main.svg)](https://results.pre-commit.ci/latest/github/manujosephv/pytorch_tabular/main) 8 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/pytorch_tabular) 9 | [![DOI](https://zenodo.org/badge/321584367.svg)](https://zenodo.org/badge/latestdoi/321584367) 10 | [![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat-square)](https://github.com/manujosephv/pytorch_tabular/issues) 11 | 12 | 13 | **PyTorch Tabular** is a powerful library that aims to simplify and popularize the application of deep learning techniques to tabular data. Tabular deep learning has gained significant importance in the field of machine learning due to its ability to handle structured data, such as data in spreadsheets or databases. However, working with tabular data can be challenging, requiring expertise in both deep learning and data preprocessing. 14 | 15 | This is where **PyTorch Tabular** comes in. Built on the shoulders of giants like `PyTorch`, `PyTorch Lightning`, and `pandas`, PyTorch Tabular offers a **low resistance usability**, making it accessible to both real-world use cases and research projects. The library's core principles revolve around **easy customization**, allowing users to tailor their models and pipelines to specific requirements. Moreover, PyTorch Tabular provides **scalable and efficient tooling**, making it easier to deploy models in production environments. The underlying goodness of `PyTorch` makes designing deep learning architectures pythonic and intuitive, while `PyTorch Lightning` simplifies the training process. `pandas` is the de-facto standard for working with tabular data, and PyTorch Tabular leverages its strengths to simplify the preprocessing of tabular data. With PyTorch Tabular, data scientists and researchers can focus on the core aspects of their work, while the library takes care of the underlying complexities, enabling efficient and effective tabular deep learning. 16 | 17 | The documentation is organized taking inspiration from the Diátaxis system of documentation. 18 | 19 | > Diátaxis is a way of thinking about and doing documentation. Diátaxis identifies four distinct needs, and four corresponding forms of documentation - tutorials, how-to guides, technical reference and explanation. It places them in a systematic relationship, and proposes that documentation should itself be organised around the structures of those needs. Diátaxis solves problems related to documentation content (what to write), style (how to write it) and architecture (how to organise it). It is a system for thinking about documentation, and a system for doing documentation. - [Diátaxis](https://diataxis.fr/) 20 | 21 | ![Diátaxis System of Documentation](imgs/diataxis.webp) 22 | 23 | Taking cues from the system, the documentation is separated into five sections: 24 | 25 | - **Getting Started** - A quick introduction on how to install and get started with PyTorch Tabular. 26 | 27 | - **Tutorials** - Short and focused exercises to get you going quickly. 28 | 29 | - **How-to Guides** - Step-by-step guides to covering key tasks, real world operations and common problems. 30 | 31 | - **Concepts** - Explanations of some of the larger concepts and intricacies of the library. 32 | 33 | - **API Reference** - The technical details of the library: all classes and functions, along with their parameters and return types. 34 | -------------------------------------------------------------------------------- /docs/optimizer.md: -------------------------------------------------------------------------------- 1 | The Optimizer is at the heart of the Gradient Descent process and is a key component that we need to train a good model. Pytorch Tabular uses `Adam` optimizer with a learning rate of `1e-3` by default. This is mainly because of a rule of thumb which provides a good starting point. 2 | 3 | Sometimes, Learning Rate Schedulers let's you have finer control in the way the learning rates are used through the optimization process. By default, PyTorch Tabular applies no Learning Rate Scheduler. 4 | 5 | ## Basic Usage 6 | 7 | - `optimizer`: str: Any of the standard optimizers from [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms). Defaults to `Adam` 8 | - `optimizer_params`: Dict: The parameters for the optimizer. If left blank, will use default parameters. 9 | - `lr_scheduler`: str: The name of the LearningRateScheduler to use, if any, from [torch.optim.lr_scheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate). If None, will not use any scheduler. Defaults to `None` 10 | - `lr_scheduler_params`: Dict: The parameters for the LearningRateScheduler. If left blank, will use default parameters. 11 | - `lr_scheduler_monitor_metric`: str: Used with ReduceLROnPlateau, where the plateau is decided based on this metric. Defaults to `val_loss` 12 | 13 | ### Usage Example 14 | 15 | ```python 16 | optimizer_config = OptimizerConfig( 17 | optimizer="RMSprop", lr_scheduler="StepLR", lr_scheduler_params={"step_size": 10} 18 | ) 19 | ``` 20 | 21 | ## Advanced Usage 22 | 23 | While the Config object restricts you to the standard Optimizers and Learning Rate Schedulers in `torch.optim`, you can use any custom Optimizer or Learning Rate Scheduler, as long as they are drop-in replacements for standard ones. You can do this using the `fit` method of `TabularModel`, which allows you to override the optimizer and learning rate which is set through config. 24 | 25 | ### Usage Example 26 | 27 | ```python 28 | from torch_optimizer import QHAdam 29 | 30 | tabular_model.fit( 31 | train=train, 32 | validation=val, 33 | optimizer=QHAdam, 34 | optimizer_params={"nus": (0.7, 1.0), "betas": (0.95, 0.998)}, 35 | ) 36 | ``` 37 | 38 | 39 | ::: pytorch_tabular.config.OptimizerConfig 40 | options: 41 | show_root_heading: yes 42 | -------------------------------------------------------------------------------- /docs/other_features.md: -------------------------------------------------------------------------------- 1 | Apart from training and using Deep Networks for tabular data, PyTorch Tabular also has some cool features which can help your classical ML/ sci-kit learn pipelines 2 | 3 | ## Categorical Embeddings 4 | 5 | The CategoryEmbedding Model can also be used as a way to encode your categorical columns. instead of using a One-hot encoder or a variant of TargetMean Encoding, you can use a learned embedding to encode your categorical features. And all this can be done using a scikit-learn style Transformer. 6 | 7 | ### Usage Example 8 | 9 | ```python 10 | # passing the trained model as an argument 11 | transformer = CategoricalEmbeddingTransformer(tabular_model) 12 | # passing the train dataframe to extract the embeddings and replace categorical features 13 | # defined in the trained tabular_model 14 | train_transformed = transformer.fit_transform(train) 15 | # using the extracted embeddings on new dataframe 16 | val_transformed = transformer.transform(val) 17 | ``` 18 | 19 | ::: pytorch_tabular.categorical_encoders.CategoricalEmbeddingTransformer 20 | options: 21 | show_root_heading: yes 22 | ## Feature Extractor 23 | 24 | What if you want to use the features learnt by the Neural Network in your ML model? Pytorch Tabular let's you do that as well, and with ease. Again, a scikit-learn style Transformer does the job for you. 25 | 26 | ### Usage Example 27 | ```python 28 | # passing the trained model as an argument 29 | dt = DeepFeatureExtractor(tabular_model) 30 | # passing the train dataframe to extract the last layer features 31 | # here `fit` is there only for compatibility and does not do anything 32 | enc_df = dt.fit_transform(train) 33 | # using the extracted embeddings on new dataframe 34 | val_transformed = transformer.transform(val) 35 | ``` 36 | 37 | ::: pytorch_tabular.feature_extractor.DeepFeatureExtractor 38 | options: 39 | show_root_heading: yes 40 | -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_eq_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_eq_1.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_eq_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_eq_2.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_fig_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_fig_1.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_fig_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_fig_2.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_fig_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_fig_3.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_hist_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_hist_1.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_hist_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_hist_2.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_hist_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_hist_3.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_hist_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_hist_4.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_mdn_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_mdn_1.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_mdn_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_mdn_2.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_mdn_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_mdn_3.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_mixing12_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_mixing12_3.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_mixing1_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_mixing1_3.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_mixing2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_mixing2_3.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_non_mdn_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_non_mdn_2.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_non_mdn_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_non_mdn_3.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/prob_reg_pdfs_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/prob_reg_pdfs_4.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/tensorboard_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/tensorboard_example.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/wandb_preview_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/wandb_preview_1.png -------------------------------------------------------------------------------- /docs/tutorials/imgs/wandb_preview_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/docs/tutorials/imgs/wandb_preview_2.png -------------------------------------------------------------------------------- /docs/tutorials/trainer_config.yml: -------------------------------------------------------------------------------- 1 | batch_size: 1024 2 | fast_dev_run: false 3 | max_epochs: 20 4 | min_epochs: 1 5 | accelerator: "auto" 6 | devices: -1 7 | accumulate_grad_batches: 1 8 | auto_lr_find: true 9 | check_val_every_n_epoch: 1 10 | gradient_clip_val: 0.0 11 | overfit_batches: 0.0 12 | profiler: null 13 | early_stopping: null 14 | early_stopping_min_delta: 0.001 15 | early_stopping_mode: min 16 | early_stopping_patience: 3 17 | checkpoints: valid_loss 18 | checkpoints_path: saved_models 19 | checkpoints_mode: min 20 | checkpoints_save_top_k: 1 21 | load_best: true 22 | track_grad_norm: -1 23 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | These are a few example scripts. For more examples and tutorials, check [documentation](https://pytorch-tabular.readthedocs.io/en/stable/tutorials/01-Basic_Usage/) 2 | 3 | Ignore the files in `__only_for_dev__` folder. They are only for development purposes and may not work the right way. 4 | -------------------------------------------------------------------------------- /examples/__only_for_dev__/adhoc_scaffold.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pandas as pd 4 | 5 | # os.chdir("..") 6 | from sklearn.datasets import make_classification 7 | from sklearn.metrics import accuracy_score, f1_score 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | def make_mixed_classification(n_samples, n_features, n_categories): 12 | X, y = make_classification(n_samples=n_samples, n_features=n_features, random_state=42, n_informative=5) 13 | cat_cols = random.choices(list(range(X.shape[-1])), k=n_categories) 14 | num_cols = [i for i in range(X.shape[-1]) if i not in cat_cols] 15 | for col in cat_cols: 16 | X[:, col] = pd.qcut(X[:, col], q=4).codes.astype(int) 17 | col_names = [] 18 | num_col_names = [] 19 | cat_col_names = [] 20 | for i in range(X.shape[-1]): 21 | if i in cat_cols: 22 | col_names.append(f"cat_col_{i}") 23 | cat_col_names.append(f"cat_col_{i}") 24 | if i in num_cols: 25 | col_names.append(f"num_col_{i}") 26 | num_col_names.append(f"num_col_{i}") 27 | X = pd.DataFrame(X, columns=col_names) 28 | y = pd.Series(y, name="target") 29 | data = X.join(y) 30 | return data, cat_col_names, num_col_names 31 | 32 | 33 | def print_metrics(y_true, y_pred, tag): 34 | if isinstance(y_true, pd.DataFrame) or isinstance(y_true, pd.Series): 35 | y_true = y_true.values 36 | if isinstance(y_pred, pd.DataFrame) or isinstance(y_pred, pd.Series): 37 | y_pred = y_pred.values 38 | if y_true.ndim > 1: 39 | y_true = y_true.ravel() 40 | if y_pred.ndim > 1: 41 | y_pred = y_pred.ravel() 42 | val_acc = accuracy_score(y_true, y_pred) 43 | val_f1 = f1_score(y_true, y_pred) 44 | print(f"{tag} Acc: {val_acc} | {tag} F1: {val_f1}") 45 | 46 | 47 | data, cat_col_names, num_col_names = make_mixed_classification(n_samples=10000, n_features=20, n_categories=4) 48 | train, test = train_test_split(data, random_state=42) 49 | train, val = train_test_split(train, random_state=42) 50 | 51 | from pytorch_tabular import TabularModel # noqa: E402 52 | from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig # noqa: E402 53 | from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig # noqa: E402 54 | 55 | data_config = DataConfig( 56 | # target should always be a list. 57 | target=["target"], 58 | continuous_cols=num_col_names, 59 | categorical_cols=cat_col_names, 60 | continuous_feature_transform="quantile_normal", 61 | normalize_continuous_features=True, 62 | ) 63 | trainer_config = TrainerConfig( 64 | auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate 65 | batch_size=32, 66 | max_epochs=5, 67 | # fast_dev_run=True, 68 | # profiler="simple", 69 | early_stopping=None, 70 | checkpoints=None, 71 | trainer_kwargs={"limit_train_batches": 10}, 72 | ) 73 | optimizer_config = OptimizerConfig() 74 | model_config = GatedAdditiveTreeEnsembleConfig( 75 | task="classification", 76 | gflu_stages=3, 77 | num_trees=0, 78 | tree_depth=2, 79 | binning_activation="sigmoid", 80 | feature_mask_function="t-softmax", 81 | # layers="4096-4096-512", # Number of nodes in each layer 82 | # activation="LeakyReLU", # Activation between each layers 83 | learning_rate=1e-3, 84 | metrics=["auroc"], 85 | metrics_prob_input=[True], 86 | ) 87 | tabular_model = TabularModel( 88 | data_config=data_config, 89 | model_config=model_config, 90 | optimizer_config=optimizer_config, 91 | trainer_config=trainer_config, 92 | ) 93 | 94 | tabular_model.fit(train=train, validation=val) 95 | # test.drop(columns=["target"], inplace=True) 96 | # pred_df = tabular_model.predict(test) 97 | # pred_df = tabular_model.predict(test, device="cpu") 98 | # pred_df = tabular_model.predict(test, device="cuda") 99 | # import torch 100 | 101 | # pred_df = tabular_model.predict(test, device=torch.device("cuda")) 102 | # tabular_model.fit(train=train, validation=val) 103 | # tabular_model.fit(train=train, validation=val, max_epochs=5) 104 | # tabular_model.fit(train=train, validation=val, max_epochs=5, reset=True) 105 | 106 | 107 | # t = torch.rand(128,200) 108 | # a = t.numpy() 109 | 110 | # start = time.time() 111 | # t.median(dim=-1) 112 | # end = time.time() 113 | # print("torch median", end - start) 114 | 115 | # start = time.time() 116 | # t.quantile(torch.rand(128), dim=-1) 117 | # end = time.time() 118 | # print("torch quant ", end - start) 119 | 120 | # start = time.time() 121 | # np.median(t.numpy(), axis=-1) 122 | # end = time.time() 123 | # print("numpy median", end - start) 124 | 125 | # start = time.time() 126 | # np.quantile(t.numpy(), np.random.rand(128), axis=-1) 127 | # end = time.time() 128 | # print("numpy quant ", end - start) 129 | 130 | # start = time.time() 131 | # st = torch.sort(t, dim=-1) 132 | # end = time.time() 133 | # print("torch sort", end - start) 134 | 135 | # start = time.time() 136 | # st = np.sort(t.numpy(), axis=-1) 137 | # end = time.time() 138 | # print("numpy sort", end - start) 139 | -------------------------------------------------------------------------------- /examples/__only_for_dev__/to_test_node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Tests for `pytorch_tabular` package.""" 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.datasets import fetch_california_housing, fetch_covtype 7 | 8 | from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig 9 | from pytorch_tabular.models.node import NodeConfig 10 | from pytorch_tabular.tabular_model import TabularModel 11 | 12 | 13 | def regression_data(): 14 | dataset = fetch_california_housing(data_home="data", as_frame=True) 15 | df = dataset.frame.sample(5000) 16 | df["HouseAgeBin"] = pd.qcut(df["HouseAge"], q=4) 17 | df["HouseAgeBin"] = "age_" + df.HouseAgeBin.cat.codes.astype(str) 18 | test_idx = df.sample(int(0.2 * len(df)), random_state=42).index 19 | test = df[df.index.isin(test_idx)] 20 | train = df[~df.index.isin(test_idx)] 21 | return (train, test, dataset.target_names) 22 | 23 | 24 | def classification_data(): 25 | dataset = fetch_covtype(data_home="data") 26 | data = np.hstack([dataset.data, dataset.target.reshape(-1, 1)])[:10000, :] 27 | col_names = [f"feature_{i}" for i in range(data.shape[-1])] 28 | col_names[-1] = "target" 29 | data = pd.DataFrame(data, columns=col_names) 30 | data["feature_0_cat"] = pd.qcut(data["feature_0"], q=4) 31 | data["feature_0_cat"] = "feature_0_" + data.feature_0_cat.cat.codes.astype(str) 32 | test_idx = data.sample(int(0.2 * len(data)), random_state=42).index 33 | test = data[data.index.isin(test_idx)] 34 | train = data[~data.index.isin(test_idx)] 35 | return (train, test, ["target"]) 36 | 37 | 38 | def test_regression( 39 | regression_data, 40 | multi_target, 41 | continuous_cols, 42 | categorical_cols, 43 | continuous_feature_transform, 44 | normalize_continuous_features, 45 | ): 46 | (train, test, target) = regression_data 47 | if len(continuous_cols) + len(categorical_cols) == 0: 48 | assert True 49 | else: 50 | data_config = DataConfig( 51 | target=target + ["MedInc"] if multi_target else target, 52 | continuous_cols=continuous_cols, 53 | categorical_cols=categorical_cols, 54 | continuous_feature_transform=continuous_feature_transform, 55 | normalize_continuous_features=normalize_continuous_features, 56 | ) 57 | model_config_params = {"task": "regression", "depth": 2} 58 | model_config = NodeConfig(**model_config_params) 59 | # model_config_params = dict(task="regression") 60 | # model_config = NodeConfig(**model_config_params) 61 | 62 | trainer_config = TrainerConfig(max_epochs=1, checkpoints=None, early_stopping=None) 63 | optimizer_config = OptimizerConfig() 64 | 65 | tabular_model = TabularModel( 66 | data_config=data_config, 67 | model_config=model_config, 68 | optimizer_config=optimizer_config, 69 | trainer_config=trainer_config, 70 | ) 71 | tabular_model.fit(train=train) 72 | 73 | result = tabular_model.evaluate(test) 74 | if multi_target: 75 | assert result[0]["valid_loss"] < 30 76 | else: 77 | assert result[0]["valid_loss"] < 8 78 | pred_df = tabular_model.predict(test) 79 | assert pred_df.shape[0] == test.shape[0] 80 | 81 | 82 | def test_classification( 83 | classification_data, 84 | continuous_cols, 85 | categorical_cols, 86 | continuous_feature_transform, 87 | normalize_continuous_features, 88 | ): 89 | (train, test, target) = classification_data 90 | if len(continuous_cols) + len(categorical_cols) == 0: 91 | return 92 | data_config = DataConfig( 93 | target=target, 94 | continuous_cols=continuous_cols, 95 | categorical_cols=categorical_cols, 96 | continuous_feature_transform=continuous_feature_transform, 97 | normalize_continuous_features=normalize_continuous_features, 98 | ) 99 | model_config_params = {"task": "classification", "depth": 2} 100 | model_config = NodeConfig(**model_config_params) 101 | trainer_config = TrainerConfig(max_epochs=1, checkpoints=None, early_stopping=None) 102 | optimizer_config = OptimizerConfig() 103 | 104 | tabular_model = TabularModel( 105 | data_config=data_config, 106 | model_config=model_config, 107 | optimizer_config=optimizer_config, 108 | trainer_config=trainer_config, 109 | ) 110 | tabular_model.fit(train=train) 111 | 112 | result = tabular_model.evaluate(test) 113 | assert result[0]["valid_loss"] < 2.5 114 | pred_df = tabular_model.predict(test) 115 | assert pred_df.shape[0] == test.shape[0] 116 | 117 | 118 | test_regression( 119 | regression_data(), 120 | multi_target=False, 121 | continuous_cols=[ 122 | "AveRooms", 123 | "AveBedrms", 124 | "Population", 125 | "AveOccup", 126 | "Latitude", 127 | "Longitude", 128 | ], 129 | categorical_cols=["HouseAgeBin"], 130 | continuous_feature_transform=None, 131 | normalize_continuous_features=True, 132 | # target_range=True, 133 | ) 134 | 135 | # classification_data() 136 | -------------------------------------------------------------------------------- /examples/__only_for_dev__/to_test_regression.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | from sklearn.datasets import fetch_california_housing 4 | 5 | from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig 6 | from pytorch_tabular.models.category_embedding.config import CategoryEmbeddingModelConfig 7 | from pytorch_tabular.tabular_model import TabularModel 8 | 9 | # from pytorch_tabular.models.mixture_density import ( 10 | # CategoryEmbeddingMDNConfig, 11 | # MixtureDensityHeadConfig, 12 | # NODEMDNConfig, 13 | 14 | 15 | dataset = fetch_california_housing(data_home="data", as_frame=True) 16 | dataset.frame["HouseAgeBin"] = pd.qcut(dataset.frame["HouseAge"], q=4) 17 | dataset.frame.HouseAgeBin = "age_" + dataset.frame.HouseAgeBin.cat.codes.astype(str) 18 | dataset.frame["AveRoomsBin"] = pd.qcut(dataset.frame["AveRooms"], q=3) 19 | dataset.frame.AveRoomsBin = "av_rm_" + dataset.frame.AveRoomsBin.cat.codes.astype(str) 20 | 21 | test_idx = dataset.frame.sample(int(0.2 * len(dataset.frame)), random_state=42).index 22 | test = dataset.frame[dataset.frame.index.isin(test_idx)] 23 | train = dataset.frame[~dataset.frame.index.isin(test_idx)] 24 | 25 | data_config = DataConfig( 26 | target=dataset.target_names, 27 | continuous_cols=[ 28 | "AveRooms", 29 | "AveBedrms", 30 | "Population", 31 | "AveOccup", 32 | "Latitude", 33 | "Longitude", 34 | ], 35 | # continuous_cols=[], 36 | categorical_cols=["HouseAgeBin", "AveRoomsBin"], 37 | continuous_feature_transform=None, # "yeo-johnson", 38 | normalize_continuous_features=True, 39 | ) 40 | 41 | # mdn_config = dict(num_gaussian=2) 42 | # model_config = MDNConfig( 43 | # task="regression", 44 | # backbone_config_class = "CategoryEmbeddingModelConfig", 45 | # backbone_config_params = dict(task="backbone"), #TODO Add backbone task 46 | # head_config = mdn_config 47 | # ) 48 | # # model_config.validate() 49 | # model_config = CategoryEmbeddingModelConfig(task="regression") 50 | # model_config = AutoIntConfig( 51 | # task="regression", 52 | # deep_layers=True, 53 | # embedding_dropout=0.2, 54 | # batch_norm_continuous_input=True, 55 | # attention_pooling=True, 56 | # ) 57 | model_config = CategoryEmbeddingModelConfig(task="regression", dropout=0.2, head_config={"layers": "32-16"}) 58 | 59 | trainer_config = TrainerConfig( 60 | # checkpoints=None, 61 | max_epochs=2, 62 | profiler=None, 63 | fast_dev_run=False, 64 | auto_lr_find=True, 65 | ) 66 | # experiment_config = ExperimentConfig( 67 | # project_name="DeepGMM_test", 68 | # run_name="wand_debug", 69 | # log_target="wandb", 70 | # exp_watch="gradients", 71 | # log_logits=True 72 | # ) 73 | optimizer_config = OptimizerConfig() 74 | 75 | 76 | def fake_metric(y_hat, y): 77 | return (y_hat - y).mean() 78 | 79 | 80 | from sklearn.preprocessing import PowerTransformer # noqa: E402 81 | 82 | tr = PowerTransformer() 83 | tabular_model = TabularModel( 84 | data_config=data_config, 85 | model_config=model_config, 86 | optimizer_config=optimizer_config, 87 | trainer_config=trainer_config, 88 | # experiment_config=experiment_config, 89 | ) 90 | tabular_model.fit( 91 | train=train, 92 | test=test, 93 | metrics=[fake_metric], 94 | metrics_prob_inputs=[False], 95 | target_transform=tr, 96 | loss=torch.nn.L1Loss(), 97 | optimizer=torch.optim.Adagrad, 98 | optimizer_params={}, 99 | ) 100 | 101 | # dt = DeepFeatureExtractor(tabular_model) 102 | # enc_df = dt.fit_transform(test) 103 | # print(enc_df.head()) 104 | # tabular_model.save_model("examples/sample") 105 | result = tabular_model.evaluate(test) 106 | print(result) 107 | # # # print(result[0]['train_loss']) 108 | # new_mdl = TabularModel.load_model("examples/sample") 109 | # # TODO test none no test loader 110 | # result = new_mdl.evaluate(test) 111 | # print(result) 112 | # tabular_model.fit( 113 | # train=train, test=test, metrics=[fake_metric], target_transform=tr, max_epochs=2 114 | # ) 115 | # pred_df = tabular_model.predict(test, quantiles=[0.25], ret_logits=True) 116 | # print(pred_df.head()) 117 | 118 | # pred_df.to_csv("output/temp2.csv") 119 | -------------------------------------------------------------------------------- /examples/covertype_classification.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import wget 5 | from sklearn.model_selection import train_test_split 6 | 7 | from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig 8 | from pytorch_tabular.models import CategoryEmbeddingModelConfig 9 | from pytorch_tabular.models.common.heads import LinearHeadConfig 10 | from pytorch_tabular.tabular_model import TabularModel 11 | 12 | BASE_DIR = Path.home().joinpath("data") 13 | datafile = BASE_DIR.joinpath("covtype.data.gz") 14 | datafile.parent.mkdir(parents=True, exist_ok=True) 15 | url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz" 16 | if not datafile.exists(): 17 | wget.download(url, datafile.as_posix()) 18 | 19 | target_name = ["Covertype"] 20 | 21 | cat_col_names = [ 22 | "Wilderness_Area1", 23 | "Wilderness_Area2", 24 | "Wilderness_Area3", 25 | "Wilderness_Area4", 26 | "Soil_Type1", 27 | "Soil_Type2", 28 | "Soil_Type3", 29 | "Soil_Type4", 30 | "Soil_Type5", 31 | "Soil_Type6", 32 | "Soil_Type7", 33 | "Soil_Type8", 34 | "Soil_Type9", 35 | "Soil_Type10", 36 | "Soil_Type11", 37 | "Soil_Type12", 38 | "Soil_Type13", 39 | "Soil_Type14", 40 | "Soil_Type15", 41 | "Soil_Type16", 42 | "Soil_Type17", 43 | "Soil_Type18", 44 | "Soil_Type19", 45 | "Soil_Type20", 46 | "Soil_Type21", 47 | "Soil_Type22", 48 | "Soil_Type23", 49 | "Soil_Type24", 50 | "Soil_Type25", 51 | "Soil_Type26", 52 | "Soil_Type27", 53 | "Soil_Type28", 54 | "Soil_Type29", 55 | "Soil_Type30", 56 | "Soil_Type31", 57 | "Soil_Type32", 58 | "Soil_Type33", 59 | "Soil_Type34", 60 | "Soil_Type35", 61 | "Soil_Type36", 62 | "Soil_Type37", 63 | "Soil_Type38", 64 | "Soil_Type39", 65 | "Soil_Type40", 66 | ] 67 | 68 | num_col_names = [ 69 | "Elevation", 70 | "Aspect", 71 | "Slope", 72 | "Horizontal_Distance_To_Hydrology", 73 | "Vertical_Distance_To_Hydrology", 74 | "Horizontal_Distance_To_Roadways", 75 | "Hillshade_9am", 76 | "Hillshade_Noon", 77 | "Hillshade_3pm", 78 | "Horizontal_Distance_To_Fire_Points", 79 | ] 80 | 81 | feature_columns = num_col_names + cat_col_names + target_name 82 | 83 | df = pd.read_csv(datafile, header=None, names=feature_columns) 84 | train, test = train_test_split(df, random_state=42) 85 | train, val = train_test_split(train, random_state=42) 86 | num_classes = len(set(train[target_name].values.ravel())) 87 | 88 | data_config = DataConfig( 89 | target=target_name, 90 | continuous_cols=num_col_names, 91 | categorical_cols=cat_col_names, 92 | continuous_feature_transform=None, # "quantile_normal", 93 | normalize_continuous_features=True, 94 | ) 95 | head_config = LinearHeadConfig( 96 | layers="", 97 | dropout=0.1, 98 | initialization="kaiming", # No additional layer in head, just a mapping layer to output_dim 99 | ).__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects) 100 | model_config = CategoryEmbeddingModelConfig( 101 | task="classification", 102 | metrics=["f1_score", "accuracy"], 103 | metrics_params=[{"num_classes": num_classes}, {}], 104 | metrics_prob_input=[True, False], 105 | ) 106 | trainer_config = TrainerConfig(auto_lr_find=True, fast_dev_run=False, max_epochs=5, batch_size=512) 107 | optimizer_config = OptimizerConfig() 108 | tabular_model = TabularModel( 109 | data_config=data_config, 110 | model_config=model_config, 111 | optimizer_config=optimizer_config, 112 | trainer_config=trainer_config, 113 | ) 114 | tabular_model.fit( 115 | train=train, 116 | validation=val, 117 | ) 118 | 119 | pred_df = tabular_model.predict(test) 120 | print(pred_df.head()) 121 | tabular_model.save_model("examples/test_save") 122 | -------------------------------------------------------------------------------- /examples/covertype_classification_using_yaml.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import wget 5 | from sklearn.model_selection import train_test_split 6 | 7 | from pytorch_tabular.tabular_model import TabularModel 8 | 9 | BASE_DIR = Path.home().joinpath("data") 10 | datafile = BASE_DIR.joinpath("covtype.data.gz") 11 | datafile.parent.mkdir(parents=True, exist_ok=True) 12 | url = "https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz" 13 | if not datafile.exists(): 14 | wget.download(url, datafile.as_posix()) 15 | 16 | target_name = ["Covertype"] 17 | 18 | cat_col_names = [ 19 | "Wilderness_Area1", 20 | "Wilderness_Area2", 21 | "Wilderness_Area3", 22 | "Wilderness_Area4", 23 | "Soil_Type1", 24 | "Soil_Type2", 25 | "Soil_Type3", 26 | "Soil_Type4", 27 | "Soil_Type5", 28 | "Soil_Type6", 29 | "Soil_Type7", 30 | "Soil_Type8", 31 | "Soil_Type9", 32 | "Soil_Type10", 33 | "Soil_Type11", 34 | "Soil_Type12", 35 | "Soil_Type13", 36 | "Soil_Type14", 37 | "Soil_Type15", 38 | "Soil_Type16", 39 | "Soil_Type17", 40 | "Soil_Type18", 41 | "Soil_Type19", 42 | "Soil_Type20", 43 | "Soil_Type21", 44 | "Soil_Type22", 45 | "Soil_Type23", 46 | "Soil_Type24", 47 | "Soil_Type25", 48 | "Soil_Type26", 49 | "Soil_Type27", 50 | "Soil_Type28", 51 | "Soil_Type29", 52 | "Soil_Type30", 53 | "Soil_Type31", 54 | "Soil_Type32", 55 | "Soil_Type33", 56 | "Soil_Type34", 57 | "Soil_Type35", 58 | "Soil_Type36", 59 | "Soil_Type37", 60 | "Soil_Type38", 61 | "Soil_Type39", 62 | "Soil_Type40", 63 | ] 64 | 65 | num_col_names = [ 66 | "Elevation", 67 | "Aspect", 68 | "Slope", 69 | "Horizontal_Distance_To_Hydrology", 70 | "Vertical_Distance_To_Hydrology", 71 | "Horizontal_Distance_To_Roadways", 72 | "Hillshade_9am", 73 | "Hillshade_Noon", 74 | "Hillshade_3pm", 75 | "Horizontal_Distance_To_Fire_Points", 76 | ] 77 | 78 | feature_columns = num_col_names + cat_col_names + target_name 79 | 80 | df = pd.read_csv(datafile, header=None, names=feature_columns) 81 | train, test = train_test_split(df, random_state=42) 82 | train, val = train_test_split(train, random_state=42) 83 | num_classes = len(set(train[target_name].values.ravel())) 84 | 85 | gate_lite = TabularModel( 86 | data_config="examples/yaml_config/data_config.yml", 87 | model_config="examples/yaml_config/gate_lite_model_config.yml", 88 | optimizer_config="examples/yaml_config/optimizer_config.yml", 89 | trainer_config="examples/yaml_config/trainer_config.yml", 90 | ) 91 | 92 | datamodule = gate_lite.prepare_dataloader(train=train, validation=val, seed=42) 93 | model = gate_lite.prepare_model(datamodule) 94 | 95 | gate_lite.train(model, datamodule) 96 | 97 | pred_df = gate_lite.predict(test, include_input_features=False) 98 | pred_df["Model"] = "Gate Lite" 99 | print(pred_df.head()) 100 | 101 | gate_full = TabularModel( 102 | data_config="examples/yaml_config/data_config.yml", 103 | model_config="examples/yaml_config/gate_full_model_config.yml", 104 | optimizer_config="examples/yaml_config/optimizer_config.yml", 105 | trainer_config="examples/yaml_config/trainer_config.yml", 106 | ) 107 | gate_full_model = gate_full.prepare_model(datamodule) 108 | gate_full.train(gate_full_model, datamodule) 109 | 110 | 111 | pred_df_ = gate_lite.predict(test, include_input_features=False) 112 | pred_df_["Model"] = "Gate Full" 113 | pred_df = pd.concat([pred_df, pred_df_]) 114 | print(pred_df_.head()) 115 | del pred_df_ 116 | -------------------------------------------------------------------------------- /examples/yaml_config/data_config.yml: -------------------------------------------------------------------------------- 1 | target: 2 | - Covertype 3 | continuous_cols: 4 | - "Elevation" 5 | - "Aspect" 6 | - "Slope" 7 | - "Horizontal_Distance_To_Hydrology" 8 | - "Vertical_Distance_To_Hydrology" 9 | - "Horizontal_Distance_To_Roadways" 10 | - "Hillshade_9am" 11 | - "Hillshade_Noon" 12 | - "Hillshade_3pm" 13 | - "Horizontal_Distance_To_Fire_Points" 14 | categorical_cols: 15 | - "Wilderness_Area1" 16 | - "Wilderness_Area2" 17 | - "Wilderness_Area3" 18 | - "Wilderness_Area4" 19 | - "Soil_Type1" 20 | - "Soil_Type2" 21 | - "Soil_Type3" 22 | - "Soil_Type4" 23 | - "Soil_Type5" 24 | - "Soil_Type6" 25 | - "Soil_Type7" 26 | - "Soil_Type8" 27 | - "Soil_Type9" 28 | - "Soil_Type10" 29 | - "Soil_Type11" 30 | - "Soil_Type12" 31 | - "Soil_Type13" 32 | - "Soil_Type14" 33 | - "Soil_Type15" 34 | - "Soil_Type16" 35 | - "Soil_Type17" 36 | - "Soil_Type18" 37 | - "Soil_Type19" 38 | - "Soil_Type20" 39 | - "Soil_Type21" 40 | - "Soil_Type22" 41 | - "Soil_Type23" 42 | - "Soil_Type24" 43 | - "Soil_Type25" 44 | - "Soil_Type26" 45 | - "Soil_Type27" 46 | - "Soil_Type28" 47 | - "Soil_Type29" 48 | - "Soil_Type30" 49 | - "Soil_Type31" 50 | - "Soil_Type32" 51 | - "Soil_Type33" 52 | - "Soil_Type34" 53 | - "Soil_Type35" 54 | - "Soil_Type36" 55 | - "Soil_Type37" 56 | - "Soil_Type38" 57 | - "Soil_Type39" 58 | - "Soil_Type40" 59 | date_cols: [] 60 | continuous_feature_transform: Null 61 | normalize_continuous_features: true 62 | -------------------------------------------------------------------------------- /examples/yaml_config/gate_full_model_config.yml: -------------------------------------------------------------------------------- 1 | task: classification 2 | gflu_stages: 6 3 | num_trees: 20 4 | tree_depth: 5 5 | chain_trees: True 6 | head: LinearHead 7 | head_config: 8 | layers: "" 9 | activation: LeakyRelu 10 | dropout: 0.1 11 | learning_rate: 0.001 12 | _module_src: models.gate 13 | _model_name: GatedAdditiveTreeEnsembleModel 14 | _backbone_name: GatedAdditiveTreesBackbone 15 | _config_name: GatedAdditiveTreeEnsembleConfig 16 | -------------------------------------------------------------------------------- /examples/yaml_config/gate_lite_mdn.yml: -------------------------------------------------------------------------------- 1 | task: classification 2 | gflu_stages: 4 3 | num_trees: 30 4 | tree_depth: 5 5 | chain_trees: False 6 | head: MixtureDensityHead 7 | head_config: 8 | num_gaussian: 4 9 | learning_rate: 0.001 10 | _module_src: models.gate 11 | _model_name: GatedAdditiveTreeEnsembleModel 12 | _backbone_name: GatedAdditiveTreesBackbone 13 | _config_name: GatedAdditiveTreeEnsembleConfig 14 | -------------------------------------------------------------------------------- /examples/yaml_config/gate_lite_model_config.yml: -------------------------------------------------------------------------------- 1 | task: classification 2 | gflu_stages: 4 3 | num_trees: 30 4 | tree_depth: 5 5 | chain_trees: False 6 | head: LinearHead 7 | head_config: 8 | layers: "" 9 | activation: LeakyRelu 10 | dropout: 0.1 11 | learning_rate: 0.001 12 | _module_src: models.gate 13 | _model_name: GatedAdditiveTreeEnsembleModel 14 | _backbone_name: GatedAdditiveTreesBackbone 15 | _config_name: GatedAdditiveTreeEnsembleConfig 16 | -------------------------------------------------------------------------------- /examples/yaml_config/optimizer_config.yml: -------------------------------------------------------------------------------- 1 | optimizer: AdamW 2 | optimizer_params: 3 | weight_decay: 1e-5 4 | lr_scheduler: CosineAnnealingWarmRestarts 5 | lr_scheduler_params: 6 | T_0: 100 7 | T_mult: 1 8 | lr_scheduler_monitor_metric: val_loss 9 | -------------------------------------------------------------------------------- /examples/yaml_config/trainer_config.yml: -------------------------------------------------------------------------------- 1 | batch_size: 512 2 | fast_dev_run: true 3 | max_epochs: 5 # Set this to more reasonable value in real 4 | accelerator: auto 5 | early_stopping: null 6 | checkpoints: valid_loss 7 | load_best: true 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | target-version = "py38" 3 | line-length = 120 4 | # Enable Pyflakes `E` and `F` codes by default. 5 | select = [ 6 | "E", "W", # see: https://pypi.org/project/pycodestyle 7 | "F", # see: https://pypi.org/project/pyflakes 8 | "I", # isort 9 | "UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up 10 | "RUF100", # yesqa 11 | # "D", # see: https://pypi.org/project/pydocstyle 12 | # "N", # see: https://pypi.org/project/pep8-naming 13 | ] 14 | extend-select = [ 15 | "C4", # see: https://pypi.org/project/flake8-comprehensions 16 | # "SIM", # see: https://pypi.org/project/flake8-simplify 17 | # "RET", # see: https://pypi.org/project/flake8-return 18 | # "PT", # see: https://pypi.org/project/flake8-pytest-style 19 | ] 20 | ignore = [ 21 | "E731", # Do not assign a lambda expression, use a def 22 | ] 23 | # Exclude a variety of commonly ignored directories. 24 | exclude = [ 25 | ".eggs", 26 | ".git", 27 | ".mypy_cache", 28 | ".ruff_cache", 29 | "__pypackages__", 30 | "_build", 31 | "build", 32 | "dist", 33 | "docs" 34 | ] 35 | ignore-init-module-imports = true 36 | 37 | [tool.ruff.per-file-ignores] 38 | "setup.py" = ["D100", "SIM115"] 39 | "__about__.py" = ["D100"] 40 | "__init__.py" = ["D100"] 41 | 42 | [tool.ruff.pydocstyle] 43 | # Use Google-style docstrings. 44 | convention = "google" 45 | 46 | [tool.docformatter] 47 | recursive = true 48 | # this need to be shorter as some docstings are r"""... 49 | wrap-summaries = 119 50 | wrap-descriptions = 120 51 | blank = true 52 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | torch >=1.11.0 2 | numpy >1.20.0, <2.0 3 | pandas >=1.1.5 4 | scikit-learn >=1.3.0 5 | pytorch-lightning >=2.0.0, <2.5.0 6 | omegaconf >=2.3.0 7 | torchmetrics >=0.10.0, <1.7.0 8 | tensorboard >2.2.0, !=2.5.0 9 | protobuf >=3.20.0, <5.30.0 10 | pytorch-tabnet ==4.1 11 | PyYAML >=5.4, <6.1.0 12 | # importlib-metadata <1,>=0.12 13 | matplotlib >3.1 14 | ipywidgets 15 | einops >=0.6.0, <0.8.0 16 | rich >=11.0.0 17 | fsspec >=2022.5.0, <2024.4.0 ; python_version == "3.8" 18 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | wget 2 | bump2version ==1.0.1 3 | # coverage >=4.5.4 4 | mkdocs-material ==9.5.* 5 | ipython[notebook] >8.0, <9.0 6 | mkdocstrings[python] ==0.26.*; python_version < "3.9" 7 | mkdocstrings[python] ==0.29.*; python_version >= "3.9" 8 | mknotebooks ==0.8.* 9 | pytest >=5.3.2 10 | pytest-runner >=5.1 11 | torch_optimizer 12 | -------------------------------------------------------------------------------- /requirements/extra.txt: -------------------------------------------------------------------------------- 1 | wandb >=0.15.0, <0.19.0 2 | plotly>=5.13.0, <5.25.0 3 | kaleido >=0.2.0, <0.3.0 4 | captum >=0.5.0, <0.8.0 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.1.0 3 | commit = True 4 | tag = False 5 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? 6 | serialize = 7 | {major}.{minor}.{patch}-{release}{build} 8 | {major}.{minor}.{patch} 9 | 10 | [bumpversion:part:release] 11 | optional_value = prod 12 | first_value = dev 13 | values = 14 | dev 15 | prod 16 | 17 | [bumpversion:part:build] 18 | 19 | [bumpversion:file:setup.py] 20 | search = version="{current_version}" 21 | replace = version="{new_version}" 22 | 23 | [bumpversion:file:src/pytorch_tabular/__init__.py] 24 | search = __version__ = "{current_version}" 25 | replace = __version__ = "{new_version}" 26 | 27 | [bdist_wheel] 28 | universal = 1 29 | 30 | [aliases] 31 | test = pytest 32 | 33 | [tool:pytest] 34 | collect_ignore = ['setup.py'] 35 | addopts = 36 | --color=yes 37 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """The setup script.""" 4 | 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | def read_requirements(thelibFolder, filename): 11 | requirementPath = os.path.join(thelibFolder, filename) 12 | requirements = [] 13 | if os.path.isfile(requirementPath): 14 | with open(requirementPath) as f: 15 | requirements = f.read().splitlines() 16 | return requirements 17 | 18 | 19 | with open("README.md") as readme_file: 20 | readme = readme_file.read() 21 | 22 | with open("docs/history.md") as history_file: 23 | history = history_file.read() 24 | 25 | this_folder = os.path.dirname(os.path.realpath(__file__)) 26 | req_folder = os.path.join(this_folder, "requirements") 27 | 28 | requirements = read_requirements(req_folder, "base.txt") 29 | requirements_testing = read_requirements(req_folder, "dev.txt") 30 | requirements_extra = read_requirements(req_folder, "extra.txt") 31 | 32 | # setup_requirements = ['pytest-runner', ] 33 | 34 | test_requirements = requirements_testing 35 | 36 | setup( 37 | author="Manu Joseph", 38 | author_email="manujosephv@gmail.com", 39 | python_requires=">=3.8", 40 | classifiers=[ 41 | "Development Status :: 4 - Beta", 42 | "Intended Audience :: Developers", 43 | "License :: OSI Approved :: MIT License", 44 | "Natural Language :: English", 45 | "Programming Language :: Python :: 3", 46 | "Programming Language :: Python :: 3.8", 47 | "Programming Language :: Python :: 3.9", 48 | "Programming Language :: Python :: 3.10", 49 | ], 50 | description="A standard framework for using Deep Learning for tabular data", 51 | install_requires=requirements, 52 | extras_require={"extra": requirements_extra, "dev": requirements_testing}, 53 | license="MIT license", 54 | long_description=readme + "\n\n" + history, 55 | long_description_content_type="text/markdown", 56 | include_package_data=True, 57 | keywords="pytorch, tabular, pytorch-lightning, neural network", 58 | name="pytorch_tabular", 59 | packages=find_packages(where="src", include=["pytorch_tabular", "pytorch_tabular.*"]), 60 | package_dir={"": "src"}, 61 | # setup_requires=test_requirements, 62 | test_suite="tests", 63 | tests_require=test_requirements, 64 | url="https://github.com/manujosephv/pytorch_tabular", 65 | version="1.1.0", 66 | zip_safe=False, 67 | ) 68 | -------------------------------------------------------------------------------- /src/pytorch_tabular/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for Pytorch Tabular.""" 2 | 3 | __author__ = """Manu Joseph""" 4 | __email__ = "manujosephv@gmail.com" 5 | __version__ = "1.1.0" 6 | 7 | from . import models, ssl_models 8 | from .categorical_encoders import CategoricalEmbeddingTransformer 9 | from .feature_extractor import DeepFeatureExtractor 10 | from .tabular_datamodule import TabularDatamodule 11 | from .tabular_model import TabularModel 12 | from .tabular_model_sweep import MODEL_SWEEP_PRESETS, model_sweep 13 | from .tabular_model_tuner import TabularModelTuner 14 | from .utils import available_models, available_ssl_models, get_logger 15 | 16 | logger = get_logger("pytorch_tabular") 17 | 18 | __all__ = [ 19 | "TabularModel", 20 | "TabularModelTuner", 21 | "TabularDatamodule", 22 | "models", 23 | "ssl_models", 24 | "CategoricalEmbeddingTransformer", 25 | "DeepFeatureExtractor", 26 | "utils", 27 | "model_sweep", 28 | "available_models", 29 | "available_ssl_models", 30 | "model_sweep", 31 | "MODEL_SWEEP_PRESETS", 32 | ] 33 | 34 | # fix Sphinx issues, see https://bit.ly/2K2eptM 35 | for item in __all__: 36 | if hasattr(item, "__module__"): 37 | setattr(item, "__module__", __name__) 38 | -------------------------------------------------------------------------------- /src/pytorch_tabular/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import ( 2 | DataConfig, 3 | ExperimentConfig, 4 | ExperimentRunManager, 5 | InferredConfig, 6 | ModelConfig, 7 | OptimizerConfig, 8 | SSLModelConfig, 9 | TrainerConfig, 10 | _validate_choices, 11 | ) 12 | 13 | __all__ = [ 14 | "TrainerConfig", 15 | "DataConfig", 16 | "ModelConfig", 17 | "InferredConfig", 18 | "ExperimentConfig", 19 | "OptimizerConfig", 20 | "SSLModelConfig", 21 | "ExperimentRunManager", 22 | "_validate_choices", 23 | ] 24 | -------------------------------------------------------------------------------- /src/pytorch_tabular/feature_extractor.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | from collections import defaultdict 5 | 6 | import pandas as pd 7 | from rich.progress import track 8 | from sklearn.base import BaseEstimator, TransformerMixin 9 | 10 | from pytorch_tabular.models import NODEModel, TabNetModel 11 | from pytorch_tabular.models.mixture_density import MDNModel 12 | 13 | try: 14 | import cPickle as pickle 15 | except ImportError: 16 | import pickle 17 | 18 | import torch 19 | 20 | 21 | class DeepFeatureExtractor(BaseEstimator, TransformerMixin): 22 | def __init__(self, tabular_model, extract_keys=["backbone_features"], drop_original=True): 23 | """Initializes the Transformer and extracts the neural features. 24 | 25 | Args: 26 | tabular_model (TabularModel): The trained TabularModel object 27 | extract_keys (list, optional): The keys of the features to extract. Defaults to ["backbone_features"]. 28 | drop_original (bool, optional): Whether to drop the original columns. Defaults to True. 29 | 30 | """ 31 | assert not ( 32 | isinstance(tabular_model.model, NODEModel) 33 | or isinstance(tabular_model.model, TabNetModel) 34 | or isinstance(tabular_model.model, MDNModel) 35 | ), "FeatureExtractor doesn't work for Mixture Density Networks, NODE Model, & Tabnet Model" 36 | self.tabular_model = tabular_model 37 | self.extract_keys = extract_keys 38 | self.drop_original = drop_original 39 | 40 | def fit(self, X, y=None): 41 | """Just for compatibility. 42 | 43 | Does not do anything 44 | 45 | """ 46 | return self 47 | 48 | def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: 49 | """Transforms the categorical columns specified to the trained neural features from the model. 50 | 51 | Args: 52 | X (pd.DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode. 53 | y ([type], optional): Only for compatibility. Not used. Defaults to None. 54 | 55 | Raises: 56 | ValueError: [description] 57 | 58 | Returns: 59 | pd.DataFrame: The encoded dataframe 60 | 61 | """ 62 | 63 | X_encoded = X.copy(deep=True) 64 | orig_features = X_encoded.columns 65 | self.tabular_model.model.eval() 66 | inference_dataloader = self.tabular_model.datamodule.prepare_inference_dataloader(X_encoded) 67 | logits_predictions = defaultdict(list) 68 | for batch in track(inference_dataloader, description="Generating Features..."): 69 | for k, v in batch.items(): 70 | if isinstance(v, list) and (len(v) == 0): 71 | # Skipping empty list 72 | continue 73 | batch[k] = v.to(self.tabular_model.model.device) 74 | if self.tabular_model.config.task == "ssl": 75 | ret_value = {"backbone_features": self.tabular_model.model.predict(batch, ret_model_output=True)} 76 | else: 77 | _, ret_value = self.tabular_model.model.predict(batch, ret_model_output=True) 78 | for k in self.extract_keys: 79 | if k in ret_value.keys(): 80 | logits_predictions[k].append(ret_value[k].detach().cpu()) 81 | 82 | logits_dfs = [] 83 | for k, v in logits_predictions.items(): 84 | v = torch.cat(v, dim=0).numpy() 85 | if v.ndim == 1: 86 | v = v.reshape(-1, 1) 87 | if v.shape[-1] > 1: 88 | temp_df = pd.DataFrame({f"{k}_{i}": v[:, i] for i in range(v.shape[-1])}) 89 | else: 90 | temp_df = pd.DataFrame({f"{k}": v[:, 0]}) 91 | 92 | # Append the temp DataFrame to the list 93 | logits_dfs.append(temp_df) 94 | 95 | preds = pd.concat(logits_dfs, axis=1) 96 | X_encoded = pd.concat([X_encoded, preds], axis=1) 97 | 98 | if self.drop_original: 99 | X_encoded.drop(columns=orig_features, inplace=True) 100 | return X_encoded 101 | 102 | def fit_transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: 103 | """Encode given columns of X based on the learned features. 104 | 105 | Args: 106 | X (pd.DataFrame): DataFrame of features, shape (n_samples, n_features). Must contain columns to encode. 107 | y ([type], optional): Only for compatibility. Not used. Defaults to None. 108 | 109 | Returns: 110 | pd.DataFrame: The encoded dataframe 111 | 112 | """ 113 | self.fit(X, y) 114 | return self.transform(X) 115 | 116 | def save_as_object_file(self, path): 117 | """Saves the feature extractor as a pickle file. 118 | 119 | Args: 120 | path (str): The path to save the file 121 | 122 | """ 123 | if not self._mapping: 124 | raise ValueError("`fit` method must be called before `save_as_object_file`.") 125 | pickle.dump(self.__dict__, open(path, "wb")) 126 | 127 | def load_from_object_file(self, path): 128 | """Loads the feature extractor from a pickle file. 129 | 130 | Args: 131 | path (str): The path to load the file from 132 | 133 | """ 134 | for k, v in pickle.load(open(path, "rb")).items(): 135 | setattr(self, k, v) 136 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | autoint, 3 | category_embedding, 4 | danet, 5 | ft_transformer, 6 | gandalf, 7 | gate, 8 | mixture_density, 9 | node, 10 | tab_transformer, 11 | tabnet, 12 | ) 13 | from .autoint import AutoIntConfig, AutoIntModel 14 | from .base_model import BaseModel 15 | from .category_embedding import CategoryEmbeddingModel, CategoryEmbeddingModelConfig 16 | from .danet import DANetConfig, DANetModel 17 | from .ft_transformer import FTTransformerConfig, FTTransformerModel 18 | from .gandalf import GANDALFBackbone, GANDALFConfig, GANDALFModel 19 | from .gate import GatedAdditiveTreeEnsembleConfig, GatedAdditiveTreeEnsembleModel 20 | from .mixture_density import MDNConfig, MDNModel 21 | from .node import NodeConfig, NODEModel 22 | from .stacking import StackingModel, StackingModelConfig 23 | from .tab_transformer import TabTransformerConfig, TabTransformerModel 24 | from .tabnet import TabNetModel, TabNetModelConfig 25 | 26 | __all__ = [ 27 | "CategoryEmbeddingModel", 28 | "CategoryEmbeddingModelConfig", 29 | "NODEModel", 30 | "NodeConfig", 31 | "TabNetModel", 32 | "TabNetModelConfig", 33 | "BaseModel", 34 | "MDNModel", 35 | "MDNConfig", 36 | "AutoIntConfig", 37 | "AutoIntModel", 38 | "TabTransformerConfig", 39 | "TabTransformerModel", 40 | "FTTransformerConfig", 41 | "FTTransformerModel", 42 | "GatedAdditiveTreeEnsembleConfig", 43 | "GatedAdditiveTreeEnsembleModel", 44 | "GANDALFConfig", 45 | "GANDALFModel", 46 | "GANDALFBackbone", 47 | "DANetConfig", 48 | "DANetModel", 49 | "StackingModel", 50 | "StackingModelConfig", 51 | "category_embedding", 52 | "node", 53 | "mixture_density", 54 | "tabnet", 55 | "autoint", 56 | "ft_transformer", 57 | "tab_transformer", 58 | "gate", 59 | "gandalf", 60 | "danet", 61 | "stacking", 62 | ] 63 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/autoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoint import AutoIntBackbone, AutoIntModel 2 | from .config import AutoIntConfig 3 | 4 | __all__ = ["AutoIntModel", "AutoIntBackbone", "AutoIntConfig"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/autoint/autoint.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | # Inspired by https://github.com/rixwew/pytorch-fm/blob/master/torchfm/model/afi.py 5 | """AutomaticFeatureInteraction Model.""" 6 | 7 | import torch 8 | import torch.nn as nn 9 | from omegaconf import DictConfig 10 | 11 | from pytorch_tabular.models.common.layers import Embedding2dLayer 12 | from pytorch_tabular.utils import _initialize_layers, _linear_dropout_bn 13 | 14 | from ..base_model import BaseModel 15 | 16 | 17 | class AutoIntBackbone(nn.Module): 18 | def __init__(self, config: DictConfig): 19 | """Automatic Feature Interaction Network. 20 | 21 | Args: 22 | config (DictConfig): config of the model 23 | 24 | """ 25 | super().__init__() 26 | self.hparams = config 27 | self._build_network() 28 | 29 | def _build_network(self): 30 | # Deep Layers 31 | _curr_units = self.hparams.embedding_dim 32 | if self.hparams.deep_layers: 33 | # Linear Layers 34 | layers = [] 35 | for units in self.hparams.layers.split("-"): 36 | layers.extend( 37 | _linear_dropout_bn( 38 | self.hparams.activation, 39 | self.hparams.initialization, 40 | self.hparams.use_batch_norm, 41 | _curr_units, 42 | int(units), 43 | self.hparams.dropout, 44 | ) 45 | ) 46 | _curr_units = int(units) 47 | self.linear_layers = nn.Sequential(*layers) 48 | # Projection to Multi-Headed Attention Dims 49 | self.attn_proj = nn.Linear(_curr_units, self.hparams.attn_embed_dim) 50 | _initialize_layers(self.hparams.activation, self.hparams.initialization, self.attn_proj) 51 | # Multi-Headed Attention Layers 52 | self.self_attns = nn.ModuleList( 53 | [ 54 | nn.MultiheadAttention( 55 | self.hparams.attn_embed_dim, 56 | self.hparams.num_heads, 57 | dropout=self.hparams.attn_dropouts, 58 | ) 59 | for _ in range(self.hparams.num_attn_blocks) 60 | ] 61 | ) 62 | if self.hparams.has_residuals: 63 | self.V_res_embedding = torch.nn.Linear( 64 | _curr_units, 65 | ( 66 | self.hparams.attn_embed_dim * self.hparams.num_attn_blocks 67 | if self.hparams.attention_pooling 68 | else self.hparams.attn_embed_dim 69 | ), 70 | ) 71 | self.output_dim = (self.hparams.continuous_dim + self.hparams.categorical_dim) * self.hparams.attn_embed_dim 72 | if self.hparams.attention_pooling: 73 | self.output_dim = self.output_dim * self.hparams.num_attn_blocks 74 | 75 | def _build_embedding_layer(self): 76 | return Embedding2dLayer( 77 | continuous_dim=self.hparams.continuous_dim, 78 | categorical_cardinality=self.hparams.categorical_cardinality, 79 | embedding_dim=self.hparams.embedding_dim, 80 | shared_embedding_strategy=self.hparams.share_embedding_strategy, 81 | frac_shared_embed=self.hparams.shared_embedding_fraction, 82 | embedding_bias=self.hparams.embedding_bias, 83 | batch_norm_continuous_input=self.hparams.batch_norm_continuous_input, 84 | embedding_dropout=self.hparams.embedding_dropout, 85 | initialization=self.hparams.embedding_initialization, 86 | ) 87 | 88 | def forward(self, x: torch.Tensor) -> torch.Tensor: 89 | if self.hparams.deep_layers: 90 | x = self.linear_layers(x) 91 | # (N, B, E*) --> E* is the Attn Dimention 92 | cross_term = self.attn_proj(x).transpose(0, 1) 93 | if self.hparams.attention_pooling: 94 | attention_ops = [] 95 | for self_attn in self.self_attns: 96 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 97 | if self.hparams.attention_pooling: 98 | attention_ops.append(cross_term) 99 | if self.hparams.attention_pooling: 100 | cross_term = torch.cat(attention_ops, dim=-1) 101 | # (B, N, E*) 102 | cross_term = cross_term.transpose(0, 1) 103 | if self.hparams.has_residuals: 104 | # (B, N, E*) --> Projecting Embedded input to Attention sub-space 105 | V_res = self.V_res_embedding(x) 106 | cross_term = cross_term + V_res 107 | # (B, NxE*) 108 | cross_term = nn.ReLU()(cross_term).reshape(-1, self.output_dim) 109 | return cross_term 110 | 111 | 112 | class AutoIntModel(BaseModel): 113 | def __init__(self, config: DictConfig, **kwargs): 114 | super().__init__(config, **kwargs) 115 | 116 | @property 117 | def backbone(self): 118 | return self._backbone 119 | 120 | @property 121 | def embedding_layer(self): 122 | return self._embedding_layer 123 | 124 | @property 125 | def head(self): 126 | return self._head 127 | 128 | def _build_network(self): 129 | # Backbone 130 | self._backbone = AutoIntBackbone(self.hparams) 131 | # Embedding Layer 132 | self._embedding_layer = self._backbone._build_embedding_layer() 133 | # Head 134 | self._head = self._get_head_from_config() 135 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/category_embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .category_embedding_model import CategoryEmbeddingBackbone, CategoryEmbeddingModel 2 | from .config import CategoryEmbeddingModelConfig 3 | 4 | __all__ = ["CategoryEmbeddingModel", "CategoryEmbeddingModelConfig", "CategoryEmbeddingBackbone"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/category_embedding/category_embedding_model.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """Category Embedding Model.""" 5 | 6 | import torch 7 | import torch.nn as nn 8 | from omegaconf import DictConfig 9 | 10 | from pytorch_tabular.models.common.layers import Embedding1dLayer 11 | from pytorch_tabular.utils import _initialize_layers, _linear_dropout_bn 12 | 13 | from ..base_model import BaseModel 14 | 15 | 16 | class CategoryEmbeddingBackbone(nn.Module): 17 | def __init__(self, config: DictConfig, **kwargs): 18 | super().__init__() 19 | self.hparams = config 20 | self._build_network() 21 | 22 | def _build_network(self): 23 | # Linear Layers 24 | layers = [] 25 | if hasattr(self.hparams, "_backbone_input_dim"): 26 | _curr_units = self.hparams._backbone_input_dim # TODO implement this backdoor in every model? 27 | else: 28 | _curr_units = self.hparams.embedded_cat_dim + self.hparams.continuous_dim 29 | for units in self.hparams.layers.split("-"): 30 | layers.extend( 31 | _linear_dropout_bn( 32 | self.hparams.activation, 33 | self.hparams.initialization, 34 | self.hparams.use_batch_norm, 35 | _curr_units, 36 | int(units), 37 | self.hparams.dropout, 38 | ) 39 | ) 40 | _curr_units = int(units) 41 | self.linear_layers = nn.Sequential(*layers) 42 | _initialize_layers(self.hparams.activation, self.hparams.initialization, self.linear_layers) 43 | self.output_dim = _curr_units 44 | 45 | def _build_embedding_layer(self): 46 | return Embedding1dLayer( 47 | continuous_dim=self.hparams.continuous_dim, 48 | categorical_embedding_dims=self.hparams.embedding_dims, 49 | embedding_dropout=self.hparams.embedding_dropout, 50 | batch_norm_continuous_input=self.hparams.batch_norm_continuous_input, 51 | virtual_batch_size=self.hparams.virtual_batch_size, 52 | ) 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | x = self.linear_layers(x) 56 | return x 57 | 58 | 59 | class CategoryEmbeddingModel(BaseModel): 60 | def __init__(self, config: DictConfig, **kwargs): 61 | super().__init__(config, **kwargs) 62 | 63 | @property 64 | def backbone(self): 65 | return self._backbone 66 | 67 | @property 68 | def embedding_layer(self): 69 | return self._embedding_layer 70 | 71 | @property 72 | def head(self): 73 | return self._head 74 | 75 | def _build_network(self): 76 | # Backbone 77 | self._backbone = CategoryEmbeddingBackbone(self.hparams) 78 | # Embedding Layer 79 | self._embedding_layer = self._backbone._build_embedding_layer() 80 | # Head 81 | self.head = self._get_head_from_config() 82 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/common/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_tabular.models.common import heads, layers 2 | from pytorch_tabular.models.common.layers import activations 3 | 4 | __all__ = ["activations", "layers", "heads"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/common/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import LinearHead, MixtureDensityHead 2 | from .config import LinearHeadConfig, MixtureDensityHeadConfig 3 | 4 | __all__ = ["LinearHead", "MixtureDensityHead", "LinearHeadConfig", "MixtureDensityHeadConfig"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/common/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import activations 2 | from .batch_norm import GBN, BatchNorm1d 3 | from .embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer, SharedEmbeddings 4 | from .gated_units import GEGLU, GatedFeatureLearningUnit, PositionWiseFeedForward, ReGLU, SwiGLU 5 | from .misc import Add, Lambda, ModuleWithInit, Residual 6 | from .soft_trees import ODST, NeuralDecisionTree 7 | from .transformers import AddNorm, AppendCLSToken, MultiHeadedAttention, TransformerEncoderBlock 8 | 9 | __all__ = [ 10 | "PreEncoded1dLayer", 11 | "SharedEmbeddings", 12 | "Embedding1dLayer", 13 | "Embedding2dLayer", 14 | "Residual", 15 | "Add", 16 | "Lambda", 17 | "ModuleWithInit", 18 | "PositionWiseFeedForward", 19 | "AddNorm", 20 | "MultiHeadedAttention", 21 | "TransformerEncoderBlock", 22 | "AppendCLSToken", 23 | "ODST", 24 | "activations", 25 | "GEGLU", 26 | "ReGLU", 27 | "SwiGLU", 28 | "NeuralDecisionTree", 29 | "GatedFeatureLearningUnit", 30 | "GBN", 31 | "BatchNorm1d", 32 | ] 33 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/common/layers/batch_norm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class GBN(nn.Module): 7 | """ 8 | Ghost Batch Normalization 9 | https://arxiv.org/abs/1705.08741 10 | """ 11 | 12 | def __init__(self, input_dim, virtual_batch_size=512): 13 | super().__init__() 14 | self.input_dim = input_dim 15 | self.virtual_batch_size = virtual_batch_size 16 | self.bn = nn.BatchNorm1d(self.input_dim) 17 | 18 | def forward(self, x): 19 | if self.training: 20 | chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0) 21 | res = [self.bn(x_) for x_ in chunks] 22 | return torch.cat(res, dim=0) 23 | else: 24 | return self.bn(x) 25 | 26 | 27 | class BatchNorm1d(nn.Module): 28 | """BatchNorm1d with Ghost Batch Normalization.""" 29 | 30 | def __init__(self, num_features, virtual_batch_size=None): 31 | super().__init__() 32 | self.num_features = num_features 33 | self.virtual_batch_size = virtual_batch_size 34 | if self.virtual_batch_size is None: 35 | self.bn = nn.BatchNorm1d(self.num_features) 36 | else: 37 | self.bn = GBN(self.num_features, self.virtual_batch_size) 38 | 39 | def forward(self, x): 40 | return self.bn(x) 41 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/common/layers/misc.py: -------------------------------------------------------------------------------- 1 | # W605 2 | from typing import Callable, Union 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | 13 | def forward(self, x, **kwargs): 14 | return self.fn(x, **kwargs) + x 15 | 16 | 17 | class Lambda(nn.Module): 18 | """A wrapper for a lambda function as a pytorch module.""" 19 | 20 | def __init__(self, func: Callable): 21 | """Initialize lambda module 22 | Args: 23 | func: any function/callable 24 | """ 25 | super().__init__() 26 | self.func = func 27 | 28 | def forward(self, *args, **kwargs): 29 | return self.func(*args, **kwargs) 30 | 31 | 32 | class ModuleWithInit(nn.Module): 33 | """Base class for pytorch module with data-aware initializer on first batch.""" 34 | 35 | def __init__(self): 36 | super().__init__() 37 | self._is_initialized_tensor = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False) 38 | self._is_initialized_bool = None 39 | # Note: this module uses a separate flag self._is_initialized so as to achieve both 40 | # * persistence: is_initialized is saved alongside model in state_dict 41 | # * speed: model doesn't need to cache 42 | # please DO NOT use these flags in child modules 43 | 44 | def initialize(self, *args, **kwargs): 45 | """Initialize module tensors using first batch of data.""" 46 | raise NotImplementedError("Please implement ") 47 | 48 | def __call__(self, *args, **kwargs): 49 | if self._is_initialized_bool is None: 50 | self._is_initialized_bool = bool(self._is_initialized_tensor.item()) 51 | if not self._is_initialized_bool: 52 | self.initialize(*args, **kwargs) 53 | self._is_initialized_tensor.data[...] = 1 54 | self._is_initialized_bool = True 55 | return super().__call__(*args, **kwargs) 56 | 57 | 58 | class Add(nn.Module): 59 | """A module that adds a constant/parameter value to the input.""" 60 | 61 | def __init__(self, add_value: Union[float, torch.Tensor]): 62 | """Initialize the module. 63 | 64 | Args: 65 | add_value: The value to add to the input 66 | 67 | """ 68 | super().__init__() 69 | self.add_value = add_value 70 | 71 | def forward(self, x): 72 | return x + self.add_value 73 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/common/layers/transformers.py: -------------------------------------------------------------------------------- 1 | # W605 2 | import math 3 | from typing import Optional 4 | 5 | import torch 6 | from einops import rearrange 7 | from torch import einsum, nn 8 | 9 | from pytorch_tabular.utils import _initialize_kaiming 10 | 11 | from .gated_units import GEGLU, PositionWiseFeedForward, ReGLU, SwiGLU 12 | 13 | # from . import activations 14 | 15 | 16 | GATED_UNITS = {"GEGLU": GEGLU, "ReGLU": ReGLU, "SwiGLU": SwiGLU} 17 | 18 | 19 | class AddNorm(nn.Module): 20 | """Applies LayerNorm, Dropout and adds to input. 21 | 22 | Standard AddNorm operations in Transformers 23 | 24 | """ 25 | 26 | def __init__(self, input_dim: int, dropout: float): 27 | super().__init__() 28 | self.dropout = nn.Dropout(dropout) 29 | self.ln = nn.LayerNorm(input_dim) 30 | 31 | def forward(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: 32 | return self.ln(self.dropout(Y) + X) 33 | 34 | 35 | class MultiHeadedAttention(nn.Module): 36 | """Multi Headed Attention Block in Transformers.""" 37 | 38 | def __init__( 39 | self, 40 | input_dim: int, 41 | num_heads: int = 8, 42 | head_dim: int = 16, 43 | dropout: int = 0.1, 44 | keep_attn: bool = True, 45 | ): 46 | super().__init__() 47 | assert input_dim % num_heads == 0, "'input_dim' must be multiples of 'num_heads'" 48 | inner_dim = head_dim * num_heads 49 | self.n_heads = num_heads 50 | self.scale = head_dim**-0.5 51 | self.keep_attn = keep_attn 52 | 53 | self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=False) 54 | self.to_out = nn.Linear(inner_dim, input_dim) 55 | 56 | self.dropout = nn.Dropout(dropout) 57 | 58 | def forward(self, x): 59 | h = self.n_heads 60 | q, k, v = self.to_qkv(x).chunk(3, dim=-1) 61 | q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v)) 62 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale 63 | 64 | attn = sim.softmax(dim=-1) 65 | attn = self.dropout(attn) 66 | if self.keep_attn: 67 | self.attn_weights = attn 68 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 69 | out = rearrange(out, "b h n d -> b n (h d)", h=h) 70 | return self.to_out(out) 71 | 72 | 73 | class TransformerEncoderBlock(nn.Module): 74 | """A single Transformer Encoder Block.""" 75 | 76 | def __init__( 77 | self, 78 | input_embed_dim: int, 79 | num_heads: int = 8, 80 | ff_hidden_multiplier: int = 4, 81 | ff_activation: str = "GEGLU", 82 | attn_dropout: float = 0.1, 83 | keep_attn: bool = True, 84 | ff_dropout: float = 0.1, 85 | add_norm_dropout: float = 0.1, 86 | transformer_head_dim: Optional[int] = None, 87 | ): 88 | """ 89 | Args: 90 | input_embed_dim: The input embedding dimension 91 | num_heads: The number of attention heads 92 | ff_hidden_multiplier: The hidden dimension multiplier for the position-wise feed-forward layer 93 | ff_activation: The activation function for the position-wise feed-forward layer 94 | attn_dropout: The dropout probability for the attention layer 95 | keep_attn: Whether to keep the attention weights 96 | ff_dropout: The dropout probability for the position-wise feed-forward layer 97 | add_norm_dropout: The dropout probability for the residual connections 98 | transformer_head_dim: The dimension of the attention heads. If None, will default to input_embed_dim 99 | """ 100 | super().__init__() 101 | self.mha = MultiHeadedAttention( 102 | input_embed_dim, 103 | num_heads, 104 | head_dim=input_embed_dim if transformer_head_dim is None else transformer_head_dim, 105 | dropout=attn_dropout, 106 | keep_attn=keep_attn, 107 | ) 108 | 109 | try: 110 | self.pos_wise_ff = GATED_UNITS[ff_activation]( 111 | d_model=input_embed_dim, 112 | d_ff=input_embed_dim * ff_hidden_multiplier, 113 | dropout=ff_dropout, 114 | ) 115 | except (AttributeError, KeyError): 116 | self.pos_wise_ff = PositionWiseFeedForward( 117 | d_model=input_embed_dim, 118 | d_ff=input_embed_dim * ff_hidden_multiplier, 119 | dropout=ff_dropout, 120 | activation=getattr(nn, ff_activation)(), 121 | ) 122 | self.attn_add_norm = AddNorm(input_embed_dim, add_norm_dropout) 123 | self.ff_add_norm = AddNorm(input_embed_dim, add_norm_dropout) 124 | 125 | def forward(self, x): 126 | y = self.mha(x) 127 | x = self.attn_add_norm(x, y) 128 | y = self.pos_wise_ff(y) 129 | return self.ff_add_norm(x, y) 130 | 131 | 132 | class AppendCLSToken(nn.Module): 133 | """Appends the [CLS] token for BERT-like inference.""" 134 | 135 | def __init__(self, d_token: int, initialization: str) -> None: 136 | """Initialize self.""" 137 | super().__init__() 138 | self.weight = nn.Parameter(torch.Tensor(d_token)) 139 | d_sqrt_inv = 1 / math.sqrt(d_token) 140 | _initialize_kaiming(self.weight, initialization, d_sqrt_inv) 141 | 142 | def forward(self, x: torch.Tensor) -> torch.Tensor: 143 | """Perform the forward pass.""" 144 | assert x.ndim == 3 145 | return torch.cat([x, self.weight.view(1, 1, -1).repeat(len(x), 1, 1)], dim=1) 146 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/danet/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import DANetConfig 2 | from .danet import DANetBackbone, DANetModel 3 | 4 | __all__ = ["DANetConfig", "DANetBackbone", "DANetModel"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/danet/arch_blocks.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from pytorch_tabular.models.common.layers.activations import entmax15 9 | from pytorch_tabular.models.common.layers.batch_norm import GBN 10 | 11 | 12 | def initialize_glu(module, input_dim, output_dim): 13 | gain_value = np.sqrt((input_dim + output_dim) / np.sqrt(input_dim)) 14 | torch.nn.init.xavier_normal_(module.weight, gain=gain_value) 15 | return 16 | 17 | 18 | class LearnableLocality(nn.Module): 19 | def __init__(self, input_dim, k): 20 | super().__init__() 21 | self.register_parameter("weight", nn.Parameter(torch.rand(k, input_dim))) 22 | self.smax = partial(entmax15, dim=-1) 23 | 24 | def forward(self, x): 25 | mask = self.smax(self.weight) 26 | masked_x = torch.einsum("nd,bd->bnd", mask, x) # [B, k, D] 27 | return masked_x 28 | 29 | 30 | class AbstractLayer(nn.Module): 31 | def __init__(self, base_input_dim, base_output_dim, k, virtual_batch_size, bias=True): 32 | super().__init__() 33 | self.masker = LearnableLocality(input_dim=base_input_dim, k=k) 34 | self.fc = nn.Conv1d( 35 | base_input_dim * k, 36 | 2 * k * base_output_dim, 37 | kernel_size=1, 38 | groups=k, 39 | bias=bias, 40 | ) 41 | initialize_glu(self.fc, input_dim=base_input_dim * k, output_dim=2 * k * base_output_dim) 42 | self.bn = GBN(2 * base_output_dim * k, virtual_batch_size) 43 | self.k = k 44 | self.base_output_dim = base_output_dim 45 | 46 | def forward(self, x): 47 | b = x.size(0) 48 | x = self.masker(x) # [B, D] -> [B, k, D] 49 | x = self.fc(x.view(b, -1, 1)) # [B, k, D] -> [B, k * D, 1] -> [B, k * (2 * D'), 1] 50 | x = self.bn(x) 51 | chunks = x.chunk(self.k, 1) # k * [B, 2 * D', 1] 52 | x = sum( 53 | [ 54 | F.relu(torch.sigmoid(x_[:, : self.base_output_dim, :]) * x_[:, self.base_output_dim :, :]) 55 | for x_ in chunks 56 | ] 57 | ) # k * [B, D', 1] -> [B, D', 1] 58 | return x.squeeze(-1) 59 | 60 | 61 | class BasicBlock(nn.Module): 62 | def __init__( 63 | self, 64 | input_dim, 65 | abstlay_dim_1, 66 | abstlay_dim_2, 67 | k, 68 | virtual_batch_size, 69 | fix_input_dim, 70 | drop_rate, 71 | block_activation, 72 | ): 73 | super().__init__() 74 | self.conv1 = AbstractLayer(input_dim, abstlay_dim_1, k, virtual_batch_size) 75 | self.conv2 = AbstractLayer(abstlay_dim_1, abstlay_dim_2, k, virtual_batch_size) 76 | 77 | self.downsample = nn.Sequential( 78 | nn.Dropout(drop_rate), 79 | AbstractLayer(fix_input_dim, abstlay_dim_2, k, virtual_batch_size), 80 | ) 81 | self.block_activation = block_activation 82 | 83 | def forward(self, x, pre_out=None): 84 | if pre_out is None: 85 | pre_out = x 86 | out = self.conv1(pre_out) 87 | out = self.conv2(out) 88 | identity = self.downsample(x) 89 | out += identity 90 | return self.block_activation(out) 91 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/danet/danet.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """DANet Model.""" 5 | 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | from omegaconf import DictConfig 11 | 12 | from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer 13 | 14 | from ..base_model import BaseModel 15 | from .arch_blocks import BasicBlock 16 | 17 | 18 | class DANetBackbone(nn.Module): 19 | def __init__( 20 | self, 21 | n_continuous_features: int, 22 | cat_embedding_dims: list, 23 | n_layers: int, 24 | abstlay_dim_1: int, 25 | abstlay_dim_2: int, 26 | k: int, 27 | dropout_rate: float, 28 | block_activation: nn.Module, 29 | virtual_batch_size: int, 30 | embedding_dropout: float, 31 | batch_norm_continuous_input: bool, 32 | ): 33 | super().__init__() 34 | self.cat_embedding_dims = cat_embedding_dims 35 | self.n_continuous_features = n_continuous_features 36 | self.input_dim = n_continuous_features + sum([y for x, y in cat_embedding_dims]) 37 | self.n_layers = n_layers 38 | self.abstlay_dim_1 = abstlay_dim_1 39 | self.abstlay_dim_2 = abstlay_dim_2 40 | self.k = k 41 | self.dropout_rate = dropout_rate 42 | self.block_activation = block_activation 43 | self.virtual_batch_size = virtual_batch_size 44 | self.batch_norm_continuous_input = batch_norm_continuous_input 45 | self.embedding_dropout = embedding_dropout 46 | 47 | self.output_dim = self.abstlay_dim_2 48 | self._build_network() 49 | 50 | def _build_network(self): 51 | params = { 52 | "fix_input_dim": self.input_dim, 53 | "k": self.k, 54 | "virtual_batch_size": self.virtual_batch_size, 55 | "abstlay_dim_1": self.abstlay_dim_1, 56 | "abstlay_dim_2": self.abstlay_dim_2, 57 | "drop_rate": self.dropout_rate, 58 | "block_activation": self.block_activation, 59 | } 60 | self.init_layer = BasicBlock(self.input_dim, **params) 61 | self.layers = nn.ModuleList() 62 | for i in range(self.n_layers - 1): 63 | self.layers.append(BasicBlock(self.abstlay_dim_2, **params)) 64 | 65 | def _build_embedding_layer(self): 66 | return Embedding1dLayer( 67 | continuous_dim=self.n_continuous_features, 68 | categorical_embedding_dims=self.cat_embedding_dims, 69 | embedding_dropout=self.embedding_dropout, 70 | batch_norm_continuous_input=self.batch_norm_continuous_input, 71 | virtual_batch_size=self.virtual_batch_size, 72 | ) 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | out = self.init_layer(x) 76 | for layer in self.layers: 77 | out = layer(x, pre_out=out) 78 | return out 79 | 80 | # Not Tested Properly 81 | # def _calculate_feature_importance(self): 82 | # n, h, f, _ = self.attention_weights_[0].shape 83 | # device = self.attention_weights_[0].device 84 | # L = len(self.attention_weights_) 85 | # self.local_feature_importance = torch.zeros((n, f), device=device) 86 | # for attn_weights in self.attention_weights_: 87 | # self.local_feature_importance += attn_weights[:, :, :, -1].sum(dim=1) 88 | # self.local_feature_importance = (1 / (h * L)) * self.local_feature_importance[ 89 | # :, :-1 90 | # ] 91 | # self.feature_importance_ = ( 92 | # self.local_feature_importance.mean(dim=0).detach().cpu().numpy() 93 | # ) 94 | # self.feature_importance_count_+=attn_weights.shape[0] 95 | 96 | 97 | class DANetModel(BaseModel): 98 | def __init__(self, config: DictConfig, **kwargs): 99 | super().__init__(config, **kwargs) 100 | 101 | @property 102 | def backbone(self): 103 | return self._backbone 104 | 105 | @property 106 | def embedding_layer(self): 107 | return self._embedding_layer 108 | 109 | @property 110 | def head(self): 111 | return self._head 112 | 113 | def _build_network(self): 114 | if self.hparams.virtual_batch_size > self.hparams.batch_size: 115 | warnings.warn( 116 | f"virtual_batch_size({self.hparams.virtual_batch_size}) is greater " 117 | f"than batch_size ({self.hparams.batch_size}). Setting virtual_batch_size " 118 | f"to {self.hparams.batch_size}. DANet uses Ghost Batch Normalization, " 119 | f"which works best when virtual_batch_size is small. Consider setting " 120 | "virtual_batch_size to something like 256 or 512." 121 | ) 122 | self.hparams.virtual_batch_size = self.hparams.batch_size 123 | # Backbone 124 | self._backbone = DANetBackbone( 125 | cat_embedding_dims=self.hparams.embedding_dims, 126 | n_continuous_features=self.hparams.continuous_dim, 127 | n_layers=self.hparams.n_layers, 128 | abstlay_dim_1=self.hparams.abstlay_dim_1, 129 | abstlay_dim_2=self.hparams.abstlay_dim_2, 130 | k=self.hparams.k, 131 | dropout_rate=self.hparams.dropout_rate, 132 | block_activation=getattr(nn, self.hparams.block_activation)(), 133 | virtual_batch_size=self.hparams.virtual_batch_size, 134 | embedding_dropout=self.hparams.embedding_dropout, 135 | batch_norm_continuous_input=self.hparams.batch_norm_continuous_input, 136 | ) 137 | # Embedding Layer 138 | self._embedding_layer = self._backbone._build_embedding_layer() 139 | # Head 140 | self._head = self._get_head_from_config() 141 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/ft_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import FTTransformerConfig 2 | from .ft_transformer import FTTransformerBackbone, FTTransformerModel 3 | 4 | __all__ = ["FTTransformerBackbone", "FTTransformerModel", "FTTransformerConfig"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/ft_transformer/ft_transformer.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """Feature Tokenizer Transformer Model.""" 5 | 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | from omegaconf import DictConfig 11 | 12 | from pytorch_tabular.models.common.layers.batch_norm import BatchNorm1d 13 | 14 | from ..base_model import BaseModel 15 | from ..common.layers import AppendCLSToken, Embedding2dLayer, TransformerEncoderBlock 16 | 17 | 18 | class FTTransformerBackbone(nn.Module): 19 | def __init__(self, config: DictConfig): 20 | super().__init__() 21 | assert config.share_embedding_strategy in [ 22 | "add", 23 | "fraction", 24 | ], ( 25 | "`share_embedding_strategy` should be one of `add` or `fraction`," 26 | f" not {self.hparams.share_embedding_strategy}" 27 | ) 28 | self.hparams = config 29 | self._build_network() 30 | 31 | def _build_network(self): 32 | self.add_cls = AppendCLSToken( 33 | d_token=self.hparams.input_embed_dim, 34 | initialization=self.hparams.embedding_initialization, 35 | ) 36 | self.transformer_blocks = OrderedDict() 37 | for i in range(self.hparams.num_attn_blocks): 38 | self.transformer_blocks[f"mha_block_{i}"] = TransformerEncoderBlock( 39 | input_embed_dim=self.hparams.input_embed_dim, 40 | num_heads=self.hparams.num_heads, 41 | ff_hidden_multiplier=self.hparams.ff_hidden_multiplier, 42 | ff_activation=self.hparams.transformer_activation, 43 | attn_dropout=self.hparams.attn_dropout, 44 | ff_dropout=self.hparams.ff_dropout, 45 | add_norm_dropout=self.hparams.add_norm_dropout, 46 | keep_attn=self.hparams.attn_feature_importance, # Can use Attn Weights to derive feature importance 47 | ) 48 | self.transformer_blocks = nn.Sequential(self.transformer_blocks) 49 | if self.hparams.attn_feature_importance: 50 | self.attention_weights_ = [None] * self.hparams.num_attn_blocks 51 | if self.hparams.batch_norm_continuous_input: 52 | self.normalizing_batch_norm = BatchNorm1d(self.hparams.continuous_dim, self.hparams.virtual_batch_size) 53 | 54 | self.output_dim = self.hparams.input_embed_dim 55 | 56 | def _build_embedding_layer(self): 57 | return Embedding2dLayer( 58 | continuous_dim=self.hparams.continuous_dim, 59 | categorical_cardinality=self.hparams.categorical_cardinality, 60 | embedding_dim=self.hparams.input_embed_dim, 61 | shared_embedding_strategy=self.hparams.share_embedding_strategy, 62 | frac_shared_embed=self.hparams.shared_embedding_fraction, 63 | embedding_bias=self.hparams.embedding_bias, 64 | batch_norm_continuous_input=self.hparams.batch_norm_continuous_input, 65 | embedding_dropout=self.hparams.embedding_dropout, 66 | initialization=self.hparams.embedding_initialization, 67 | virtual_batch_size=self.hparams.virtual_batch_size, 68 | ) 69 | 70 | def forward(self, x: torch.Tensor) -> torch.Tensor: 71 | x = self.add_cls(x) 72 | for i, block in enumerate(self.transformer_blocks): 73 | x = block(x) 74 | if self.hparams.attn_feature_importance: 75 | self.attention_weights_[i] = block.mha.attn_weights 76 | # self.feature_importance_+=block.mha.attn_weights[:,:,:,-1].sum(dim=1) 77 | # self._calculate_feature_importance(block.mha.attn_weights) 78 | if self.hparams.attn_feature_importance: 79 | self._calculate_feature_importance() 80 | # Flatten (Batch, N_Categorical, Hidden) --> (Batch, N_CategoricalxHidden) 81 | # x = rearrange(x, "b n h -> b (n h)") 82 | # Taking only CLS token for the prediction head 83 | return x[:, -1] 84 | 85 | # Not Tested Properly 86 | def _calculate_feature_importance(self): 87 | n, h, f, _ = self.attention_weights_[0].shape 88 | device = self.attention_weights_[0].device 89 | L = len(self.attention_weights_) 90 | self.local_feature_importance = torch.zeros((n, f), device=device) 91 | for attn_weights in self.attention_weights_: 92 | self.local_feature_importance += attn_weights[:, :, :, -1].sum(dim=1) 93 | self.local_feature_importance = (1 / (h * L)) * self.local_feature_importance[:, :-1] 94 | self.feature_importance_ = self.local_feature_importance.mean(dim=0).detach().cpu().numpy() 95 | # self.feature_importance_count_+=attn_weights.shape[0] 96 | 97 | 98 | class FTTransformerModel(BaseModel): 99 | def __init__(self, config: DictConfig, **kwargs): 100 | super().__init__(config, **kwargs) 101 | 102 | @property 103 | def backbone(self): 104 | return self._backbone 105 | 106 | @property 107 | def embedding_layer(self): 108 | return self._embedding_layer 109 | 110 | @property 111 | def head(self): 112 | return self._head 113 | 114 | def _build_network(self): 115 | # Backbone 116 | self._backbone = FTTransformerBackbone(self.hparams) 117 | # Embedding Layer 118 | self._embedding_layer = self._backbone._build_embedding_layer() 119 | # Head 120 | self._head = self._get_head_from_config() 121 | 122 | def feature_importance(self): 123 | if self.hparams.attn_feature_importance: 124 | return super().feature_importance() 125 | else: 126 | raise ValueError("If you want Feature Importance, `attn_feature_weights` should be `True`.") 127 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/gandalf/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import GANDALFConfig 2 | from .gandalf import GANDALFBackbone, GANDALFModel 3 | 4 | __all__ = [ 5 | "GANDALFBackbone", 6 | "GANDALFModel", 7 | "GANDALFConfig", 8 | ] 9 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/gandalf/config.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """AutomaticFeatureInteraction Config.""" 5 | 6 | from dataclasses import dataclass, field 7 | 8 | from pytorch_tabular.config import ModelConfig 9 | 10 | 11 | @dataclass 12 | class GANDALFConfig(ModelConfig): 13 | """Gated Adaptive Network for Deep Automated Learning of Features (GANDALF) Config. 14 | 15 | Args: 16 | gflu_stages (int): Number of layers in the feature abstraction layer. Defaults to 6 17 | 18 | gflu_dropout (float): Dropout rate for the feature abstraction layer. Defaults to 0.0 19 | 20 | gflu_feature_init_sparsity (float): Only valid for t-softmax. The percentage of features 21 | to be selected in each GFLU stage. This is just initialized and during learning 22 | it may change. Defaults to 0.3 23 | 24 | learnable_sparsity (bool): Only valid for t-softmax. If True, the sparsity parameters 25 | will be learned. If False, the sparsity parameters will be fixed to the initial 26 | values specified in `gflu_feature_init_sparsity` and `tree_feature_init_sparsity`. 27 | Defaults to True 28 | 29 | task (str): Specify whether the problem is regression or classification. `backbone` is a task which 30 | considers the model as a backbone to generate features. Mostly used internally for SSL and related 31 | tasks. Choices are: [`regression`,`classification`,`backbone`]. 32 | 33 | head (Optional[str]): The head to be used for the model. Should be one of the heads defined in 34 | `pytorch_tabular.models.common.heads`. Defaults to LinearHead. Choices are: 35 | [`None`,`LinearHead`,`MixtureDensityHead`]. 36 | 37 | head_config (Optional[Dict]): The config as a dict which defines the head. If left empty, will be 38 | initialized as default linear head. 39 | 40 | embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a 41 | list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of 42 | the categorical column using the rule min(50, (x + 1) // 2) 43 | 44 | embedding_dropout (float): Dropout to be applied to the Categorical Embedding. Defaults to 0.0 45 | 46 | batch_norm_continuous_input (bool): If True, we will normalize the continuous layer by passing it 47 | through a BatchNorm layer. 48 | 49 | learning_rate (float): The learning rate of the model. Defaults to 1e-3. 50 | 51 | loss (Optional[str]): The loss function to be applied. By Default, it is MSELoss for regression and 52 | CrossEntropyLoss for classification. Unless you are sure what you are doing, leave it at MSELoss 53 | or L1Loss for regression and CrossEntropyLoss for classification 54 | 55 | metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics 56 | should be one of the functional metrics implemented in ``torchmetrics``. By default, it is 57 | accuracy if classification and mean_squared_error for regression 58 | 59 | metrics_params (Optional[List]): The parameters to be passed to the metrics function. `task` is forced to 60 | be `multiclass` because the multiclass version can handle binary as well and for simplicity we are 61 | only using `multiclass`. 62 | 63 | metrics_prob_input (Optional[List]): Is a mandatory parameter for classification metrics defined in the config. 64 | This defines whether the input to the metric function is the probability or the class. Length should be 65 | same as the number of metrics. Defaults to None. 66 | 67 | target_range (Optional[List]): The range in which we should limit the output variable. Currently 68 | ignored for multi-target regression. Typically used for Regression problems. If left empty, will 69 | not apply any restrictions 70 | 71 | seed (int): The seed for reproducibility. Defaults to 42 72 | 73 | """ 74 | 75 | gflu_stages: int = field( 76 | default=6, 77 | metadata={"help": "Number of layers in the feature abstraction layer. Defaults to 6"}, 78 | ) 79 | 80 | gflu_dropout: float = field( 81 | default=0.0, metadata={"help": "Dropout rate for the feature abstraction layer. Defaults to 0.0"} 82 | ) 83 | 84 | gflu_feature_init_sparsity: float = field( 85 | default=0.3, 86 | metadata={ 87 | "help": "Only valid for t-softmax. The perecentge of features to be selected in " 88 | "each GFLU stage. This is just initialized and during learning it may change" 89 | }, 90 | ) 91 | learnable_sparsity: bool = field( 92 | default=True, 93 | metadata={ 94 | "help": "Only valid for t-softmax. If True, the sparsity parameters will be learned." 95 | "If False, the sparsity parameters will be fixed to the initial values specified in " 96 | "`gflu_feature_init_sparsity` and `tree_feature_init_sparsity`" 97 | }, 98 | ) 99 | _module_src: str = field(default="models.gandalf") 100 | _model_name: str = field(default="GANDALFModel") 101 | _backbone_name: str = field(default="GANDALFBackbone") 102 | _config_name: str = field(default="GANDALFConfig") 103 | 104 | def __post_init__(self): 105 | assert self.gflu_stages > 0, "gflu_stages should be greater than 0" 106 | return super().__post_init__() 107 | 108 | 109 | if __name__ == "__main__": 110 | from pytorch_tabular.utils import generate_doc_dataclass 111 | 112 | print(generate_doc_dataclass(GANDALFConfig, desc="GANDALF Config")) 113 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/gandalf/gandalf.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | import torch 5 | import torch.nn as nn 6 | from omegaconf import DictConfig 7 | 8 | from pytorch_tabular.models.common.layers import Add, Embedding1dLayer, GatedFeatureLearningUnit 9 | from pytorch_tabular.models.common.layers.activations import t_softmax 10 | from pytorch_tabular.utils import get_logger 11 | 12 | from ..base_model import BaseModel 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | class GANDALFBackbone(nn.Module): 18 | def __init__( 19 | self, 20 | cat_embedding_dims: list, 21 | n_continuous_features: int, 22 | gflu_stages: int, 23 | gflu_dropout: float = 0.0, 24 | gflu_feature_init_sparsity: float = 0.3, 25 | learnable_sparsity: bool = True, 26 | batch_norm_continuous_input: bool = True, 27 | virtual_batch_size: int = None, 28 | embedding_dropout: float = 0.0, 29 | ): 30 | super().__init__() 31 | self.gflu_stages = gflu_stages 32 | self.gflu_dropout = gflu_dropout 33 | self.batch_norm_continuous_input = batch_norm_continuous_input 34 | self.n_continuous_features = n_continuous_features 35 | self.cat_embedding_dims = cat_embedding_dims 36 | self._embedded_cat_features = sum([y for x, y in cat_embedding_dims]) 37 | self.n_features = self._embedded_cat_features + n_continuous_features 38 | self.embedding_dropout = embedding_dropout 39 | self.output_dim = self.n_continuous_features + self._embedded_cat_features 40 | self.gflu_feature_init_sparsity = gflu_feature_init_sparsity 41 | self.learnable_sparsity = learnable_sparsity 42 | self.virtual_batch_size = virtual_batch_size 43 | self._build_network() 44 | 45 | def _build_network(self): 46 | self.gflus = GatedFeatureLearningUnit( 47 | n_features_in=self.n_features, 48 | n_stages=self.gflu_stages, 49 | feature_mask_function=t_softmax, 50 | dropout=self.gflu_dropout, 51 | feature_sparsity=self.gflu_feature_init_sparsity, 52 | learnable_sparsity=self.learnable_sparsity, 53 | ) 54 | 55 | def _build_embedding_layer(self): 56 | return Embedding1dLayer( 57 | continuous_dim=self.n_continuous_features, 58 | categorical_embedding_dims=self.cat_embedding_dims, 59 | embedding_dropout=self.embedding_dropout, 60 | batch_norm_continuous_input=self.batch_norm_continuous_input, 61 | virtual_batch_size=self.virtual_batch_size, 62 | ) 63 | 64 | def forward(self, x: torch.Tensor) -> torch.Tensor: 65 | return self.gflus(x) 66 | 67 | @property 68 | def feature_importance_(self): 69 | return self.gflus.feature_mask_function(self.gflus.feature_masks).sum(dim=0).detach().cpu().numpy() 70 | 71 | 72 | class GANDALFModel(BaseModel): 73 | def __init__(self, config: DictConfig, **kwargs): 74 | super().__init__(config, **kwargs) 75 | 76 | @property 77 | def backbone(self): 78 | return self._backbone 79 | 80 | @property 81 | def embedding_layer(self): 82 | return self._embedding_layer 83 | 84 | @property 85 | def head(self): 86 | return self._head 87 | 88 | def _build_network(self): 89 | # Backbone 90 | self._backbone = GANDALFBackbone( 91 | cat_embedding_dims=self.hparams.embedding_dims, 92 | n_continuous_features=self.hparams.continuous_dim, 93 | gflu_stages=self.hparams.gflu_stages, 94 | gflu_dropout=self.hparams.gflu_dropout, 95 | gflu_feature_init_sparsity=self.hparams.gflu_feature_init_sparsity, 96 | learnable_sparsity=self.hparams.learnable_sparsity, 97 | batch_norm_continuous_input=self.hparams.batch_norm_continuous_input, 98 | embedding_dropout=self.hparams.embedding_dropout, 99 | virtual_batch_size=self.hparams.virtual_batch_size, 100 | ) 101 | # Embedding Layer 102 | self._embedding_layer = self._backbone._build_embedding_layer() 103 | # Head 104 | self.T0 = nn.Parameter(torch.rand(self.hparams.output_dim), requires_grad=True) 105 | self._head = nn.Sequential(self._get_head_from_config(), Add(self.T0)) 106 | 107 | def data_aware_initialization(self, datamodule): 108 | if self.hparams.task == "regression": 109 | logger.info("Data Aware Initialization of T0") 110 | # Need a big batch to initialize properly 111 | alt_loader = datamodule.train_dataloader(batch_size=self.hparams.data_aware_init_batch_size) 112 | batch = next(iter(alt_loader)) 113 | self.T0.data = torch.mean(batch["target"], dim=0) 114 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/gate/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import GatedAdditiveTreeEnsembleConfig 2 | from .gate_model import GatedAdditiveTreeEnsembleModel, GatedAdditiveTreesBackbone 3 | 4 | __all__ = [ 5 | "GatedAdditiveTreesBackbone", 6 | "GatedAdditiveTreeEnsembleModel", 7 | "GatedAdditiveTreeEnsembleConfig", 8 | ] 9 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/mixture_density/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import MDNConfig 2 | from .mdn import MDNModel 3 | 4 | __all__ = [ 5 | # "MixtureDensityHead", 6 | # "MixtureDensityHeadConfig", 7 | # "CategoryEmbeddingMDNConfig", 8 | # "CategoryEmbeddingMDN", 9 | # "NODEMDN", 10 | "MDNModel", 11 | "MDNConfig", 12 | # "AutoIntMDNConfig", 13 | # "AutoIntMDN" 14 | ] 15 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/mixture_density/config.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """Mixture Density Head Config.""" 5 | 6 | from dataclasses import dataclass, field 7 | from typing import Dict 8 | 9 | from pytorch_tabular.config.config import ModelConfig 10 | 11 | INCOMPATIBLE_BACKBONES = ["NodeConfig", "TabNetModelConfig", "MDNConfig"] 12 | 13 | 14 | @dataclass 15 | class MDNConfig(ModelConfig): 16 | """MDN configuration. 17 | 18 | Args: 19 | backbone_config_class (str): The config class for defining the Backbone. The config class should be 20 | a valid module path from `models`. e.g. `FTTransformerConfig` 21 | 22 | backbone_config_params (Dict): The dict of config parameters for defining the Backbone. 23 | 24 | 25 | task (str): Specify whether the problem is regression or classification. `backbone` is a task which 26 | considers the model as a backbone to generate features. Mostly used internally for SSL and related 27 | tasks. Choices are: [`regression`,`classification`,`backbone`]. 28 | 29 | head (str): 30 | 31 | head_config (Dict): The config for defining the Mixed Density Network Head 32 | 33 | embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a 34 | list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of 35 | the categorical column using the rule min(50, (x + 1) // 2) 36 | 37 | embedding_dropout (float): Dropout to be applied to the Categorical Embedding. Defaults to 0.0 38 | 39 | batch_norm_continuous_input (bool): If True, we will normalize the continuous layer by passing it 40 | through a BatchNorm layer. 41 | 42 | learning_rate (float): The learning rate of the model. Defaults to 1e-3. 43 | 44 | loss (Optional[str]): The loss function to be applied. By Default, it is MSELoss for regression and 45 | CrossEntropyLoss for classification. Unless you are sure what you are doing, leave it at MSELoss 46 | or L1Loss for regression and CrossEntropyLoss for classification 47 | 48 | metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics 49 | should be one of the functional metrics implemented in ``torchmetrics``. By default, it is 50 | accuracy if classification and mean_squared_error for regression 51 | 52 | metrics_params (Optional[List]): The parameters to be passed to the metrics function. `task` is forced to 53 | be `multiclass` because the multiclass version can handle binary as well and for simplicity we are 54 | only using `multiclass`. 55 | 56 | metrics_prob_input (Optional[List]): Is a mandatory parameter for classification metrics defined in the config. 57 | This defines whether the input to the metric function is the probability or the class. Length should be 58 | same as the number of metrics. Defaults to None. 59 | 60 | target_range (Optional[List]): The range in which we should limit the output variable. Currently 61 | ignored for multi-target regression. Typically used for Regression problems. If left empty, will 62 | not apply any restrictions 63 | 64 | seed (int): The seed for reproducibility. Defaults to 42 65 | 66 | """ 67 | 68 | backbone_config_class: str = field( 69 | default=None, 70 | metadata={ 71 | "help": "The config class for defining the Backbone." 72 | " The config class should be a valid module path from `models`. e.g. `FTTransformerConfig`" 73 | }, 74 | ) 75 | backbone_config_params: Dict = field( 76 | default=None, 77 | metadata={"help": "The dict of config parameters for defining the Backbone."}, 78 | ) 79 | head: str = field(init=False, default="MixtureDensityHead") 80 | head_config: Dict = field( 81 | default=None, 82 | metadata={"help": "The config for defining the Mixed Density Network Head"}, 83 | ) 84 | _module_src: str = field(default="models.mixture_density") 85 | _model_name: str = field(default="MDNModel") 86 | _config_name: str = field(default="MDNConfig") 87 | _probabilistic: bool = field(default=True) 88 | 89 | def __post_init__(self): 90 | assert ( 91 | self.backbone_config_class not in INCOMPATIBLE_BACKBONES 92 | ), f"{self.backbone_config_class} is not a supported backbone for MDN head" 93 | assert self.head == "MixtureDensityHead" 94 | return super().__post_init__() 95 | 96 | 97 | if __name__ == "__main__": 98 | from pytorch_tabular.utils import generate_doc_dataclass 99 | 100 | print(generate_doc_dataclass(MDNConfig)) 101 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/node/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import NodeConfig 2 | from .node_model import NODEBackbone, NODEModel 3 | 4 | __all__ = ["NODEModel", "NodeConfig", "NODEBackbone"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/node/architecture_blocks.py: -------------------------------------------------------------------------------- 1 | # Neural Oblivious Decision Ensembles 2 | # Author: Sergey Popov, Julian Qian 3 | # https://github.com/Qwicen/node 4 | # For license information, see https://github.com/Qwicen/node/blob/master/LICENSE.md 5 | """Dense ODST Block.""" 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from pytorch_tabular.models.common.layers import ODST 12 | 13 | 14 | class DenseODSTBlock(nn.Sequential): 15 | def __init__( 16 | self, 17 | input_dim, 18 | num_trees, 19 | num_layers, 20 | tree_output_dim=1, 21 | max_features=None, 22 | input_dropout=0.0, 23 | flatten_output=False, 24 | Module=ODST, 25 | **kwargs, 26 | ): 27 | layers = [] 28 | for i in range(num_layers): 29 | oddt = Module(input_dim, num_trees, tree_output_dim=tree_output_dim, flatten_output=True, **kwargs) 30 | input_dim = min(input_dim + num_trees * tree_output_dim, max_features or float("inf")) 31 | layers.append(oddt) 32 | 33 | super().__init__(*layers) 34 | self.num_layers, self.layer_dim, self.tree_dim = ( 35 | num_layers, 36 | num_trees, 37 | tree_output_dim, 38 | ) 39 | self.max_features, self.flatten_output = max_features, flatten_output 40 | self.input_dropout = input_dropout 41 | 42 | def forward(self, x): 43 | initial_features = x.shape[-1] 44 | for layer in self: 45 | layer_inp = x 46 | if self.max_features is not None: 47 | tail_features = min(self.max_features, layer_inp.shape[-1]) - initial_features 48 | if tail_features != 0: 49 | layer_inp = torch.cat( 50 | [ 51 | layer_inp[..., :initial_features], 52 | layer_inp[..., -tail_features:], 53 | ], 54 | dim=-1, 55 | ) 56 | if self.training and self.input_dropout: 57 | layer_inp = F.dropout(layer_inp, self.input_dropout) 58 | h = layer(layer_inp) 59 | x = torch.cat([x, h], dim=-1) 60 | 61 | outputs = x[..., initial_features:] 62 | if not self.flatten_output: 63 | outputs = outputs.view(*outputs.shape[:-1], self.num_layers * self.layer_dim, self.tree_dim) 64 | return outputs 65 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/node/node_model.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """Tabular Model.""" 5 | 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | from omegaconf import DictConfig 11 | 12 | from pytorch_tabular.models.common.layers import Embedding1dLayer 13 | from pytorch_tabular.utils import get_logger 14 | 15 | from ..base_model import BaseModel 16 | from ..common import activations 17 | from ..common.layers import Lambda 18 | from .architecture_blocks import DenseODSTBlock 19 | 20 | logger = get_logger(__name__) 21 | 22 | 23 | class NODEBackbone(nn.Module): 24 | def __init__(self, config: DictConfig, **kwargs): 25 | super().__init__() 26 | self.hparams = config 27 | # self.hparams.output_dim = (0 if self.hparams.output_dim is None else self.hparams.output_dim) 28 | # For SSL cases where output_dim will be None 29 | self._build_network() 30 | 31 | def _build_network(self): 32 | self.hparams.node_input_dim = self.hparams.continuous_dim + self.hparams.embedded_cat_dim 33 | self.dense_block = DenseODSTBlock( 34 | input_dim=self.hparams.node_input_dim, 35 | num_trees=self.hparams.num_trees, 36 | num_layers=self.hparams.num_layers, 37 | tree_output_dim=self.hparams.output_dim + self.hparams.additional_tree_output_dim, 38 | max_features=self.hparams.max_features, 39 | input_dropout=self.hparams.input_dropout, 40 | depth=self.hparams.depth, 41 | choice_function=getattr(activations, self.hparams.choice_function), 42 | bin_function=getattr(activations, self.hparams.bin_function), 43 | initialize_response_=getattr(nn.init, self.hparams.initialize_response + "_"), 44 | initialize_selection_logits_=getattr(nn.init, self.hparams.initialize_selection_logits + "_"), 45 | threshold_init_beta=self.hparams.threshold_init_beta, 46 | threshold_init_cutoff=self.hparams.threshold_init_cutoff, 47 | ) 48 | self.output_dim = self.hparams.output_dim + self.hparams.additional_tree_output_dim 49 | 50 | def _build_embedding_layer(self): 51 | embedding = Embedding1dLayer( 52 | continuous_dim=self.hparams.continuous_dim, 53 | categorical_embedding_dims=self.hparams.embedding_dims, 54 | embedding_dropout=self.hparams.embedding_dropout, 55 | batch_norm_continuous_input=self.hparams.batch_norm_continuous_input, 56 | virtual_batch_size=self.hparams.virtual_batch_size, 57 | ) 58 | return embedding 59 | 60 | def forward(self, x: torch.Tensor): 61 | x = self.dense_block(x) 62 | return x 63 | 64 | 65 | class NODEModel(BaseModel): 66 | def __init__(self, config: DictConfig, **kwargs): 67 | super().__init__(config, **kwargs) 68 | 69 | def subset(self, x): 70 | return x[..., : self.hparams.output_dim].mean(dim=-2) 71 | 72 | def data_aware_initialization(self, datamodule): 73 | """Performs data-aware initialization for NODE.""" 74 | logger.info( 75 | "Data Aware Initialization of NODE using a forward pass with " 76 | f"{self.hparams.data_aware_init_batch_size} batch size...." 77 | ) 78 | # Need a big batch to initialize properly 79 | alt_loader = datamodule.train_dataloader(batch_size=self.hparams.data_aware_init_batch_size) 80 | batch = next(iter(alt_loader)) 81 | for k, v in batch.items(): 82 | if isinstance(v, list) and (len(v) == 0): 83 | # Skipping empty list 84 | continue 85 | # batch[k] = v.to("cpu" if self.config.gpu == 0 else "cuda") 86 | batch[k] = v.to(self.device) 87 | 88 | # single forward pass to initialize the ODST 89 | with torch.no_grad(): 90 | self(batch) 91 | 92 | @property 93 | def backbone(self): 94 | return self._backbone 95 | 96 | @property 97 | def embedding_layer(self): 98 | return self._embedding_layer 99 | 100 | @property 101 | def head(self): 102 | return self._head 103 | 104 | def _build_network(self): 105 | self._backbone = NODEBackbone(self.hparams) 106 | # Embedding Layer 107 | self._embedding_layer = self._backbone._build_embedding_layer() 108 | # average first n channels of every tree, where n is the number of output targets for regression 109 | # and number of classes for classification 110 | # Not using config head because NODE has a specific head 111 | warnings.warn("Ignoring head config because NODE has a specific head which subsets the tree outputs") 112 | self._head = Lambda(self.subset) 113 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/stacking/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import StackingModelConfig 2 | from .stacking_model import StackingBackbone, StackingModel 3 | 4 | __all__ = ["StackingModel", "StackingModelConfig", "StackingBackbone"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/stacking/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | from pytorch_tabular.config import ModelConfig 4 | 5 | 6 | @dataclass 7 | class StackingModelConfig(ModelConfig): 8 | """StackingModelConfig is a configuration class for the StackingModel. It is used to stack multiple models 9 | together. Now, CategoryEmbeddingModel, TabNetModel, FTTransformerModel, GatedAdditiveTreeEnsembleModel, DANetModel, 10 | AutoIntModel, GANDALFModel, NodeModel are supported. 11 | 12 | Args: 13 | model_configs (list[ModelConfig]): List of model configs to stack. 14 | 15 | """ 16 | 17 | model_configs: list = field(default_factory=list, metadata={"help": "List of model configs to stack"}) 18 | _module_src: str = field(default="models.stacking") 19 | _model_name: str = field(default="StackingModel") 20 | _backbone_name: str = field(default="StackingBackbone") 21 | _config_name: str = field(default="StackingConfig") 22 | 23 | 24 | # if __name__ == "__main__": 25 | # from pytorch_tabular.utils import generate_doc_dataclass 26 | # print(generate_doc_dataclass(StackingModelConfig)) 27 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/tab_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import TabTransformerConfig 2 | from .tab_transformer import TabTransformerBackbone, TabTransformerModel 3 | 4 | __all__ = ["TabTransformerBackbone", "TabTransformerModel", "TabTransformerConfig"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/tabnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import TabNetModelConfig 2 | from .tabnet_model import TabNetBackbone, TabNetModel 3 | 4 | __all__ = ["TabNetModel", "TabNetModelConfig", "TabNetBackbone"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/models/tabnet/tabnet_model.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """TabNet Model.""" 5 | 6 | from typing import Dict 7 | 8 | import torch 9 | import torch.nn as nn 10 | from omegaconf import DictConfig 11 | from pytorch_tabnet.tab_network import TabNet 12 | from pytorch_tabnet.utils import create_group_matrix 13 | 14 | from ..base_model import BaseModel 15 | 16 | 17 | class TabNetBackbone(nn.Module): 18 | def __init__(self, config: DictConfig, **kwargs): 19 | super().__init__() 20 | self.hparams = config 21 | self._build_network() 22 | 23 | def _build_network(self): 24 | if self.hparams.grouped_features: 25 | # converting the grouped_features into a nested list of indices 26 | features = self.hparams.categorical_cols + self.hparams.continuous_cols 27 | grp_list = [ 28 | [features.index(col) for col in grp if col in features] for grp in self.hparams.grouped_features 29 | ] 30 | else: 31 | # creating a default grp_list with each feature as a group 32 | grp_list = [[i] for i in range(self.hparams.continuous_dim + self.hparams.categorical_dim)] 33 | group_matrix = create_group_matrix( 34 | grp_list, 35 | self.hparams.continuous_dim + self.hparams.categorical_dim, 36 | ) 37 | self.tabnet = TabNet( 38 | input_dim=self.hparams.continuous_dim + self.hparams.categorical_dim, 39 | output_dim=self.hparams.output_dim, 40 | n_d=self.hparams.n_d, 41 | n_a=self.hparams.n_a, 42 | n_steps=self.hparams.n_steps, 43 | gamma=self.hparams.gamma, 44 | cat_idxs=list(range(self.hparams.categorical_dim)), 45 | cat_dims=[cardinality for cardinality, _ in self.hparams.embedding_dims], 46 | cat_emb_dim=[embed_dim for _, embed_dim in self.hparams.embedding_dims], 47 | n_independent=self.hparams.n_independent, 48 | n_shared=self.hparams.n_shared, 49 | epsilon=1e-15, 50 | virtual_batch_size=self.hparams.virtual_batch_size, 51 | momentum=0.02, 52 | mask_type=self.hparams.mask_type, 53 | group_attention_matrix=group_matrix, 54 | ) 55 | 56 | def unpack_input(self, x: Dict): 57 | # unpacking into a tuple 58 | x = x["categorical"], x["continuous"] 59 | # eliminating None in case there is no categorical or continuous columns 60 | x = (item for item in x if len(item) > 0) 61 | x = torch.cat(tuple(x), dim=1) 62 | return x 63 | 64 | def forward(self, x: Dict): 65 | # unpacking into a tuple 66 | x = self.unpack_input(x) 67 | # Making two parameters to the right device. 68 | self.tabnet.embedder.embedding_group_matrix = self.tabnet.embedder.embedding_group_matrix.to(x.device) 69 | self.tabnet.tabnet.encoder.group_attention_matrix = self.tabnet.tabnet.encoder.group_attention_matrix.to( 70 | x.device 71 | ) 72 | # Returns output and Masked Loss. We only need the output 73 | x, _ = self.tabnet(x) 74 | return x 75 | 76 | 77 | class TabNetModel(BaseModel): 78 | def __init__(self, config: DictConfig, **kwargs): 79 | assert config.task in [ 80 | "regression", 81 | "classification", 82 | ], "TabNet is only implemented for Regression and Classification" 83 | super().__init__(config, **kwargs) 84 | 85 | @property 86 | def backbone(self): 87 | return self._backbone 88 | 89 | @property 90 | def embedding_layer(self): 91 | return self._embedding_layer 92 | 93 | @property 94 | def head(self): 95 | return self._head 96 | 97 | def _build_network(self): 98 | # TabNet has its own embedding layer. 99 | # So we are not using the embedding layer from BaseModel 100 | self._embedding_layer = nn.Identity() 101 | self._backbone = TabNetBackbone(self.hparams) 102 | setattr(self.backbone, "output_dim", self.hparams.output_dim) 103 | # TabNet has its own head 104 | self._head = nn.Identity() 105 | 106 | def extract_embedding(self): 107 | raise ValueError("Extracting Embeddings is not supported by Tabnet. Please use another" " compatible model") 108 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dae 2 | from .base_model import SSLBaseModel 3 | from .dae import DenoisingAutoEncoderConfig, DenoisingAutoEncoderModel 4 | 5 | __all__ = ["DenoisingAutoEncoderConfig", "DenoisingAutoEncoderModel", "SSLBaseModel", "dae"] 6 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manujosephv/pytorch_tabular/023db2776f96a0f2854e837eef62840be1a12a5e/src/pytorch_tabular/ssl_models/common/__init__.py -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/augmentations.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def mixup(batch: Dict, lam: float = 0.5) -> Dict: 8 | """It apply mixup augmentation, making a weighted average between a tensor and some random element of the tensor 9 | taking random rows. 10 | 11 | :param batch: Tensor on which apply the mixup augmentation 12 | :param lam: weight in the linear combination between the original values and the random permutation 13 | 14 | """ 15 | result = {} 16 | for key, value in batch.items(): 17 | random_index = _get_random_index(value) 18 | result[key] = lam * value + (1 - lam) * value[random_index, :] 19 | result[key] = result[key].to(dtype=value.dtype) 20 | return result 21 | 22 | 23 | def cutmix(batch: Dict, lam: float = 0.1) -> Dict: 24 | """Define how apply cutmix to a tensor. 25 | 26 | :param batch: Tensor on which apply the cutmix augmentation 27 | :param lam: probability values have 0 in a binary random mask, so it means probability original values will be 28 | updated 29 | 30 | """ 31 | result = {} 32 | for key, value in batch.items(): 33 | random_index = _get_random_index(value) 34 | x_binary_mask = torch.from_numpy(np.random.choice(2, size=value.shape, p=[lam, 1 - lam])) 35 | x_random = value[random_index, :] 36 | x_noised = value.clone().detach() 37 | x_noised[x_binary_mask == 0] = x_random[x_binary_mask == 0] 38 | result[key] = x_noised 39 | return result 40 | 41 | 42 | def _get_random_index(x: torch.Tensor) -> torch.Tensor: 43 | """Given a tensor it compute random indices between 0 and the number of the first dimension. 44 | 45 | :param x: Tensor used to get the number of rows 46 | 47 | """ 48 | batch_size = x.size()[0] 49 | index = torch.randperm(batch_size) 50 | return index 51 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/heads.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """SSL Heads.""" 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class MultiTaskHead(nn.Module): 10 | """Simple Linear transformation to take last hidden representation to reconstruct inputs. 11 | 12 | Output is dictionary of variable type to tensor mapping. 13 | 14 | """ 15 | 16 | def __init__(self, in_features, n_binary=0, n_categorical=0, n_numerical=0, cardinality=[]): 17 | super().__init__() 18 | assert n_categorical == len(cardinality), "require cardinalities for each categorical variable" 19 | assert n_binary + n_categorical + n_numerical, "need some targets" 20 | self.n_binary = n_binary 21 | self.n_categorical = n_categorical 22 | self.n_numerical = n_numerical 23 | 24 | self.binary_linear = nn.Linear(in_features, n_binary) if n_binary else None 25 | self.categorical_linears = nn.ModuleList([nn.Linear(in_features, card) for card in cardinality]) 26 | self.numerical_linear = nn.Linear(in_features, n_numerical) if n_numerical else None 27 | 28 | def forward(self, features): 29 | outputs = {} 30 | 31 | if self.binary_linear: 32 | outputs["binary"] = self.binary_linear(features) 33 | 34 | if self.categorical_linears: 35 | outputs["categorical"] = [linear(features) for linear in self.categorical_linears] 36 | 37 | if self.numerical_linear: 38 | outputs["continuous"] = self.numerical_linear(features) 39 | 40 | return outputs 41 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/layers.py: -------------------------------------------------------------------------------- 1 | # W605 2 | from collections import OrderedDict 3 | from typing import Any, Dict, Tuple 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from pytorch_tabular.models.common.layers.batch_norm import BatchNorm1d 9 | from pytorch_tabular.ssl_models.common.utils import OneHot 10 | 11 | 12 | class MixedEmbedding1dLayer(nn.Module): 13 | """Enables different values in a categorical features to have different embeddings.""" 14 | 15 | def __init__( 16 | self, 17 | continuous_dim: int, 18 | categorical_embedding_dims: Tuple[int, int], 19 | max_onehot_cardinality: int = 4, 20 | embedding_dropout: float = 0.0, 21 | batch_norm_continuous_input: bool = False, 22 | virtual_batch_size: int = None, 23 | ): 24 | super().__init__() 25 | self.continuous_dim = continuous_dim 26 | self.categorical_embedding_dims = categorical_embedding_dims 27 | self.categorical_dim = len(categorical_embedding_dims) 28 | self.batch_norm_continuous_input = batch_norm_continuous_input 29 | 30 | binary_feat_idx = [] 31 | onehot_feat_idx = [] 32 | embedding_feat_idx = [] 33 | embd_layers = {} 34 | one_hot_layers = {} 35 | for i, (cardinality, embed_dim) in enumerate(categorical_embedding_dims): 36 | # conditions based on enhanced cardinality (including missing/new value placeholder) 37 | if cardinality == 2: 38 | binary_feat_idx.append(i) 39 | elif cardinality <= max_onehot_cardinality: 40 | onehot_feat_idx.append(i) 41 | one_hot_layers[str(i)] = OneHot(cardinality) 42 | else: 43 | embedding_feat_idx.append(i) 44 | embd_layers[str(i)] = nn.Embedding(cardinality, embed_dim) 45 | 46 | if self.categorical_dim > 0: 47 | # Embedding layers 48 | self.embedding_layer = nn.ModuleDict(embd_layers) 49 | self.one_hot_layers = nn.ModuleDict(one_hot_layers) 50 | self._onehot_feat_idx = onehot_feat_idx 51 | self._binary_feat_idx = binary_feat_idx 52 | self._embedding_feat_idx = embedding_feat_idx 53 | 54 | if embedding_dropout > 0 and len(embedding_feat_idx) > 0: 55 | self.embd_dropout = nn.Dropout(embedding_dropout) 56 | else: 57 | self.embd_dropout = None 58 | # Continuous Layers 59 | if batch_norm_continuous_input: 60 | self.normalizing_batch_norm = BatchNorm1d(continuous_dim, virtual_batch_size) 61 | 62 | @property 63 | def embedded_cat_dim(self): 64 | return sum( 65 | [ 66 | embd_dim 67 | for i, (_, embd_dim) in enumerate(self.categorical_embedding_dims) 68 | if i in self._embedding_feat_idx 69 | ] 70 | ) 71 | 72 | def forward(self, x: Dict[str, Any]) -> torch.Tensor: 73 | assert "continuous" in x or "categorical" in x, "x must contain either continuous and categorical features" 74 | # (B, N) 75 | continuous_data, categorical_data = ( 76 | x.get("continuous", torch.empty(0, 0)), 77 | x.get("categorical", torch.empty(0, 0)), 78 | ) 79 | assert categorical_data.shape[1] == len( 80 | self._onehot_feat_idx + self._binary_feat_idx + self._embedding_feat_idx 81 | ), "categorical_data must have same number of columns as categorical embedding layers" 82 | assert ( 83 | continuous_data.shape[1] == self.continuous_dim 84 | ), "continuous_data must have same number of columns as continuous dim" 85 | # embed = None 86 | if continuous_data.shape[1] > 0: 87 | if self.batch_norm_continuous_input: 88 | continuous_data = self.normalizing_batch_norm(continuous_data) 89 | # (B, N, C) 90 | if categorical_data.shape[1] > 0: 91 | x_cat = [] 92 | x_cat_orig = [] 93 | x_binary = [] 94 | x_embed = [] 95 | for i in range(self.categorical_dim): 96 | if i in self._binary_feat_idx: 97 | x_binary.append(categorical_data[:, i : i + 1]) 98 | elif i in self._onehot_feat_idx: 99 | x_cat.append(self.one_hot_layers[str(i)](categorical_data[:, i])) 100 | x_cat_orig.append(categorical_data[:, i : i + 1]) 101 | else: 102 | x_embed.append(self.embedding_layer[str(i)](categorical_data[:, i])) 103 | # (B, N, E) 104 | x_cat = torch.cat(x_cat, 1) if len(x_cat) > 0 else None 105 | x_cat_orig = torch.cat(x_cat_orig, 1) if len(x_cat_orig) > 0 else None 106 | x_binary = torch.cat(x_binary, 1) if len(x_binary) > 0 else None 107 | x_embed = torch.cat(x_embed, 1) if len(x_embed) > 0 else None 108 | all_none = (x_cat is None) and (x_binary is None) and (x_embed is None) 109 | assert not all_none, "All inputs can't be none!" 110 | if self.embd_dropout is not None: 111 | x_embed = self.embd_dropout(x_embed) 112 | else: 113 | x_cat = None 114 | x_cat_orig = None 115 | x_binary = None 116 | x_embed = None 117 | return OrderedDict( 118 | binary=x_binary, 119 | categorical=x_cat, 120 | _categorical_orig=x_cat_orig, 121 | continuous=continuous_data, 122 | embedding=x_embed, 123 | ) 124 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/noise_generators.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | # Inspired by implementation https://github.com/ryancheunggit/tabular_dae 5 | """DenoisingAutoEncoder Model.""" 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class SwapNoiseCorrupter(nn.Module): 13 | """Apply swap noise on the input data. 14 | 15 | Each data point has specified chance be replaced by a random value from the same column. 16 | 17 | """ 18 | 19 | def __init__(self, probas): 20 | super().__init__() 21 | self.probas = torch.from_numpy(np.array(probas, dtype=np.float32)) 22 | 23 | def forward(self, x): 24 | should_swap = torch.bernoulli(self.probas.to(x.device) * torch.ones(x.shape).to(x.device)) 25 | corrupted_x = torch.where(should_swap == 1, x[torch.randperm(x.shape[0])], x) 26 | mask = (corrupted_x != x).float() 27 | return corrupted_x, mask 28 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/ssl_losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def loss_contrastive(y_hat, y): 5 | return -nn.functional.cosine_similarity(y_hat, y).add_(-1).sum() 6 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/ssl_utils.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch import Tensor 3 | 4 | from pytorch_tabular.models.common import PositionWiseFeedForward 5 | 6 | 7 | class Denoising(pl.LightningModule): 8 | def __init__(self, input_dim: int): 9 | super().__init__() 10 | self.mlp = PositionWiseFeedForward(d_model=input_dim, d_ff=2 * input_dim) 11 | 12 | def forward(self, x: Tensor): 13 | return {"logits": self.mlp(x)} 14 | 15 | 16 | class Contrastive(pl.LightningModule): 17 | def __init__(self, input_dim: int): 18 | super().__init__() 19 | self.mlp = PositionWiseFeedForward(d_model=input_dim, d_ff=2 * input_dim) 20 | 21 | def forward(self, x: Tensor): 22 | x = x / x.norm(dim=-1, keepdim=True) 23 | return {"logits": self.mlp(x)} 24 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/common/utils.py: -------------------------------------------------------------------------------- 1 | # Pytorch Tabular 2 | # Author: Manu Joseph 3 | # For license information, see LICENSE.TXT 4 | """Utilities.""" 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class OneHot(nn.Module): 11 | def __init__(self, cardinality): 12 | super().__init__() 13 | self.cardinality = cardinality 14 | 15 | def forward(self, x): 16 | return F.one_hot(x, self.cardinality) 17 | -------------------------------------------------------------------------------- /src/pytorch_tabular/ssl_models/dae/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import DenoisingAutoEncoderConfig 2 | from .dae import DenoisingAutoEncoderModel 3 | 4 | __all__ = ["DenoisingAutoEncoderModel", "DenoisingAutoEncoderConfig"] 5 | -------------------------------------------------------------------------------- /src/pytorch_tabular/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import ( 2 | get_balanced_sampler, 3 | get_class_weighted_cross_entropy, 4 | get_gaussian_centers, 5 | load_covertype_dataset, 6 | make_mixed_dataset, 7 | print_metrics, 8 | ) 9 | from .logger import get_logger 10 | from .nn_utils import ( 11 | OOMException, 12 | OutOfMemoryHandler, 13 | _initialize_kaiming, 14 | _initialize_layers, 15 | _linear_dropout_bn, 16 | _make_ix_like, 17 | count_parameters, 18 | reset_all_weights, 19 | to_one_hot, 20 | ) 21 | from .python_utils import ( 22 | available_models, 23 | available_ssl_models, 24 | check_numpy, 25 | enable_lightning_logs, 26 | generate_doc_dataclass, 27 | getattr_nested, 28 | ifnone, 29 | int_to_human_readable, 30 | pl_load, 31 | suppress_lightning_logs, 32 | ) 33 | 34 | __all__ = [ 35 | "get_logger", 36 | "getattr_nested", 37 | "generate_doc_dataclass", 38 | "ifnone", 39 | "pl_load", 40 | "_initialize_layers", 41 | "_linear_dropout_bn", 42 | "reset_all_weights", 43 | "get_class_weighted_cross_entropy", 44 | "get_balanced_sampler", 45 | "get_gaussian_centers", 46 | "_make_ix_like", 47 | "to_one_hot", 48 | "_initialize_kaiming", 49 | "check_numpy", 50 | "OutOfMemoryHandler", 51 | "OOMException", 52 | "make_mixed_dataset", 53 | "print_metrics", 54 | "load_covertype_dataset", 55 | "count_parameters", 56 | "int_to_human_readable", 57 | "suppress_lightning_logs", 58 | "enable_lightning_logs", 59 | "available_models", 60 | "available_ssl_models", 61 | ] 62 | -------------------------------------------------------------------------------- /src/pytorch_tabular/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from rich.logging import RichHandler 5 | 6 | 7 | def get_logger(name): 8 | logger = logging.getLogger(name) 9 | # ch = logging.StreamHandler() 10 | logger.setLevel(level=os.environ.get("PT_LOGLEVEL", "INFO")) 11 | formatter = logging.Formatter("%(asctime)s - {%(name)s:%(lineno)d} - %(levelname)s - %(message)s") 12 | if not logger.hasHandlers(): 13 | ch = RichHandler(show_level=False, show_time=False, show_path=False, rich_tracebacks=True) 14 | ch.setLevel(logging.DEBUG) 15 | ch.setFormatter(formatter) 16 | logger.addHandler(ch) 17 | logger.propagate = False 18 | return logger 19 | -------------------------------------------------------------------------------- /src/pytorch_tabular/utils/nn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .logger import get_logger 5 | 6 | logger = get_logger(__name__) 7 | 8 | 9 | def _initialize_layers(activation, initialization, layers): 10 | if type(layers) is nn.Sequential: 11 | for layer in layers: 12 | if hasattr(layer, "weight"): 13 | _initialize_layers(activation, initialization, layer) 14 | else: 15 | if activation == "ReLU": 16 | nonlinearity = "relu" 17 | elif activation == "LeakyReLU": 18 | nonlinearity = "leaky_relu" 19 | else: 20 | if initialization == "kaiming": 21 | logger.warning("Kaiming initialization is only recommended for ReLU and" " LeakyReLU.") 22 | nonlinearity = "leaky_relu" 23 | else: 24 | nonlinearity = "relu" 25 | 26 | if initialization == "kaiming": 27 | nn.init.kaiming_normal_(layers.weight, nonlinearity=nonlinearity) 28 | elif initialization == "xavier": 29 | nn.init.xavier_normal_( 30 | layers.weight, 31 | gain=(nn.init.calculate_gain(nonlinearity) if activation in ["ReLU", "LeakyReLU"] else 1), 32 | ) 33 | elif initialization == "random": 34 | nn.init.normal_(layers.weight) 35 | 36 | 37 | def _linear_dropout_bn(activation, initialization, use_batch_norm, in_units, out_units, dropout): 38 | if isinstance(activation, str): 39 | _activation = getattr(nn, activation) 40 | else: 41 | _activation = activation 42 | layers = [] 43 | if use_batch_norm: 44 | from pytorch_tabular.models.common.layers.batch_norm import BatchNorm1d 45 | 46 | layers.append(BatchNorm1d(num_features=in_units)) 47 | linear = nn.Linear(in_units, out_units) 48 | _initialize_layers(activation, initialization, linear) 49 | layers.extend([linear, _activation()]) 50 | if dropout != 0: 51 | layers.append(nn.Dropout(dropout)) 52 | return layers 53 | 54 | 55 | def reset_all_weights(model: nn.Module) -> None: 56 | """Resets all parameters in a network. 57 | 58 | Args: 59 | model: The model to reset the parameters of. 60 | 61 | refs: 62 | - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6 63 | - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch 64 | - https://pytorch.org/docs/stable/generated/torch.nn.Module.html 65 | 66 | """ 67 | 68 | @torch.no_grad() 69 | def weight_reset(m: nn.Module): 70 | # - check if the current module has reset_parameters & if it's callabed called it on m 71 | reset_parameters = getattr(m, "reset_parameters", None) 72 | if callable(reset_parameters): 73 | m.reset_parameters() 74 | 75 | # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html 76 | model.apply(fn=weight_reset) 77 | 78 | 79 | def to_one_hot(y, depth=None): 80 | r"""Takes integer with n dims and converts it to 1-hot representation with n + 1 dims. 81 | 82 | The n+1'st dimension will have zeros everywhere but at y'th index, where it will be equal to 1. 83 | Args: 84 | y: input integer (IntTensor, LongTensor or Variable) of any shape 85 | depth (int): the size of the one hot dimension 86 | 87 | """ 88 | y_flat = y.to(torch.int64).view(-1, 1) 89 | depth = depth or int(torch.max(y_flat)) + 1 90 | y_one_hot = torch.zeros(y_flat.size()[0], depth, device=y.device).scatter_(1, y_flat, 1) 91 | y_one_hot = y_one_hot.view(*(tuple(y.shape) + (-1,))) 92 | return y_one_hot 93 | 94 | 95 | def _make_ix_like(input, dim=0): 96 | d = input.size(dim) 97 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 98 | view = [1] * input.dim() 99 | view[0] = -1 100 | return rho.view(view).transpose(0, dim) 101 | 102 | 103 | def _initialize_kaiming(x, initialization, d_sqrt_inv): 104 | if initialization == "kaiming_uniform": 105 | nn.init.uniform_(x, a=-d_sqrt_inv, b=d_sqrt_inv) 106 | elif initialization == "kaiming_normal": 107 | nn.init.normal_(x, std=d_sqrt_inv) 108 | elif initialization is None: 109 | pass 110 | else: 111 | raise NotImplementedError("initialization should be either of `kaiming_normal`, `kaiming_uniform`," " `None`") 112 | 113 | 114 | class OutOfMemoryHandler: 115 | """Context manager to handle out of memory errors. 116 | 117 | Args: 118 | handle_oom: Whether to handle the error or not. If set to False, 119 | the exception will be propagated. 120 | 121 | """ 122 | 123 | def __init__(self, handle_oom: bool = True): 124 | self.handle_oom = handle_oom 125 | self.oom_triggered = False 126 | self.oom_msg = None 127 | 128 | def __enter__(self): 129 | return self 130 | 131 | def __exit__(self, exc_type, exc_value, traceback): 132 | is_oom_runtime_error = exc_type is RuntimeError and "out of memory" in str(exc_value) 133 | try: 134 | is_cuda_oom_error = exc_type is torch.cuda.OutOfMemoryError 135 | except AttributeError: 136 | # before torch 1.13.0, torch.cuda.OutOfMemoryError did not exist 137 | is_cuda_oom_error = False 138 | if (is_oom_runtime_error or is_cuda_oom_error) and self.handle_oom: 139 | self.oom_triggered = True 140 | self.oom_msg = exc_value.args[0] 141 | torch.cuda.empty_cache() 142 | return True # Suppress the exception 143 | return False # Propagate any other exceptions 144 | 145 | 146 | class OOMException(Exception): 147 | pass 148 | 149 | 150 | def count_parameters(model): 151 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 152 | -------------------------------------------------------------------------------- /src/pytorch_tabular/utils/python_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import textwrap 3 | from pathlib import Path 4 | from typing import IO, Any, Callable, Dict, Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | 9 | try: # for 1.8 10 | from pytorch_lightning.utilities.cloud_io import get_filesystem 11 | except ImportError: # for 1.9 12 | from pytorch_lightning.core.saving import get_filesystem 13 | 14 | import pytorch_tabular as root_module 15 | 16 | from .logger import get_logger 17 | 18 | _PATH = Union[str, Path] 19 | _DEVICE = Union[torch.device, str, int] 20 | _MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] 21 | 22 | 23 | logger = get_logger(__name__) 24 | 25 | 26 | def getattr_nested(_module_src, _model_name): 27 | module = root_module 28 | for m in _module_src.split("."): 29 | module = getattr(module, m) 30 | return getattr(module, _model_name) 31 | 32 | 33 | def ifnone(arg, default_arg): 34 | return default_arg if arg is None else arg 35 | 36 | 37 | def generate_doc_dataclass(dataclass, desc=None, width=100): 38 | if desc is not None: 39 | doc_str = f"{desc}\nArgs:" 40 | else: 41 | doc_str = "Args:" 42 | for key in dataclass.__dataclass_fields__.keys(): 43 | if key.startswith("_"): # Skipping private fields 44 | continue 45 | atr = dataclass.__dataclass_fields__[key] 46 | if atr.init: 47 | type = str(atr.type).replace("", "").replace("typing.", "") 48 | help_str = atr.metadata.get("help", "") 49 | if "choices" in atr.metadata.keys(): 50 | help_str += ". Choices are:" f" [{','.join(['`'+str(ch)+'`' for ch in atr.metadata['choices']])}]." 51 | # help_str += f'. Defaults to {atr.default}' 52 | h_str = textwrap.fill( 53 | f"{key} ({type}): {help_str}", 54 | width=width, 55 | subsequent_indent="\t\t", 56 | initial_indent="\t", 57 | ) 58 | h_str = f"\n{h_str}\n" 59 | doc_str += h_str 60 | return doc_str 61 | 62 | 63 | # Copied over pytorch_lightning.utilities.cloud_io.load as it was deprecated 64 | def pl_load( 65 | path_or_url: Union[IO, _PATH], 66 | map_location: _MAP_LOCATION_TYPE = None, 67 | ) -> Any: 68 | """Loads a checkpoint. 69 | 70 | Args: 71 | path_or_url: Path or URL of the checkpoint. 72 | map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations. 73 | 74 | """ 75 | if not isinstance(path_or_url, (str, Path)): 76 | # any sort of BytesIO or similar 77 | # get the torch version 78 | torch_version = torch.__version__ 79 | if torch_version < "2.6": 80 | return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6 81 | elif torch_version >= "2.6": 82 | return torch.load(path_or_url, map_location=map_location, weights_only=False) 83 | if str(path_or_url).startswith("http"): 84 | return torch.hub.load_state_dict_from_url( 85 | str(path_or_url), 86 | map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct 87 | ) 88 | fs = get_filesystem(path_or_url) 89 | with fs.open(path_or_url, "rb") as f: 90 | torch_version = torch.__version__ 91 | if torch_version < "2.6": 92 | return torch.load(f, map_location=map_location) # for torch version < 2.6 93 | elif torch_version >= "2.6": 94 | return torch.load(f, map_location=map_location, weights_only=False) 95 | 96 | 97 | def check_numpy(x): 98 | """Makes sure x is a numpy array.""" 99 | if isinstance(x, torch.Tensor): 100 | x = x.detach().cpu().numpy() 101 | x = np.asarray(x) 102 | assert isinstance(x, np.ndarray) 103 | return x 104 | 105 | 106 | def int_to_human_readable(number: int, round_number=True) -> str: 107 | millnames = ["", " T", " M", " B", " T"] 108 | n = float(number) 109 | millidx = max( 110 | 0, 111 | min( 112 | len(millnames) - 1, 113 | int(math.floor(0 if n == 0 else math.log10(abs(n)) / 3)), 114 | ), 115 | ) 116 | if round_number: 117 | return f"{int(n / 10 ** (3 * millidx))}{millnames[millidx]}" 118 | else: 119 | return f"{n / 10 ** (3 * millidx):.2f}{millnames[millidx]}" 120 | 121 | 122 | def suppress_lightning_logs(log_level=None): 123 | import logging 124 | 125 | log_level = log_level or logging.ERROR 126 | for logger_name in logging.root.manager.loggerDict: 127 | if logger_name.startswith("pytorch_lightning") or logger_name.startswith("lightning"): 128 | logging.getLogger(logger_name).setLevel(log_level) 129 | 130 | 131 | def enable_lightning_logs(log_level=None): 132 | import logging 133 | 134 | log_level = log_level or logging.INFO 135 | 136 | for logger_name in logging.root.manager.loggerDict: 137 | if logger_name.startswith("pytorch_lightning") or logger_name.startswith("lightning"): 138 | logging.getLogger(logger_name).setLevel(log_level) 139 | 140 | 141 | def available_models(): 142 | from pytorch_tabular import models 143 | 144 | return [cl for cl in dir(models) if "config" in cl.lower()] 145 | 146 | 147 | def available_ssl_models(): 148 | from pytorch_tabular import ssl_models 149 | 150 | return [cl for cl in dir(ssl_models) if "config" in cl.lower()] 151 | -------------------------------------------------------------------------------- /tests/___test_augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from pytorch_tabular.ssl_models.common.augmentations import _get_random_index, cutmix, mixup 5 | 6 | 7 | def test_get_random_index(): 8 | torch.manual_seed(0) 9 | x = torch.Tensor([1, 2, 3]) 10 | expected = np.array([2, 0, 1]) 11 | actual = _get_random_index(x).numpy() 12 | np.testing.assert_array_equal(actual, expected) 13 | 14 | 15 | def test_mixup(): 16 | torch.manual_seed(0) 17 | np.random.seed(0) 18 | x = torch.Tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]) 19 | lam = 0.5 20 | expected = torch.Tensor([[2.0, 2.0], [1.5, 1.5], [2.5, 2.5]]).numpy() 21 | actual = mixup(batch={"x": x}, lam=lam) 22 | np.testing.assert_array_equal(actual["x"].numpy(), expected) 23 | 24 | 25 | def test_cutmix(): 26 | torch.manual_seed(0) 27 | np.random.seed(0) 28 | x = torch.Tensor([[1, 1], [2, 2], [3, 3]]) 29 | lam = 0.5 30 | expected = torch.Tensor([[1.0, 1.0], [2.0, 2.0], [2.0, 3.0]]).numpy() 31 | actual = cutmix(batch={"x": x}, lam=lam) 32 | np.testing.assert_array_equal(actual["x"].numpy(), expected) 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for pytorch_tabular.""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from zipfile import ZipFile 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | from sklearn.datasets import fetch_california_housing, fetch_covtype 8 | 9 | _PATH_TEST = os.path.dirname(__file__) 10 | PATH_DATASETS = os.path.join(_PATH_TEST, ".datasets") 11 | os.makedirs(PATH_DATASETS, exist_ok=True) 12 | 13 | DATASET_ZIP_OCCUPANCY = os.path.join(PATH_DATASETS, "occupancy_data.zip") 14 | if not os.path.isfile(DATASET_ZIP_OCCUPANCY): 15 | import urllib.request 16 | 17 | urllib.request.urlretrieve( 18 | "http://archive.ics.uci.edu/ml/machine-learning-databases/00357/occupancy_data.zip", DATASET_ZIP_OCCUPANCY 19 | ) 20 | 21 | 22 | def load_regression_data(): 23 | dataset = fetch_california_housing(data_home="data", as_frame=True) 24 | df = dataset.frame.sample(5000) 25 | df["HouseAgeBin"] = pd.qcut(df["HouseAge"], q=4) 26 | df["HouseAgeBin"] = "age_" + df.HouseAgeBin.cat.codes.astype(str) 27 | test_idx = df.sample(int(0.2 * len(df)), random_state=42).index 28 | test = df[df.index.isin(test_idx)] 29 | train = df[~df.index.isin(test_idx)] 30 | return (train, test, dataset.target_names) 31 | 32 | 33 | def load_classification_data(): 34 | dataset = fetch_covtype(data_home="data") 35 | data = np.hstack([dataset.data, dataset.target.reshape(-1, 1)])[:10000, :] 36 | col_names = [f"feature_{i}" for i in range(data.shape[-1])] 37 | col_names[-1] = "target" 38 | data = pd.DataFrame(data, columns=col_names) 39 | data["feature_0_cat"] = pd.qcut(data["feature_0"], q=4) 40 | data["feature_0_cat"] = "feature_0_" + data.feature_0_cat.cat.codes.astype(str) 41 | test_idx = data.sample(int(0.2 * len(data)), random_state=42).index 42 | test = data[data.index.isin(test_idx)] 43 | train = data[~data.index.isin(test_idx)] 44 | return (train, test, ["target"]) 45 | 46 | 47 | def load_timeseries_data(): 48 | zipfile = ZipFile(DATASET_ZIP_OCCUPANCY) 49 | train = pd.read_csv(zipfile.open("datatraining.txt"), sep=",") 50 | val = pd.read_csv(zipfile.open("datatest.txt"), sep=",") 51 | test = pd.read_csv(zipfile.open("datatest2.txt"), sep=",") 52 | return (pd.concat([train, val], sort=False), test, ["Occupancy"]) 53 | 54 | 55 | @pytest.fixture(scope="session", autouse=True) 56 | def regression_data(): 57 | return load_regression_data() 58 | 59 | 60 | @pytest.fixture(scope="session", autouse=True) 61 | def classification_data(): 62 | return load_classification_data() 63 | 64 | 65 | @pytest.fixture(scope="session", autouse=True) 66 | def timeseries_data(): 67 | return load_timeseries_data() 68 | -------------------------------------------------------------------------------- /tests/test_mdn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Tests for `pytorch_tabular` package.""" 3 | 4 | import pytest 5 | 6 | from pytorch_tabular import TabularModel 7 | from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig 8 | from pytorch_tabular.models import MDNConfig 9 | 10 | 11 | @pytest.mark.parametrize("multi_target", [False]) 12 | @pytest.mark.parametrize( 13 | "continuous_cols", 14 | [ 15 | [ 16 | "AveRooms", 17 | "AveBedrms", 18 | "Population", 19 | "AveOccup", 20 | "Latitude", 21 | "Longitude", 22 | ] 23 | ], 24 | ) 25 | @pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]]) 26 | @pytest.mark.parametrize("continuous_feature_transform", [None]) 27 | @pytest.mark.parametrize("normalize_continuous_features", [True]) 28 | @pytest.mark.parametrize("variant", ["CategoryEmbeddingModelConfig", "TabTransformerConfig", "FTTransformerConfig"]) 29 | @pytest.mark.parametrize("num_gaussian", [1, 2]) 30 | def test_regression( 31 | regression_data, 32 | multi_target, 33 | continuous_cols, 34 | categorical_cols, 35 | continuous_feature_transform, 36 | normalize_continuous_features, 37 | variant, 38 | num_gaussian, 39 | ): 40 | (train, test, target) = regression_data 41 | data_config = DataConfig( 42 | target=target + ["MedInc"] if multi_target else target, 43 | continuous_cols=continuous_cols, 44 | categorical_cols=categorical_cols, 45 | continuous_feature_transform=continuous_feature_transform, 46 | normalize_continuous_features=normalize_continuous_features, 47 | ) 48 | model_config_params = {"task": "regression"} 49 | mdn_config = {"num_gaussian": num_gaussian} 50 | model_config_params["head_config"] = mdn_config 51 | model_config_params["backbone_config_class"] = variant 52 | model_config_params["backbone_config_params"] = {"task": "backbone"} 53 | 54 | model_config = MDNConfig(**model_config_params) 55 | trainer_config = TrainerConfig( 56 | max_epochs=3, 57 | checkpoints=None, 58 | early_stopping=None, 59 | accelerator="cpu", 60 | fast_dev_run=True, 61 | ) 62 | optimizer_config = OptimizerConfig() 63 | 64 | tabular_model = TabularModel( 65 | data_config=data_config, 66 | model_config=model_config, 67 | optimizer_config=optimizer_config, 68 | trainer_config=trainer_config, 69 | ) 70 | tabular_model.fit(train=train) 71 | 72 | result = tabular_model.evaluate(test) 73 | # print(result[0]["valid_loss"]) 74 | assert "test_mean_squared_error" in result[0].keys() 75 | pred_df = tabular_model.predict(test) 76 | assert pred_df.shape[0] == test.shape[0] 77 | 78 | 79 | @pytest.mark.parametrize("multi_target", [False, True]) 80 | @pytest.mark.parametrize( 81 | "continuous_cols", 82 | [ 83 | [f"feature_{i}" for i in range(54)], 84 | [], 85 | ], 86 | ) 87 | @pytest.mark.parametrize("categorical_cols", [["feature_0_cat"]]) 88 | @pytest.mark.parametrize("continuous_feature_transform", [None]) 89 | @pytest.mark.parametrize("normalize_continuous_features", [True]) 90 | @pytest.mark.parametrize("num_gaussian", [1, 2]) 91 | def test_classification( 92 | classification_data, 93 | multi_target, 94 | continuous_cols, 95 | categorical_cols, 96 | continuous_feature_transform, 97 | normalize_continuous_features, 98 | num_gaussian, 99 | ): 100 | (train, test, target) = classification_data 101 | data_config = DataConfig( 102 | target=target + ["feature_53"] if multi_target else target, 103 | continuous_cols=continuous_cols, 104 | categorical_cols=categorical_cols, 105 | continuous_feature_transform=continuous_feature_transform, 106 | normalize_continuous_features=normalize_continuous_features, 107 | ) 108 | model_config_params = {"task": "classification"} 109 | mdn_config = {"num_gaussian": num_gaussian} 110 | model_config_params["head_config"] = mdn_config 111 | model_config_params["backbone_config_class"] = "CategoryEmbeddingMDNConfig" 112 | model_config_params["backbone_config_params"] = {"task": "backbone"} 113 | 114 | model_config = MDNConfig(**model_config_params) 115 | trainer_config = TrainerConfig( 116 | max_epochs=3, 117 | checkpoints=None, 118 | early_stopping=None, 119 | accelerator="cpu", 120 | fast_dev_run=True, 121 | ) 122 | optimizer_config = OptimizerConfig() 123 | with pytest.raises(AssertionError): 124 | tabular_model = TabularModel( 125 | data_config=data_config, 126 | model_config=model_config, 127 | optimizer_config=optimizer_config, 128 | trainer_config=trainer_config, 129 | ) 130 | tabular_model.fit(train=train) 131 | -------------------------------------------------------------------------------- /tests/test_tabnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Tests for `pytorch_tabular` package.""" 3 | 4 | import pytest 5 | 6 | from pytorch_tabular import TabularModel 7 | from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig 8 | from pytorch_tabular.models import TabNetModelConfig 9 | 10 | 11 | @pytest.mark.parametrize("multi_target", [True, False]) 12 | @pytest.mark.parametrize( 13 | "continuous_cols", 14 | [ 15 | [ 16 | "AveRooms", 17 | "AveBedrms", 18 | "Population", 19 | "AveOccup", 20 | "Latitude", 21 | "Longitude", 22 | ], 23 | ], 24 | ) 25 | @pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]]) 26 | @pytest.mark.parametrize("continuous_feature_transform", [None]) 27 | @pytest.mark.parametrize("normalize_continuous_features", [True]) 28 | @pytest.mark.parametrize("target_range", [True, False]) 29 | def test_regression( 30 | regression_data, 31 | multi_target, 32 | continuous_cols, 33 | categorical_cols, 34 | continuous_feature_transform, 35 | normalize_continuous_features, 36 | target_range, 37 | ): 38 | (train, test, target) = regression_data 39 | data_config = DataConfig( 40 | target=target + ["MedInc"] if multi_target else target, 41 | continuous_cols=continuous_cols, 42 | categorical_cols=categorical_cols, 43 | continuous_feature_transform=continuous_feature_transform, 44 | normalize_continuous_features=normalize_continuous_features, 45 | ) 46 | model_config_params = {"task": "regression"} 47 | if target_range: 48 | _target_range = [] 49 | for target in data_config.target: 50 | _target_range.append( 51 | ( 52 | float(train[target].min()), 53 | float(train[target].max()), 54 | ) 55 | ) 56 | model_config_params["target_range"] = _target_range 57 | model_config = TabNetModelConfig(**model_config_params) 58 | trainer_config = TrainerConfig( 59 | max_epochs=1, 60 | checkpoints=None, 61 | early_stopping=None, 62 | accelerator="cpu", 63 | fast_dev_run=True, 64 | ) 65 | optimizer_config = OptimizerConfig() 66 | 67 | tabular_model = TabularModel( 68 | data_config=data_config, 69 | model_config=model_config, 70 | optimizer_config=optimizer_config, 71 | trainer_config=trainer_config, 72 | ) 73 | tabular_model.fit(train=train) 74 | 75 | result = tabular_model.evaluate(test) 76 | assert "test_mean_squared_error" in result[0].keys() 77 | pred_df = tabular_model.predict(test) 78 | assert pred_df.shape[0] == test.shape[0] 79 | 80 | 81 | @pytest.mark.parametrize("multi_target", [False, True]) 82 | @pytest.mark.parametrize( 83 | "continuous_cols", 84 | [[f"feature_{i}" for i in range(54)]], 85 | ) 86 | @pytest.mark.parametrize("categorical_cols", [["feature_0_cat"]]) 87 | @pytest.mark.parametrize("continuous_feature_transform", [None]) 88 | @pytest.mark.parametrize("normalize_continuous_features", [True]) 89 | def test_classification( 90 | classification_data, 91 | multi_target, 92 | continuous_cols, 93 | categorical_cols, 94 | continuous_feature_transform, 95 | normalize_continuous_features, 96 | ): 97 | (train, test, target) = classification_data 98 | data_config = DataConfig( 99 | target=target + ["feature_53"] if multi_target else target, 100 | continuous_cols=continuous_cols, 101 | categorical_cols=categorical_cols, 102 | continuous_feature_transform=continuous_feature_transform, 103 | normalize_continuous_features=normalize_continuous_features, 104 | ) 105 | model_config_params = {"task": "classification"} 106 | model_config = TabNetModelConfig(**model_config_params) 107 | trainer_config = TrainerConfig( 108 | max_epochs=1, 109 | checkpoints=None, 110 | early_stopping=None, 111 | accelerator="cpu", 112 | fast_dev_run=True, 113 | ) 114 | optimizer_config = OptimizerConfig() 115 | 116 | tabular_model = TabularModel( 117 | data_config=data_config, 118 | model_config=model_config, 119 | optimizer_config=optimizer_config, 120 | trainer_config=trainer_config, 121 | ) 122 | tabular_model.fit(train=train) 123 | 124 | result = tabular_model.evaluate(test) 125 | assert "test_accuracy" in result[0].keys() 126 | pred_df = tabular_model.predict(test) 127 | assert pred_df.shape[0] == test.shape[0] 128 | --------------------------------------------------------------------------------