├── .circleci └── config.yml ├── .dockerignore ├── .editorconfig ├── .flake8 ├── .gitatttributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .nojekyll ├── CHANGELOG.md ├── Dockerfile ├── Dockerfile_gpu ├── LICENSE ├── Makefile ├── README.md ├── census_example.ipynb ├── customizing_example.ipynb ├── docs-scripts ├── Makefile ├── make.bat ├── module_template.tpl ├── package_template.tpl ├── rst_generator.sh └── source │ ├── _static │ └── default.css │ ├── _templates │ └── layout.html │ ├── conf.py │ └── index.rst ├── docs ├── .nojekyll ├── _modules │ ├── index.html │ ├── pytorch_tabnet │ │ ├── abstract_model.html │ │ ├── augmentations.html │ │ ├── callbacks.html │ │ ├── metrics.html │ │ ├── multiclass_utils.html │ │ ├── multitask.html │ │ ├── pretraining.html │ │ ├── pretraining_utils.html │ │ ├── sparsemax.html │ │ ├── tab_model.html │ │ ├── tab_network.html │ │ └── utils.html │ └── torch │ │ └── optim │ │ └── adam.html ├── _sources │ ├── generated_docs │ │ ├── README.md.txt │ │ ├── docs-scripts....pytorch_tabnet.rst.txt │ │ └── pytorch_tabnet.rst.txt │ └── index.rst.txt ├── _static │ ├── basic.css │ ├── css │ │ ├── badge_only.css │ │ ├── fonts │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.svg │ │ │ ├── fontawesome-webfont.ttf │ │ │ ├── fontawesome-webfont.woff │ │ │ ├── fontawesome-webfont.woff2 │ │ │ ├── lato-bold-italic.woff │ │ │ ├── lato-bold-italic.woff2 │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-normal-italic.woff │ │ │ ├── lato-normal-italic.woff2 │ │ │ ├── lato-normal.woff │ │ │ └── lato-normal.woff2 │ │ └── theme.css │ ├── default.css │ ├── doctools.js │ ├── documentation_options.js │ ├── file.png │ ├── fonts │ │ ├── FontAwesome.otf │ │ ├── Lato │ │ │ ├── lato-bold.eot │ │ │ ├── lato-bold.ttf │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-bolditalic.eot │ │ │ ├── lato-bolditalic.ttf │ │ │ ├── lato-bolditalic.woff │ │ │ ├── lato-bolditalic.woff2 │ │ │ ├── lato-italic.eot │ │ │ ├── lato-italic.ttf │ │ │ ├── lato-italic.woff │ │ │ ├── lato-italic.woff2 │ │ │ ├── lato-regular.eot │ │ │ ├── lato-regular.ttf │ │ │ ├── lato-regular.woff │ │ │ └── lato-regular.woff2 │ │ ├── Roboto-Slab-Bold.woff │ │ ├── Roboto-Slab-Bold.woff2 │ │ ├── Roboto-Slab-Light.woff │ │ ├── Roboto-Slab-Light.woff2 │ │ ├── Roboto-Slab-Regular.woff │ │ ├── Roboto-Slab-Regular.woff2 │ │ ├── Roboto-Slab-Thin.woff │ │ ├── Roboto-Slab-Thin.woff2 │ │ ├── RobotoSlab │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ └── roboto-slab-v7-regular.woff2 │ │ ├── fontawesome-webfont.eot │ │ ├── fontawesome-webfont.svg │ │ ├── fontawesome-webfont.ttf │ │ ├── fontawesome-webfont.woff │ │ ├── fontawesome-webfont.woff2 │ │ ├── lato-bold-italic.woff │ │ ├── lato-bold-italic.woff2 │ │ ├── lato-bold.woff │ │ ├── lato-bold.woff2 │ │ ├── lato-normal-italic.woff │ │ ├── lato-normal-italic.woff2 │ │ ├── lato-normal.woff │ │ └── lato-normal.woff2 │ ├── graphviz.css │ ├── jquery-3.4.1.js │ ├── jquery.js │ ├── js │ │ ├── badge_only.js │ │ ├── html5shiv-printshiv.min.js │ │ ├── html5shiv.min.js │ │ ├── modernizr.min.js │ │ └── theme.js │ ├── language_data.js │ ├── minus.png │ ├── plus.png │ ├── pygments.css │ ├── searchtools.js │ ├── underscore-1.3.1.js │ └── underscore.js ├── generated_docs │ ├── README.html │ ├── docs-scripts....pytorch_tabnet.html │ └── pytorch_tabnet.html ├── genindex.html ├── index.html ├── py-modindex.html ├── search.html └── searchindex.js ├── forest_example.ipynb ├── multi_regression_example.ipynb ├── multi_task_example.ipynb ├── poetry.lock ├── poetry.toml ├── pretraining_example.ipynb ├── pyproject.toml ├── pytorch_tabnet ├── abstract_model.py ├── augmentations.py ├── callbacks.py ├── metrics.py ├── multiclass_utils.py ├── multitask.py ├── pretraining.py ├── pretraining_utils.py ├── sparsemax.py ├── tab_model.py ├── tab_network.py └── utils.py ├── regression_example.ipynb ├── release-script ├── Dockerfile_changelog ├── do-release.sh └── prepare-release.sh ├── renovate.json └── tests └── test_unsupervised_loss.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | --- 6 | version: 2.1 7 | executors: 8 | # here we can define an executor that will be shared across different jobs 9 | python-executor: 10 | docker: 11 | - image: python:3.7-slim-buster@sha256:fecbb1a9695d25c974906263c64ffba6548ce14a169ed36be58331659383c25e 12 | environment: 13 | POETRY_CACHE: /work/.cache/poetry 14 | PIP_CACHE_DIR: /work/.cache/pip 15 | JUPYTER_RUNTIME_DIR: /work/.cache/jupyter/runtime 16 | JUPYTER_CONFIG_DIR: /work/.cache/jupyter/config 17 | SHELL: bash -l 18 | working_directory: /work 19 | resource_class: large 20 | docker-executor: 21 | docker: 22 | - image: dreamquark/docker:latest@sha256:0dfd1a7a7b519e33fde3f2285f19cdb81c9a9f01e457f1940bac36a7b5ca8347 23 | working_directory: /work 24 | resource_class: small 25 | 26 | commands: 27 | # here we can define steps that will be shared across different jobs 28 | install_poetry: 29 | description: Install poetry 30 | steps: 31 | - run: 32 | name: Install prerequisites and poetry 33 | command: | 34 | apt update && apt install curl make git libopenblas-base build-essential -y 35 | curl -sSL https://install.python-poetry.org | python3 - 36 | export PATH="/root/.local/bin:$PATH" 37 | poetry config virtualenvs.path $POETRY_CACHE 38 | poetry run pip install --upgrade --no-cache-dir pip==20.1; 39 | 40 | jobs: 41 | test-build-docker: 42 | executor: docker-executor 43 | steps: 44 | - checkout 45 | - setup_remote_docker 46 | - run: 47 | name: build docker 48 | command: | 49 | make build 50 | test-build-docker-gpu: 51 | executor: docker-executor 52 | steps: 53 | - checkout 54 | - setup_remote_docker 55 | - run: 56 | name: build docker gpu 57 | command: | 58 | make build-gpu 59 | lint-code: 60 | executor: python-executor 61 | resource_class: small 62 | steps: 63 | - checkout 64 | # Download and cache dependencies 65 | - restore_cache: 66 | keys: 67 | - v1-dependencies-{{ checksum "poetry.lock" }} 68 | - install_poetry 69 | - run: 70 | name: LintCode 71 | shell: bash -leo pipefail 72 | command: | 73 | export PATH="/root/.local/bin:$PATH" 74 | poetry run flake8 75 | install: 76 | executor: python-executor 77 | resource_class: medium 78 | steps: 79 | - checkout 80 | # Download and cache dependencies 81 | - restore_cache: 82 | keys: 83 | - v1-dependencies-{{ checksum "poetry.lock" }} 84 | # fallback to using the latest cache if no exact match is found 85 | - v1-dependencies- 86 | - install_poetry 87 | - run: 88 | name: Install dependencies 89 | shell: bash -leo pipefail 90 | command: | 91 | export PATH="/root/.local/bin:$PATH" 92 | poetry config virtualenvs.path $POETRY_CACHE 93 | poetry run pip install torch==1.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 94 | poetry install --no-ansi 95 | - save_cache: 96 | paths: 97 | - /work/.cache/poetry 98 | key: v1-dependencies-{{ checksum "poetry.lock" }} 99 | unit-tests: 100 | executor: python-executor 101 | steps: 102 | - checkout 103 | # Download and cache dependencies 104 | - restore_cache: 105 | keys: 106 | - v1-dependencies-{{ checksum "poetry.lock" }} 107 | - install_poetry 108 | - run: 109 | name: run unit-tests 110 | shell: bash -leo pipefail 111 | command: | 112 | export PATH="/root/.local/bin:$PATH" 113 | make unit-tests 114 | test-nb-census: 115 | executor: python-executor 116 | steps: 117 | - checkout 118 | # Download and cache dependencies 119 | - restore_cache: 120 | keys: 121 | - v1-dependencies-{{ checksum "poetry.lock" }} 122 | - install_poetry 123 | - run: 124 | name: run test-nb-census 125 | shell: bash -leo pipefail 126 | command: | 127 | export PATH="/root/.local/bin:$PATH" 128 | make test-nb-census 129 | test-nb-multi-regression: 130 | executor: python-executor 131 | steps: 132 | - checkout 133 | # Download and cache dependencies 134 | - restore_cache: 135 | keys: 136 | - v1-dependencies-{{ checksum "poetry.lock" }} 137 | - install_poetry 138 | - run: 139 | name: run test-nb-multi-regression 140 | shell: bash -leo pipefail 141 | command: | 142 | export PATH="/root/.local/bin:$PATH" 143 | make test-nb-multi-regression 144 | test-nb-forest: 145 | executor: python-executor 146 | steps: 147 | - checkout 148 | # Download and cache dependencies 149 | - restore_cache: 150 | keys: 151 | - v1-dependencies-{{ checksum "poetry.lock" }} 152 | - install_poetry 153 | - run: 154 | name: run test-nb-forest 155 | shell: bash -leo pipefail 156 | command: | 157 | export PATH="/root/.local/bin:$PATH" 158 | make test-nb-forest 159 | test-nb-regression: 160 | executor: python-executor 161 | steps: 162 | - checkout 163 | # Download and cache dependencies 164 | - restore_cache: 165 | keys: 166 | - v1-dependencies-{{ checksum "poetry.lock" }} 167 | - install_poetry 168 | - run: 169 | name: run test-nb-regression 170 | shell: bash -leo pipefail 171 | command: | 172 | export PATH="/root/.local/bin:$PATH" 173 | make test-nb-regression 174 | test-nb-multi-task: 175 | executor: python-executor 176 | steps: 177 | - checkout 178 | # Download and cache dependencies 179 | - restore_cache: 180 | keys: 181 | - v1-dependencies-{{ checksum "poetry.lock" }} 182 | - install_poetry 183 | - run: 184 | name: run test-nb-multi-task 185 | shell: bash -leo pipefail 186 | command: | 187 | export PATH="/root/.local/bin:$PATH" 188 | make test-nb-multi-task 189 | test-nb-customization: 190 | executor: python-executor 191 | steps: 192 | - checkout 193 | # Download and cache dependencies 194 | - restore_cache: 195 | keys: 196 | - v1-dependencies-{{ checksum "poetry.lock" }} 197 | - install_poetry 198 | - run: 199 | name: run test-nb-customization 200 | shell: bash -leo pipefail 201 | command: | 202 | export PATH="/root/.local/bin:$PATH" 203 | make test-nb-customization 204 | test-nb-pretraining: 205 | executor: python-executor 206 | steps: 207 | - checkout 208 | # Download and cache dependencies 209 | - restore_cache: 210 | keys: 211 | - v1-dependencies-{{ checksum "poetry.lock" }} 212 | - install_poetry 213 | - run: 214 | name: run test-nb-pretraining 215 | shell: bash -leo pipefail 216 | command: | 217 | export PATH="/root/.local/bin:$PATH" 218 | make test-nb-pretraining 219 | workflows: 220 | version: 2 221 | CI-tabnet: 222 | jobs: 223 | - test-build-docker 224 | - test-build-docker-gpu 225 | - install 226 | - unit-tests: 227 | requires: 228 | - install 229 | - test-nb-census: 230 | requires: 231 | - install 232 | - test-nb-multi-regression: 233 | requires: 234 | - install 235 | - test-nb-regression: 236 | requires: 237 | - install 238 | - test-nb-forest: 239 | requires: 240 | - install 241 | - test-nb-multi-task: 242 | requires: 243 | - install 244 | - test-nb-customization: 245 | requires: 246 | - install 247 | - test-nb-pretraining: 248 | requires: 249 | - install 250 | - lint-code: 251 | requires: 252 | - install 253 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig helps developers define and maintain consistent 2 | # coding styles between different editors and IDEs 3 | # editorconfig.org 4 | 5 | root = true 6 | 7 | 8 | [*] 9 | 10 | # Change these settings to your own preference 11 | indent_style = space 12 | indent_size = 4 13 | 14 | # We recommend you to keep these unchanged 15 | end_of_line = lf 16 | charset = utf-8 17 | trim_trailing_whitespace = true 18 | insert_final_newline = true 19 | 20 | [*.md] 21 | trim_trailing_whitespace = false 22 | 23 | [Makefile] 24 | indent_style = tab 25 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | ignore = E203, W503 4 | count = True 5 | exclude = 6 | .git, 7 | dist, 8 | *ipynb, 9 | *.egg, 10 | .cache 11 | -------------------------------------------------------------------------------- /.gitatttributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | # Basic .gitattributes for a python repo. 3 | 4 | # Source files 5 | # ============ 6 | *.pxd text diff=python 7 | *.py text diff=python 8 | *.py3 text diff=python 9 | *.pyw text diff=python 10 | *.pyx text diff=python 11 | *.pyz text diff=python 12 | 13 | # Binary files 14 | # ============ 15 | *.db binary 16 | *.p binary 17 | *.pkl binary 18 | *.pickle binary 19 | *.pyc binary 20 | *.pyd binary 21 | *.pyo binary 22 | 23 | # Jupyter notebook 24 | *.ipynb text 25 | 26 | # Note: .db, .p, and .pkl files are associated 27 | # with the python modules ``pickle``, ``dbm.*``, 28 | # ``shelve``, ``marshal``, ``anydbm``, & ``bsddb`` 29 | # (among others). 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: eduardocarvp, Hartorn, j-abi, Optimox 7 | 8 | --- 9 | 10 | 11 | **Describe the bug** 12 | 13 | 14 | **What is the current behavior?** 15 | 16 | **If the current behavior is a bug, please provide the steps to reproduce.** 17 | 23 | 24 | **Expected behavior** 25 | 26 | 27 | 28 | **Screenshots** 29 | 30 | 31 | **Other relevant information:** 32 | poetry version: 33 | python version: 34 | Operating System: 35 | Additional tools: 36 | 37 | **Additional context** 38 | 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: eduardocarvp, Hartorn, j-abi, Optimox 7 | 8 | --- 9 | 10 | 11 | 12 | ## Feature request 13 | 14 | 15 | **What is the expected behavior?** 16 | 17 | 18 | **What is motivation or use case for adding/changing the behavior?** 19 | 20 | 21 | **How should this be implemented in your opinion?** 22 | 23 | 24 | **Are you willing to work on this yourself?** 25 | yes 26 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | **IMPORTANT: Please do not create a Pull Request without creating an issue first.** 4 | 5 | *Any change needs to be discussed before proceeding. Failure to do so may result in the rejection of the pull request.* 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | **What kind of change does this PR introduce?** 15 | 16 | 17 | 18 | **Does this PR introduce a breaking change?** 19 | 20 | 21 | 22 | **What needs to be documented once your changes are merged?** 23 | 24 | 25 | 26 | 27 | **Closing issues** 28 | 29 | Put `closes #XXXX` in your comment to auto-close the issue that your PR fixes (if such). 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .doctrees 2 | docs-scripts/source 3 | objects.inv 4 | .buildinfo 5 | .cache/ 6 | ../.history/ 7 | .history/ 8 | data/ 9 | .ipynb_checkpoints/ 10 | .vscode/ 11 | *.pt 12 | *~ 13 | .vscode/ 14 | 15 | # Notebook to python 16 | forest_example.py 17 | regression_example.py 18 | census_example.py 19 | multi_regression_example.py 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # celery beat schedule file 115 | celerybeat-schedule 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .env 122 | .venv 123 | env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | -------------------------------------------------------------------------------- /.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/.nojekyll -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7-slim-buster@sha256:50de4af76270c893fe36a9ae428951057d6e1a681312d11861970baa150a62e2 2 | RUN apt update && apt install curl make git libopenblas-base -y 3 | RUN curl -sSL https://install.python-poetry.org | python3 - 4 | ENV SHELL /bin/bash -l 5 | 6 | ENV POETRY_CACHE /work/.cache/poetry 7 | ENV PIP_CACHE_DIR /work/.cache/pip 8 | ENV JUPYTER_RUNTIME_DIR /work/.cache/jupyter/runtime 9 | ENV JUPYTER_CONFIG_DIR /work/.cache/jupyter/config 10 | 11 | RUN /root/.local/bin/poetry config virtualenvs.path $POETRY_CACHE 12 | 13 | ENV PATH /root/.local/bin:/bin:/usr/local/bin:/usr/bin 14 | 15 | CMD ["bash", "-l"] 16 | -------------------------------------------------------------------------------- /Dockerfile_gpu: -------------------------------------------------------------------------------- 1 | # GENERATED FROM SCRIPTS 2 | FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04 3 | 4 | # Avoid tzdata interactive action 5 | ENV DEBIAN_FRONTEND noninteractive 6 | 7 | # Adding Python to image 8 | 9 | # Dockerfile generated fragment to install Python and Pip 10 | # Source: https://raw.githubusercontent.com/docker-library/python/master/3.7/slim-buster/Dockerfile 11 | # Python: 3.7.13 12 | # Pip: 22.0.4 13 | 14 | 15 | 16 | 17 | # ensure local python is preferred over distribution python 18 | ENV PATH /usr/local/bin:$PATH 19 | 20 | # http://bugs.python.org/issue19846 21 | # > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK. 22 | ENV LANG C.UTF-8 23 | 24 | # runtime dependencies 25 | RUN set -eux; \ 26 | apt-get update; \ 27 | apt-get install -y --no-install-recommends \ 28 | ca-certificates \ 29 | netbase \ 30 | tzdata \ 31 | ; \ 32 | rm -rf /var/lib/apt/lists/* 33 | 34 | ENV GPG_KEY 0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D 35 | ENV PYTHON_VERSION 3.7.13 36 | 37 | RUN set -eux; \ 38 | \ 39 | savedAptMark="$(apt-mark showmanual)"; \ 40 | apt-get update; \ 41 | apt-get install -y --no-install-recommends \ 42 | dpkg-dev \ 43 | gcc \ 44 | gnupg dirmngr \ 45 | libbluetooth-dev \ 46 | libbz2-dev \ 47 | libc6-dev \ 48 | libexpat1-dev \ 49 | libffi-dev \ 50 | libgdbm-dev \ 51 | liblzma-dev \ 52 | libncursesw5-dev \ 53 | libreadline-dev \ 54 | libsqlite3-dev \ 55 | libssl-dev \ 56 | make \ 57 | tk-dev \ 58 | uuid-dev \ 59 | wget \ 60 | xz-utils \ 61 | zlib1g-dev \ 62 | ; \ 63 | \ 64 | wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz"; \ 65 | wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc"; \ 66 | GNUPGHOME="$(mktemp -d)"; export GNUPGHOME; \ 67 | gpg --batch --keyserver hkps://keys.openpgp.org --recv-keys "$GPG_KEY"; \ 68 | gpg --batch --verify python.tar.xz.asc python.tar.xz; \ 69 | command -v gpgconf > /dev/null && gpgconf --kill all || :; \ 70 | rm -rf "$GNUPGHOME" python.tar.xz.asc; \ 71 | mkdir -p /usr/src/python; \ 72 | tar --extract --directory /usr/src/python --strip-components=1 --file python.tar.xz; \ 73 | rm python.tar.xz; \ 74 | \ 75 | cd /usr/src/python; \ 76 | gnuArch="$(dpkg-architecture --query DEB_BUILD_GNU_TYPE)"; \ 77 | ./configure \ 78 | --build="$gnuArch" \ 79 | --enable-loadable-sqlite-extensions \ 80 | --enable-optimizations \ 81 | --enable-option-checking=fatal \ 82 | --enable-shared \ 83 | --with-system-expat \ 84 | --without-ensurepip \ 85 | ; \ 86 | nproc="$(nproc)"; \ 87 | make -j "$nproc" \ 88 | LDFLAGS="-Wl,--strip-all" \ 89 | # setting PROFILE_TASK makes "--enable-optimizations" reasonable: https://bugs.python.org/issue36044 / https://github.com/docker-library/python/issues/160#issuecomment-509426916 90 | PROFILE_TASK='-m test.regrtest --pgo \ 91 | test_array \ 92 | test_base64 \ 93 | test_binascii \ 94 | test_binhex \ 95 | test_binop \ 96 | test_bytes \ 97 | test_c_locale_coercion \ 98 | test_class \ 99 | test_cmath \ 100 | test_codecs \ 101 | test_compile \ 102 | test_complex \ 103 | test_csv \ 104 | test_decimal \ 105 | test_dict \ 106 | test_float \ 107 | test_fstring \ 108 | test_hashlib \ 109 | test_io \ 110 | test_iter \ 111 | test_json \ 112 | test_long \ 113 | test_math \ 114 | test_memoryview \ 115 | test_pickle \ 116 | test_re \ 117 | test_set \ 118 | test_slice \ 119 | test_struct \ 120 | test_threading \ 121 | test_time \ 122 | test_traceback \ 123 | test_unicode \ 124 | ' \ 125 | ; \ 126 | make install && cd /usr/local; \ 127 | \ 128 | cd /; \ 129 | rm -rf /usr/src/python; \ 130 | \ 131 | find /usr/local -depth \ 132 | \( \ 133 | \( -type d -a \( -name test -o -name tests -o -name idle_test \) \) \ 134 | -o \( -type f -a \( -name '*.pyc' -o -name '*.pyo' -o -name 'libpython*.a' \) \) \ 135 | -o \( -type f -a -name 'wininst-*.exe' \) \ 136 | \) -exec rm -rf '{}' + \ 137 | ; \ 138 | \ 139 | ldconfig; \ 140 | \ 141 | apt-mark auto '.*' > /dev/null; \ 142 | apt-mark manual $savedAptMark; \ 143 | find /usr/local -type f -executable -not \( -name '*tkinter*' \) -exec ldd '{}' ';' \ 144 | | awk '/=>/ { print $(NF-1) }' \ 145 | | sort -u \ 146 | | xargs -r dpkg-query --search \ 147 | | cut -d: -f1 \ 148 | | sort -u \ 149 | | xargs -r apt-mark manual \ 150 | ; \ 151 | apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false; \ 152 | rm -rf /var/lib/apt/lists/*; \ 153 | \ 154 | python3 --version 155 | 156 | # make some useful symlinks that are expected to exist ("/usr/local/bin/python" and friends) 157 | RUN set -eux; \ 158 | for src in idle3 pydoc3 python3 python3-config; do \ 159 | dst="$(echo "$src" | tr -d 3)"; \ 160 | [ -s "/usr/local/bin/$src" ]; \ 161 | [ ! -e "/usr/local/bin/$dst" ]; \ 162 | ln -svT "$src" "/usr/local/bin/$dst"; \ 163 | done 164 | 165 | # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" 166 | ENV PYTHON_PIP_VERSION 22.0.4 167 | # https://github.com/docker-library/python/issues/365 168 | ENV PYTHON_SETUPTOOLS_VERSION 57.5.0 169 | # https://github.com/pypa/get-pip 170 | ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/6ce3639da143c5d79b44f94b04080abf2531fd6e/public/get-pip.py 171 | ENV PYTHON_GET_PIP_SHA256 ba3ab8267d91fd41c58dbce08f76db99f747f716d85ce1865813842bb035524d 172 | 173 | RUN set -eux; \ 174 | \ 175 | savedAptMark="$(apt-mark showmanual)"; \ 176 | apt-get update; \ 177 | apt-get install -y --no-install-recommends wget; \ 178 | \ 179 | wget -O get-pip.py "$PYTHON_GET_PIP_URL"; \ 180 | echo "$PYTHON_GET_PIP_SHA256 *get-pip.py" | sha256sum -c -; \ 181 | \ 182 | apt-mark auto '.*' > /dev/null; \ 183 | [ -z "$savedAptMark" ] || apt-mark manual $savedAptMark > /dev/null; \ 184 | apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false; \ 185 | rm -rf /var/lib/apt/lists/*; \ 186 | \ 187 | export PYTHONDONTWRITEBYTECODE=1; \ 188 | \ 189 | python get-pip.py \ 190 | --disable-pip-version-check \ 191 | --no-cache-dir \ 192 | --no-compile \ 193 | "pip==$PYTHON_PIP_VERSION" \ 194 | "setuptools==$PYTHON_SETUPTOOLS_VERSION" \ 195 | ; \ 196 | rm -f get-pip.py; \ 197 | \ 198 | pip --version 199 | 200 | RUN apt update && apt install curl make git libopenblas-base -y 201 | 202 | RUN curl -sSL https://install.python-poetry.org | python3 - 203 | 204 | ENV SHELL /bin/bash -l 205 | 206 | ENV POETRY_CACHE /work/.cache/poetry 207 | 208 | ENV PIP_CACHE_DIR /work/.cache/pip 209 | 210 | ENV JUPYTER_RUNTIME_DIR /work/.cache/jupyter/runtime 211 | 212 | ENV JUPYTER_CONFIG_DIR /work/.cache/jupyter/config 213 | 214 | RUN /root/.local/bin/poetry config virtualenvs.path $POETRY_CACHE 215 | 216 | ENV PATH /root/.local/bin:/bin:/usr/local/bin:/usr/bin 217 | 218 | CMD ["bash", "-l"] 219 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 DreamQuark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # set default shell 2 | SHELL := $(shell which bash) 3 | FOLDER=$$(pwd) 4 | # for Windows users 5 | # FOLDER=$(CURDIR) 6 | # default shell options 7 | .SHELLFLAGS = -c 8 | NO_COLOR=\\e[39m 9 | OK_COLOR=\\e[32m 10 | ERROR_COLOR=\\e[31m 11 | WARN_COLOR=\\e[33m 12 | PORT=8887 13 | .SILENT: ; 14 | default: help; # default target 15 | 16 | IMAGE_NAME=tabnet:latest 17 | IMAGE_RELEASER_NAME=release-changelog:latest 18 | NOTEBOOKS_DIR=/work 19 | 20 | DOCKER_RUN = docker run --rm -v ${FOLDER}:/work -w /work --entrypoint bash -lc ${IMAGE_NAME} -c 21 | 22 | prepare-release: build build-releaser ## Prepare release branch with changelog for given version 23 | ./release-script/prepare-release.sh 24 | .PHONY: prepare-release 25 | 26 | do-release: build build-releaser ## Prepare release branch with changelog for given version 27 | ./release-script/do-release.sh 28 | .PHONY: do-release 29 | 30 | build-releaser: ## Build docker image for releaser 31 | echo "Building Dockerfile" 32 | docker build -f ./release-script/Dockerfile_changelog -t ${IMAGE_RELEASER_NAME} . 33 | .PHONY: build 34 | 35 | build: ## Build docker image 36 | echo "Building Dockerfile" 37 | docker build -t ${IMAGE_NAME} . 38 | .PHONY: build 39 | 40 | build-gpu: ## Build docker image 41 | echo "Building Dockerfile" 42 | docker build -t ${IMAGE_NAME} . -f Dockerfile_gpu 43 | .PHONY: build-gpu 44 | 45 | start: build ## Start docker container 46 | echo "Starting container ${IMAGE_NAME}" 47 | docker run --shm-size="32gb" --rm -it -v ${FOLDER}:/work -w /work -p ${PORT}:${PORT} -e "JUPYTER_PORT=${PORT}" ${IMAGE_NAME} 48 | .PHONY: start 49 | 50 | start-gpu: build-gpu ## Start docker container 51 | echo "Starting container ${IMAGE_NAME}" 52 | docker run --runtime nvidia --shm-size="32gb" --rm -it -v ${FOLDER}:/work -w /work -p ${PORT}:${PORT} -e "JUPYTER_PORT=${PORT}" ${IMAGE_NAME} 53 | .PHONY: start-gpu 54 | 55 | install: build ## Install dependencies 56 | $(DOCKER_RUN) 'poetry install' 57 | .PHONY: install 58 | 59 | lint: ## Check lint 60 | $(DOCKER_RUN) 'poetry run flake8' 61 | .PHONY: lint 62 | 63 | notebook: ## Start the Jupyter notebook 64 | poetry run jupyter notebook --allow-root --ip 0.0.0.0 --port ${PORT} --no-browser --notebook-dir . 65 | .PHONY: notebook 66 | 67 | root_bash: ## Start a root bash inside the container 68 | docker exec -it --user root $$(docker ps --filter ancestor=${IMAGE_NAME} --filter expose=${PORT} -q) bash 69 | .PHONY: root_bash 70 | 71 | _run_notebook: 72 | set -e 73 | echo "$(NB_FILE)" | xargs -n1 -I {} echo "poetry run jupyter nbconvert --to=script $(NOTEBOOKS_DIR)/{} || exit 1" | sh 74 | echo "$(NB_FILE)" | xargs -n1 -I {} echo "echo 'Running {}' && poetry run ipython $(NOTEBOOKS_DIR)/{} && echo 'Notebook $(NOTEBOOKS_DIR)/{} OK' || exit 1" | sed 's/.ipynb/.py/' | sh 75 | echo "$(NB_FILE)" | sed 's/.ipynb/.py/' | xargs -n1 -I {} echo "echo 'Cleaning up $(NOTEBOOKS_DIR)/{}' && rm $(NOTEBOOKS_DIR)/{} || exit 1" | sh 76 | .PHONY: _run_notebook 77 | 78 | doc: build ## Build and generate docs 79 | $(DOCKER_RUN) 'cd ./docs-scripts && ./rst_generator.sh' 80 | $(DOCKER_RUN) 'poetry run sphinx-build ./docs-scripts/source ./docs -b html' 81 | $(DOCKER_RUN) 'touch ./docs/.nojekyll' 82 | .PHONY: doc 83 | 84 | test-nb-census: ## run census income tests using notebooks 85 | $(MAKE) _run_notebook NB_FILE="./census_example.ipynb" 86 | .PHONY: test-nb-census 87 | 88 | test-nb-forest: ## run census income tests using notebooks 89 | $(MAKE) _run_notebook NB_FILE="./forest_example.ipynb" 90 | .PHONY: test-nb-forest 91 | 92 | test-nb-regression: ## run regression example tests using notebooks 93 | $(MAKE) _run_notebook NB_FILE="./regression_example.ipynb" 94 | .PHONY: test-nb-regression 95 | 96 | test-nb-multi-regression: ## run multi regression example tests using notebooks 97 | $(MAKE) _run_notebook NB_FILE="./multi_regression_example.ipynb" 98 | .PHONY: test-nb-multi-regression 99 | 100 | test-nb-multi-task: ## run multi task classification example tests using notebooks 101 | $(MAKE) _run_notebook NB_FILE="./multi_task_example.ipynb" 102 | .PHONY: test-nb-multi-task 103 | 104 | test-nb-customization: ## run customization example tests using notebooks 105 | $(MAKE) _run_notebook NB_FILE="./customizing_example.ipynb" 106 | .PHONY: test-nb-customization 107 | 108 | test-nb-pretraining: ## run customization example tests using notebooks 109 | $(MAKE) _run_notebook NB_FILE="./pretraining_example.ipynb" 110 | .PHONY: test-nb-pretraining 111 | 112 | unit-tests: ## run all unitary tests 113 | poetry run pytest -s tests/ 114 | .PHONY: unit-tests 115 | 116 | help: ## Display help 117 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 118 | .PHONY: help 119 | 120 | -------------------------------------------------------------------------------- /docs-scripts/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs-scripts/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs-scripts/module_template.tpl: -------------------------------------------------------------------------------- 1 | 2 | #MODULE_NAME_TITLE# 3 | 4 | .. automodule:: #MODULE_NAME# 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | -------------------------------------------------------------------------------- /docs-scripts/package_template.tpl: -------------------------------------------------------------------------------- 1 | #PACKAGE_NAME# 2 | 3 | -------------------------------------------------------------------------------- /docs-scripts/rst_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | SCRIPT_DIR=$(dirname ${BASH_SOURCE[0]}) 5 | DOC_FOLDER_NAME="source/generated_docs" 6 | DOC_FOLDER="${SCRIPT_DIR}/${DOC_FOLDER_NAME}" 7 | TARGET_FOLDER="../pytorch_tabnet" 8 | 9 | rm -rf ${DOC_FOLDER} 10 | mkdir -p ${DOC_FOLDER} 11 | 12 | MODULES_LIST=$(find ${SCRIPT_DIR}/${TARGET_FOLDER} -type f 2>/dev/null | grep -v \.cache | grep -v "/tests" | grep -e '.py$' | sed -e 's/^[\.+/]*\///g' | sed -e 's/\.py$//g' | sed -e 's/\//./g') 13 | PACKAGE_LIST=$(echo ${MODULES_LIST} | tr ' ' '\n' | sed -e 's/^\(.*\)\.\(.*\)/\1/g' | sort | uniq) 14 | 15 | #echo $(seq -s= $(($(echo pytorch_tabnet.sparsemax|wc -c) + 9))|tr -d '[:digit:]') 16 | 17 | echo ${PACKAGE_LIST} | tr ' ' '\n' | xargs -n1 -I {} echo "cp ${SCRIPT_DIR}/package_template.tpl ${DOC_FOLDER}/{}.rst" | sh 18 | echo ${PACKAGE_LIST} | tr ' ' '\n' | xargs -n1 -I {} echo "sed -si \"s/#PACKAGE_NAME#/{} package\n\$(seq -s= \$((\$(echo {}|wc -c) + 8))|tr -d '[:digit:]')/g\" ${DOC_FOLDER}/{}.rst" | sh 19 | 20 | # module 21 | echo ${MODULES_LIST} |\ 22 | tr ' ' '\n' |\ 23 | xargs -n1 -I {} echo "cat ${SCRIPT_DIR}/module_template.tpl | sed -s \"s/#MODULE_NAME_TITLE#/{} module\n\$(seq -s. \$((\$(echo {}|wc -c) + 9))|tr -d '[:digit:]')/g\" | sed -s \"s/#MODULE_NAME#/{}/g\" >> \$(echo ${DOC_FOLDER}/{} | sed -e 's/^\(.*\)\.\(.*\)/\1.rst/g')" | sh 24 | cp "${SCRIPT_DIR}/../README.md" "${DOC_FOLDER}/README.md" -------------------------------------------------------------------------------- /docs-scripts/source/_static/default.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Alternate Sphinx design 3 | * Originally created by Armin Ronacher for Werkzeug, adapted by Georg Brandl. 4 | */ 5 | 6 | body { 7 | font-family: 'Lucida Grande', 'Lucida Sans Unicode', 'Geneva', 'Verdana', sans-serif; 8 | font-size: 14px; 9 | letter-spacing: -0.01em; 10 | line-height: 150%; 11 | } 12 | 13 | pre { 14 | font-family: 'Consolas', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 15 | font-size: 0.95em; 16 | letter-spacing: 0.015em; 17 | } 18 | 19 | cite, code, tt { 20 | font-family: 'Consolas', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 21 | font-size: 0.95em; 22 | letter-spacing: 0.01em; 23 | } 24 | 25 | .wy-nav-content { 26 | max-width: 100% !important; 27 | } -------------------------------------------------------------------------------- /docs-scripts/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | 4 | {% block rootrellink %} 5 |
  • home
  • 6 |
  • search
  • 7 | {% endblock %} 8 | 9 | 10 | {% block relbar1 %} 11 | 12 |
    13 |

    Pytorch Tabnet

    14 |
    15 | {{ super() }} 16 | {% endblock %} 17 | 18 | {# put the sidebar before the body #} 19 | {% block sidebar1 %}{{ sidebar() }}{% endblock %} 20 | {% block sidebar2 %}{% endblock %} 21 | -------------------------------------------------------------------------------- /docs-scripts/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "pytorch_tabnet" 21 | copyright = "2019, Dreamquark" 22 | author = "Dreamquark" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | # extensions = [ 31 | # 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc' 32 | # ] 33 | extensions = [ 34 | "sphinx.ext.mathjax", 35 | "sphinx.ext.autodoc", 36 | "sphinx.ext.doctest", 37 | "sphinx.ext.inheritance_diagram", 38 | "sphinx.ext.viewcode", 39 | "sphinx.ext.napoleon", 40 | "sphinx_rtd_theme", 41 | "recommonmark", 42 | ] 43 | 44 | # Add any paths that contain templates here, relative to this directory. 45 | templates_path = ["./_templates"] 46 | 47 | # List of patterns, relative to source directory, that match files and 48 | # directories to ignore when looking for source files. 49 | # This pattern also affects html_static_path and html_extra_path. 50 | exclude_patterns = [] 51 | 52 | 53 | # -- Options for HTML output ------------------------------------------------- 54 | 55 | # The theme to use for HTML and HTML Help pages. See the documentation for 56 | # a list of builtin themes. 57 | # 58 | html_theme = "sphinx_rtd_theme" 59 | pygments_style = "sphinx" 60 | 61 | # Add any paths that contain custom static files (such as style sheets) here, 62 | # relative to this directory. They are copied after the builtin static files, 63 | # so a file named "default.csssphinxdoc" will overwrite the builtin "default.css". 64 | html_static_path = ["./_static"] 65 | html_css_files = [ 66 | "./default.css", 67 | ] 68 | 69 | # The suffix(es) of source filenames. 70 | # You can specify multiple suffix as a list of string: 71 | # 72 | source_suffix = [".rst", ".md"] 73 | -------------------------------------------------------------------------------- /docs-scripts/source/index.rst: -------------------------------------------------------------------------------- 1 | .. pytorch_tabnet documentation master file, created by 2 | sphinx-quickstart on Tue Nov 12 10:52:41 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to pytorch_tabnet's documentation! 7 | ========================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | generated_docs/README.md 14 | generated_docs/pytorch_tabnet.rst 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/.nojekyll -------------------------------------------------------------------------------- /docs/_modules/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Overview: module code — pytorch_tabnet documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
    47 | 48 | 107 | 108 |
    109 | 110 | 111 | 117 | 118 | 119 |
    120 | 121 |
    122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 |
    140 | 141 |
      142 | 143 |
    • »
    • 144 | 145 |
    • Overview: module code
    • 146 | 147 | 148 |
    • 149 | 150 |
    • 151 | 152 |
    153 | 154 | 155 |
    156 |
    157 | 179 |
    180 | 181 | 182 |
    183 | 184 |
    185 |

    186 | 187 | © Copyright 2019, Dreamquark 188 | 189 |

    190 |
    191 | 192 | 193 | 194 | Built with Sphinx using a 195 | 196 | theme 197 | 198 | provided by Read the Docs. 199 | 200 |
    201 | 202 |
    203 |
    204 | 205 |
    206 | 207 |
    208 | 209 | 210 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /docs/_sources/generated_docs/docs-scripts....pytorch_tabnet.rst.txt: -------------------------------------------------------------------------------- 1 | docs-scripts....pytorch_tabnet package 2 | ====================================== 3 | 4 | 5 | docs-scripts....pytorch_tabnet.metrics module 6 | ............................................... 7 | 8 | .. automodule:: docs-scripts....pytorch_tabnet.metrics 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | docs-scripts....pytorch_tabnet.sparsemax module 15 | ................................................. 16 | 17 | .. automodule:: docs-scripts....pytorch_tabnet.sparsemax 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | 23 | docs-scripts....pytorch_tabnet.callbacks module 24 | ................................................. 25 | 26 | .. automodule:: docs-scripts....pytorch_tabnet.callbacks 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | docs-scripts....pytorch_tabnet.tab_network module 33 | ................................................... 34 | 35 | .. automodule:: docs-scripts....pytorch_tabnet.tab_network 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | 40 | 41 | docs-scripts....pytorch_tabnet.utils module 42 | ............................................. 43 | 44 | .. automodule:: docs-scripts....pytorch_tabnet.utils 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | 49 | 50 | docs-scripts....pytorch_tabnet.multiclass_utils module 51 | ........................................................ 52 | 53 | .. automodule:: docs-scripts....pytorch_tabnet.multiclass_utils 54 | :members: 55 | :undoc-members: 56 | :show-inheritance: 57 | 58 | 59 | docs-scripts....pytorch_tabnet.abstract_model module 60 | ...................................................... 61 | 62 | .. automodule:: docs-scripts....pytorch_tabnet.abstract_model 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | 68 | docs-scripts....pytorch_tabnet.multitask module 69 | ................................................. 70 | 71 | .. automodule:: docs-scripts....pytorch_tabnet.multitask 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | 77 | docs-scripts....pytorch_tabnet.tab_model module 78 | ................................................. 79 | 80 | .. automodule:: docs-scripts....pytorch_tabnet.tab_model 81 | :members: 82 | :undoc-members: 83 | :show-inheritance: 84 | 85 | -------------------------------------------------------------------------------- /docs/_sources/generated_docs/pytorch_tabnet.rst.txt: -------------------------------------------------------------------------------- 1 | pytorch_tabnet package 2 | ====================== 3 | 4 | 5 | pytorch_tabnet.pretraining_utils module 6 | ......................................... 7 | 8 | .. automodule:: pytorch_tabnet.pretraining_utils 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | pytorch_tabnet.augmentations module 15 | ..................................... 16 | 17 | .. automodule:: pytorch_tabnet.augmentations 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | 23 | pytorch_tabnet.tab_network module 24 | ................................... 25 | 26 | .. automodule:: pytorch_tabnet.tab_network 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | pytorch_tabnet.metrics module 33 | ............................... 34 | 35 | .. automodule:: pytorch_tabnet.metrics 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | 40 | 41 | pytorch_tabnet.tab_model module 42 | ................................. 43 | 44 | .. automodule:: pytorch_tabnet.tab_model 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | 49 | 50 | pytorch_tabnet.sparsemax module 51 | ................................. 52 | 53 | .. automodule:: pytorch_tabnet.sparsemax 54 | :members: 55 | :undoc-members: 56 | :show-inheritance: 57 | 58 | 59 | pytorch_tabnet.callbacks module 60 | ................................. 61 | 62 | .. automodule:: pytorch_tabnet.callbacks 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | 68 | pytorch_tabnet.abstract_model module 69 | ...................................... 70 | 71 | .. automodule:: pytorch_tabnet.abstract_model 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | 76 | 77 | pytorch_tabnet.pretraining module 78 | ................................... 79 | 80 | .. automodule:: pytorch_tabnet.pretraining 81 | :members: 82 | :undoc-members: 83 | :show-inheritance: 84 | 85 | 86 | pytorch_tabnet.utils module 87 | ............................. 88 | 89 | .. automodule:: pytorch_tabnet.utils 90 | :members: 91 | :undoc-members: 92 | :show-inheritance: 93 | 94 | 95 | pytorch_tabnet.multitask module 96 | ................................. 97 | 98 | .. automodule:: pytorch_tabnet.multitask 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | 104 | pytorch_tabnet.multiclass_utils module 105 | ........................................ 106 | 107 | .. automodule:: pytorch_tabnet.multiclass_utils 108 | :members: 109 | :undoc-members: 110 | :show-inheritance: 111 | 112 | -------------------------------------------------------------------------------- /docs/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. pytorch_tabnet documentation master file, created by 2 | sphinx-quickstart on Tue Nov 12 10:52:41 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to pytorch_tabnet's documentation! 7 | ========================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | generated_docs/README.md 14 | generated_docs/pytorch_tabnet.rst 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /docs/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_static/default.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Alternate Sphinx design 3 | * Originally created by Armin Ronacher for Werkzeug, adapted by Georg Brandl. 4 | */ 5 | 6 | body { 7 | font-family: 'Lucida Grande', 'Lucida Sans Unicode', 'Geneva', 'Verdana', sans-serif; 8 | font-size: 14px; 9 | letter-spacing: -0.01em; 10 | line-height: 150%; 11 | } 12 | 13 | pre { 14 | font-family: 'Consolas', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 15 | font-size: 0.95em; 16 | letter-spacing: 0.015em; 17 | } 18 | 19 | cite, code, tt { 20 | font-family: 'Consolas', 'Deja Vu Sans Mono', 'Bitstream Vera Sans Mono', monospace; 21 | font-size: 0.95em; 22 | letter-spacing: 0.01em; 23 | } 24 | 25 | .wy-nav-content { 26 | max-width: 100% !important; 27 | } -------------------------------------------------------------------------------- /docs/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Sphinx JavaScript utilities for all documentation. 6 | * 7 | * :copyright: Copyright 2007-2020 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | /** 18 | * make the code below compatible with browsers without 19 | * an installed firebug like debugger 20 | if (!window.console || !console.firebug) { 21 | var names = ["log", "debug", "info", "warn", "error", "assert", "dir", 22 | "dirxml", "group", "groupEnd", "time", "timeEnd", "count", "trace", 23 | "profile", "profileEnd"]; 24 | window.console = {}; 25 | for (var i = 0; i < names.length; ++i) 26 | window.console[names[i]] = function() {}; 27 | } 28 | */ 29 | 30 | /** 31 | * small helper function to urldecode strings 32 | */ 33 | jQuery.urldecode = function(x) { 34 | return decodeURIComponent(x).replace(/\+/g, ' '); 35 | }; 36 | 37 | /** 38 | * small helper function to urlencode strings 39 | */ 40 | jQuery.urlencode = encodeURIComponent; 41 | 42 | /** 43 | * This function returns the parsed url parameters of the 44 | * current request. Multiple values per key are supported, 45 | * it will always return arrays of strings for the value parts. 46 | */ 47 | jQuery.getQueryParameters = function(s) { 48 | if (typeof s === 'undefined') 49 | s = document.location.search; 50 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 51 | var result = {}; 52 | for (var i = 0; i < parts.length; i++) { 53 | var tmp = parts[i].split('=', 2); 54 | var key = jQuery.urldecode(tmp[0]); 55 | var value = jQuery.urldecode(tmp[1]); 56 | if (key in result) 57 | result[key].push(value); 58 | else 59 | result[key] = [value]; 60 | } 61 | return result; 62 | }; 63 | 64 | /** 65 | * highlight a given string on a jquery object by wrapping it in 66 | * span elements with the given class name. 67 | */ 68 | jQuery.fn.highlightText = function(text, className) { 69 | function highlight(node, addItems) { 70 | if (node.nodeType === 3) { 71 | var val = node.nodeValue; 72 | var pos = val.toLowerCase().indexOf(text); 73 | if (pos >= 0 && 74 | !jQuery(node.parentNode).hasClass(className) && 75 | !jQuery(node.parentNode).hasClass("nohighlight")) { 76 | var span; 77 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 78 | if (isInSVG) { 79 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 80 | } else { 81 | span = document.createElement("span"); 82 | span.className = className; 83 | } 84 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 85 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 86 | document.createTextNode(val.substr(pos + text.length)), 87 | node.nextSibling)); 88 | node.nodeValue = val.substr(0, pos); 89 | if (isInSVG) { 90 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 91 | var bbox = node.parentElement.getBBox(); 92 | rect.x.baseVal.value = bbox.x; 93 | rect.y.baseVal.value = bbox.y; 94 | rect.width.baseVal.value = bbox.width; 95 | rect.height.baseVal.value = bbox.height; 96 | rect.setAttribute('class', className); 97 | addItems.push({ 98 | "parent": node.parentNode, 99 | "target": rect}); 100 | } 101 | } 102 | } 103 | else if (!jQuery(node).is("button, select, textarea")) { 104 | jQuery.each(node.childNodes, function() { 105 | highlight(this, addItems); 106 | }); 107 | } 108 | } 109 | var addItems = []; 110 | var result = this.each(function() { 111 | highlight(this, addItems); 112 | }); 113 | for (var i = 0; i < addItems.length; ++i) { 114 | jQuery(addItems[i].parent).before(addItems[i].target); 115 | } 116 | return result; 117 | }; 118 | 119 | /* 120 | * backward compatibility for jQuery.browser 121 | * This will be supported until firefox bug is fixed. 122 | */ 123 | if (!jQuery.browser) { 124 | jQuery.uaMatch = function(ua) { 125 | ua = ua.toLowerCase(); 126 | 127 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 128 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 129 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 130 | /(msie) ([\w.]+)/.exec(ua) || 131 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 132 | []; 133 | 134 | return { 135 | browser: match[ 1 ] || "", 136 | version: match[ 2 ] || "0" 137 | }; 138 | }; 139 | jQuery.browser = {}; 140 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 141 | } 142 | 143 | /** 144 | * Small JavaScript module for the documentation. 145 | */ 146 | var Documentation = { 147 | 148 | init : function() { 149 | this.fixFirefoxAnchorBug(); 150 | this.highlightSearchWords(); 151 | this.initIndexTable(); 152 | if (DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) { 153 | this.initOnKeyListeners(); 154 | } 155 | }, 156 | 157 | /** 158 | * i18n support 159 | */ 160 | TRANSLATIONS : {}, 161 | PLURAL_EXPR : function(n) { return n === 1 ? 0 : 1; }, 162 | LOCALE : 'unknown', 163 | 164 | // gettext and ngettext don't access this so that the functions 165 | // can safely bound to a different name (_ = Documentation.gettext) 166 | gettext : function(string) { 167 | var translated = Documentation.TRANSLATIONS[string]; 168 | if (typeof translated === 'undefined') 169 | return string; 170 | return (typeof translated === 'string') ? translated : translated[0]; 171 | }, 172 | 173 | ngettext : function(singular, plural, n) { 174 | var translated = Documentation.TRANSLATIONS[singular]; 175 | if (typeof translated === 'undefined') 176 | return (n == 1) ? singular : plural; 177 | return translated[Documentation.PLURALEXPR(n)]; 178 | }, 179 | 180 | addTranslations : function(catalog) { 181 | for (var key in catalog.messages) 182 | this.TRANSLATIONS[key] = catalog.messages[key]; 183 | this.PLURAL_EXPR = new Function('n', 'return +(' + catalog.plural_expr + ')'); 184 | this.LOCALE = catalog.locale; 185 | }, 186 | 187 | /** 188 | * add context elements like header anchor links 189 | */ 190 | addContextElements : function() { 191 | $('div[id] > :header:first').each(function() { 192 | $('\u00B6'). 193 | attr('href', '#' + this.id). 194 | attr('title', _('Permalink to this headline')). 195 | appendTo(this); 196 | }); 197 | $('dt[id]').each(function() { 198 | $('\u00B6'). 199 | attr('href', '#' + this.id). 200 | attr('title', _('Permalink to this definition')). 201 | appendTo(this); 202 | }); 203 | }, 204 | 205 | /** 206 | * workaround a firefox stupidity 207 | * see: https://bugzilla.mozilla.org/show_bug.cgi?id=645075 208 | */ 209 | fixFirefoxAnchorBug : function() { 210 | if (document.location.hash && $.browser.mozilla) 211 | window.setTimeout(function() { 212 | document.location.href += ''; 213 | }, 10); 214 | }, 215 | 216 | /** 217 | * highlight the search words provided in the url in the text 218 | */ 219 | highlightSearchWords : function() { 220 | var params = $.getQueryParameters(); 221 | var terms = (params.highlight) ? params.highlight[0].split(/\s+/) : []; 222 | if (terms.length) { 223 | var body = $('div.body'); 224 | if (!body.length) { 225 | body = $('body'); 226 | } 227 | window.setTimeout(function() { 228 | $.each(terms, function() { 229 | body.highlightText(this.toLowerCase(), 'highlighted'); 230 | }); 231 | }, 10); 232 | $('') 234 | .appendTo($('#searchbox')); 235 | } 236 | }, 237 | 238 | /** 239 | * init the domain index toggle buttons 240 | */ 241 | initIndexTable : function() { 242 | var togglers = $('img.toggler').click(function() { 243 | var src = $(this).attr('src'); 244 | var idnum = $(this).attr('id').substr(7); 245 | $('tr.cg-' + idnum).toggle(); 246 | if (src.substr(-9) === 'minus.png') 247 | $(this).attr('src', src.substr(0, src.length-9) + 'plus.png'); 248 | else 249 | $(this).attr('src', src.substr(0, src.length-8) + 'minus.png'); 250 | }).css('display', ''); 251 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) { 252 | togglers.click(); 253 | } 254 | }, 255 | 256 | /** 257 | * helper function to hide the search marks again 258 | */ 259 | hideSearchWords : function() { 260 | $('#searchbox .highlight-link').fadeOut(300); 261 | $('span.highlighted').removeClass('highlighted'); 262 | }, 263 | 264 | /** 265 | * make the url absolute 266 | */ 267 | makeURL : function(relativeURL) { 268 | return DOCUMENTATION_OPTIONS.URL_ROOT + '/' + relativeURL; 269 | }, 270 | 271 | /** 272 | * get the current relative url 273 | */ 274 | getCurrentURL : function() { 275 | var path = document.location.pathname; 276 | var parts = path.split(/\//); 277 | $.each(DOCUMENTATION_OPTIONS.URL_ROOT.split(/\//), function() { 278 | if (this === '..') 279 | parts.pop(); 280 | }); 281 | var url = parts.join('/'); 282 | return path.substring(url.lastIndexOf('/') + 1, path.length - 1); 283 | }, 284 | 285 | initOnKeyListeners: function() { 286 | $(document).keydown(function(event) { 287 | var activeElementType = document.activeElement.tagName; 288 | // don't navigate when in search box or textarea 289 | if (activeElementType !== 'TEXTAREA' && activeElementType !== 'INPUT' && activeElementType !== 'SELECT' 290 | && !event.altKey && !event.ctrlKey && !event.metaKey && !event.shiftKey) { 291 | switch (event.keyCode) { 292 | case 37: // left 293 | var prevHref = $('link[rel="prev"]').prop('href'); 294 | if (prevHref) { 295 | window.location.href = prevHref; 296 | return false; 297 | } 298 | case 39: // right 299 | var nextHref = $('link[rel="next"]').prop('href'); 300 | if (nextHref) { 301 | window.location.href = nextHref; 302 | return false; 303 | } 304 | } 305 | } 306 | }); 307 | } 308 | }; 309 | 310 | // quick alias for translations 311 | _ = Documentation.gettext; 312 | 313 | $(document).ready(function() { 314 | Documentation.init(); 315 | }); 316 | -------------------------------------------------------------------------------- /docs/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | HAS_SOURCE: true, 9 | SOURCELINK_SUFFIX: '.txt', 10 | NAVIGATION_WITH_KEYS: false 11 | }; -------------------------------------------------------------------------------- /docs/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/file.png -------------------------------------------------------------------------------- /docs/_static/fonts/FontAwesome.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/FontAwesome.otf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Light.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Light.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Light.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Light.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Thin.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Thin.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Thin.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/Roboto-Slab-Thin.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_static/graphviz.css: -------------------------------------------------------------------------------- 1 | /* 2 | * graphviz.css 3 | * ~~~~~~~~~~~~ 4 | * 5 | * Sphinx stylesheet -- graphviz extension. 6 | * 7 | * :copyright: Copyright 2007-2020 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | 12 | img.graphviz { 13 | border: 0; 14 | max-width: 100%; 15 | } 16 | 17 | object.graphviz { 18 | max-width: 100%; 19 | } 20 | -------------------------------------------------------------------------------- /docs/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap("
    "),n("table.docutils.footnote").wrap("
    "),n("table.docutils.citation").wrap("
    "),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}t.length>0&&($(".wy-menu-vertical .current").removeClass("current"),t.addClass("current"),t.closest("li.toctree-l1").addClass("current"),t.closest("li.toctree-l1").parent().addClass("current"),t.closest("li.toctree-l1").addClass("current"),t.closest("li.toctree-l2").addClass("current"),t.closest("li.toctree-l3").addClass("current"),t.closest("li.toctree-l4").addClass("current"),t.closest("li.toctree-l5").addClass("current"),t[0].scrollIntoView())}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t0 62 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 63 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 64 | var s_v = "^(" + C + ")?" + v; // vowel in stem 65 | 66 | this.stemWord = function (w) { 67 | var stem; 68 | var suffix; 69 | var firstch; 70 | var origword = w; 71 | 72 | if (w.length < 3) 73 | return w; 74 | 75 | var re; 76 | var re2; 77 | var re3; 78 | var re4; 79 | 80 | firstch = w.substr(0,1); 81 | if (firstch == "y") 82 | w = firstch.toUpperCase() + w.substr(1); 83 | 84 | // Step 1a 85 | re = /^(.+?)(ss|i)es$/; 86 | re2 = /^(.+?)([^s])s$/; 87 | 88 | if (re.test(w)) 89 | w = w.replace(re,"$1$2"); 90 | else if (re2.test(w)) 91 | w = w.replace(re2,"$1$2"); 92 | 93 | // Step 1b 94 | re = /^(.+?)eed$/; 95 | re2 = /^(.+?)(ed|ing)$/; 96 | if (re.test(w)) { 97 | var fp = re.exec(w); 98 | re = new RegExp(mgr0); 99 | if (re.test(fp[1])) { 100 | re = /.$/; 101 | w = w.replace(re,""); 102 | } 103 | } 104 | else if (re2.test(w)) { 105 | var fp = re2.exec(w); 106 | stem = fp[1]; 107 | re2 = new RegExp(s_v); 108 | if (re2.test(stem)) { 109 | w = stem; 110 | re2 = /(at|bl|iz)$/; 111 | re3 = new RegExp("([^aeiouylsz])\\1$"); 112 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 113 | if (re2.test(w)) 114 | w = w + "e"; 115 | else if (re3.test(w)) { 116 | re = /.$/; 117 | w = w.replace(re,""); 118 | } 119 | else if (re4.test(w)) 120 | w = w + "e"; 121 | } 122 | } 123 | 124 | // Step 1c 125 | re = /^(.+?)y$/; 126 | if (re.test(w)) { 127 | var fp = re.exec(w); 128 | stem = fp[1]; 129 | re = new RegExp(s_v); 130 | if (re.test(stem)) 131 | w = stem + "i"; 132 | } 133 | 134 | // Step 2 135 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 136 | if (re.test(w)) { 137 | var fp = re.exec(w); 138 | stem = fp[1]; 139 | suffix = fp[2]; 140 | re = new RegExp(mgr0); 141 | if (re.test(stem)) 142 | w = stem + step2list[suffix]; 143 | } 144 | 145 | // Step 3 146 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 147 | if (re.test(w)) { 148 | var fp = re.exec(w); 149 | stem = fp[1]; 150 | suffix = fp[2]; 151 | re = new RegExp(mgr0); 152 | if (re.test(stem)) 153 | w = stem + step3list[suffix]; 154 | } 155 | 156 | // Step 4 157 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 158 | re2 = /^(.+?)(s|t)(ion)$/; 159 | if (re.test(w)) { 160 | var fp = re.exec(w); 161 | stem = fp[1]; 162 | re = new RegExp(mgr1); 163 | if (re.test(stem)) 164 | w = stem; 165 | } 166 | else if (re2.test(w)) { 167 | var fp = re2.exec(w); 168 | stem = fp[1] + fp[2]; 169 | re2 = new RegExp(mgr1); 170 | if (re2.test(stem)) 171 | w = stem; 172 | } 173 | 174 | // Step 5 175 | re = /^(.+?)e$/; 176 | if (re.test(w)) { 177 | var fp = re.exec(w); 178 | stem = fp[1]; 179 | re = new RegExp(mgr1); 180 | re2 = new RegExp(meq1); 181 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 182 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 183 | w = stem; 184 | } 185 | re = /ll$/; 186 | re2 = new RegExp(mgr1); 187 | if (re.test(w) && re2.test(w)) { 188 | re = /.$/; 189 | w = w.replace(re,""); 190 | } 191 | 192 | // and turn initial Y back to y 193 | if (firstch == "y") 194 | w = firstch.toLowerCase() + w.substr(1); 195 | return w; 196 | } 197 | } 198 | 199 | 200 | 201 | 202 | 203 | var splitChars = (function() { 204 | var result = {}; 205 | var singles = [96, 180, 187, 191, 215, 247, 749, 885, 903, 907, 909, 930, 1014, 1648, 206 | 1748, 1809, 2416, 2473, 2481, 2526, 2601, 2609, 2612, 2615, 2653, 2702, 207 | 2706, 2729, 2737, 2740, 2857, 2865, 2868, 2910, 2928, 2948, 2961, 2971, 208 | 2973, 3085, 3089, 3113, 3124, 3213, 3217, 3241, 3252, 3295, 3341, 3345, 209 | 3369, 3506, 3516, 3633, 3715, 3721, 3736, 3744, 3748, 3750, 3756, 3761, 210 | 3781, 3912, 4239, 4347, 4681, 4695, 4697, 4745, 4785, 4799, 4801, 4823, 211 | 4881, 5760, 5901, 5997, 6313, 7405, 8024, 8026, 8028, 8030, 8117, 8125, 212 | 8133, 8181, 8468, 8485, 8487, 8489, 8494, 8527, 11311, 11359, 11687, 11695, 213 | 11703, 11711, 11719, 11727, 11735, 12448, 12539, 43010, 43014, 43019, 43587, 214 | 43696, 43713, 64286, 64297, 64311, 64317, 64319, 64322, 64325, 65141]; 215 | var i, j, start, end; 216 | for (i = 0; i < singles.length; i++) { 217 | result[singles[i]] = true; 218 | } 219 | var ranges = [[0, 47], [58, 64], [91, 94], [123, 169], [171, 177], [182, 184], [706, 709], 220 | [722, 735], [741, 747], [751, 879], [888, 889], [894, 901], [1154, 1161], 221 | [1318, 1328], [1367, 1368], [1370, 1376], [1416, 1487], [1515, 1519], [1523, 1568], 222 | [1611, 1631], [1642, 1645], [1750, 1764], [1767, 1773], [1789, 1790], [1792, 1807], 223 | [1840, 1868], [1958, 1968], [1970, 1983], [2027, 2035], [2038, 2041], [2043, 2047], 224 | [2070, 2073], [2075, 2083], [2085, 2087], [2089, 2307], [2362, 2364], [2366, 2383], 225 | [2385, 2391], [2402, 2405], [2419, 2424], [2432, 2436], [2445, 2446], [2449, 2450], 226 | [2483, 2485], [2490, 2492], [2494, 2509], [2511, 2523], [2530, 2533], [2546, 2547], 227 | [2554, 2564], [2571, 2574], [2577, 2578], [2618, 2648], [2655, 2661], [2672, 2673], 228 | [2677, 2692], [2746, 2748], [2750, 2767], [2769, 2783], [2786, 2789], [2800, 2820], 229 | [2829, 2830], [2833, 2834], [2874, 2876], [2878, 2907], [2914, 2917], [2930, 2946], 230 | [2955, 2957], [2966, 2968], [2976, 2978], [2981, 2983], [2987, 2989], [3002, 3023], 231 | [3025, 3045], [3059, 3076], [3130, 3132], [3134, 3159], [3162, 3167], [3170, 3173], 232 | [3184, 3191], [3199, 3204], [3258, 3260], [3262, 3293], [3298, 3301], [3312, 3332], 233 | [3386, 3388], [3390, 3423], [3426, 3429], [3446, 3449], [3456, 3460], [3479, 3481], 234 | [3518, 3519], [3527, 3584], [3636, 3647], [3655, 3663], [3674, 3712], [3717, 3718], 235 | [3723, 3724], [3726, 3731], [3752, 3753], [3764, 3772], [3774, 3775], [3783, 3791], 236 | [3802, 3803], [3806, 3839], [3841, 3871], [3892, 3903], [3949, 3975], [3980, 4095], 237 | [4139, 4158], [4170, 4175], [4182, 4185], [4190, 4192], [4194, 4196], [4199, 4205], 238 | [4209, 4212], [4226, 4237], [4250, 4255], [4294, 4303], [4349, 4351], [4686, 4687], 239 | [4702, 4703], [4750, 4751], [4790, 4791], [4806, 4807], [4886, 4887], [4955, 4968], 240 | [4989, 4991], [5008, 5023], [5109, 5120], [5741, 5742], [5787, 5791], [5867, 5869], 241 | [5873, 5887], [5906, 5919], [5938, 5951], [5970, 5983], [6001, 6015], [6068, 6102], 242 | [6104, 6107], [6109, 6111], [6122, 6127], [6138, 6159], [6170, 6175], [6264, 6271], 243 | [6315, 6319], [6390, 6399], [6429, 6469], [6510, 6511], [6517, 6527], [6572, 6592], 244 | [6600, 6607], [6619, 6655], [6679, 6687], [6741, 6783], [6794, 6799], [6810, 6822], 245 | [6824, 6916], [6964, 6980], [6988, 6991], [7002, 7042], [7073, 7085], [7098, 7167], 246 | [7204, 7231], [7242, 7244], [7294, 7400], [7410, 7423], [7616, 7679], [7958, 7959], 247 | [7966, 7967], [8006, 8007], [8014, 8015], [8062, 8063], [8127, 8129], [8141, 8143], 248 | [8148, 8149], [8156, 8159], [8173, 8177], [8189, 8303], [8306, 8307], [8314, 8318], 249 | [8330, 8335], [8341, 8449], [8451, 8454], [8456, 8457], [8470, 8472], [8478, 8483], 250 | [8506, 8507], [8512, 8516], [8522, 8525], [8586, 9311], [9372, 9449], [9472, 10101], 251 | [10132, 11263], [11493, 11498], [11503, 11516], [11518, 11519], [11558, 11567], 252 | [11622, 11630], [11632, 11647], [11671, 11679], [11743, 11822], [11824, 12292], 253 | [12296, 12320], [12330, 12336], [12342, 12343], [12349, 12352], [12439, 12444], 254 | [12544, 12548], [12590, 12592], [12687, 12689], [12694, 12703], [12728, 12783], 255 | [12800, 12831], [12842, 12880], [12896, 12927], [12938, 12976], [12992, 13311], 256 | [19894, 19967], [40908, 40959], [42125, 42191], [42238, 42239], [42509, 42511], 257 | [42540, 42559], [42592, 42593], [42607, 42622], [42648, 42655], [42736, 42774], 258 | [42784, 42785], [42889, 42890], [42893, 43002], [43043, 43055], [43062, 43071], 259 | [43124, 43137], [43188, 43215], [43226, 43249], [43256, 43258], [43260, 43263], 260 | [43302, 43311], [43335, 43359], [43389, 43395], [43443, 43470], [43482, 43519], 261 | [43561, 43583], [43596, 43599], [43610, 43615], [43639, 43641], [43643, 43647], 262 | [43698, 43700], [43703, 43704], [43710, 43711], [43715, 43738], [43742, 43967], 263 | [44003, 44015], [44026, 44031], [55204, 55215], [55239, 55242], [55292, 55295], 264 | [57344, 63743], [64046, 64047], [64110, 64111], [64218, 64255], [64263, 64274], 265 | [64280, 64284], [64434, 64466], [64830, 64847], [64912, 64913], [64968, 65007], 266 | [65020, 65135], [65277, 65295], [65306, 65312], [65339, 65344], [65371, 65381], 267 | [65471, 65473], [65480, 65481], [65488, 65489], [65496, 65497]]; 268 | for (i = 0; i < ranges.length; i++) { 269 | start = ranges[i][0]; 270 | end = ranges[i][1]; 271 | for (j = start; j <= end; j++) { 272 | result[j] = true; 273 | } 274 | } 275 | return result; 276 | })(); 277 | 278 | function splitQuery(query) { 279 | var result = []; 280 | var start = -1; 281 | for (var i = 0; i < query.length; i++) { 282 | if (splitChars[query.charCodeAt(i)]) { 283 | if (start !== -1) { 284 | result.push(query.slice(start, i)); 285 | start = -1; 286 | } 287 | } else if (start === -1) { 288 | start = i; 289 | } 290 | } 291 | if (start !== -1) { 292 | result.push(query.slice(start)); 293 | } 294 | return result; 295 | } 296 | 297 | 298 | -------------------------------------------------------------------------------- /docs/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/minus.png -------------------------------------------------------------------------------- /docs/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dreamquark-ai/tabnet/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/docs/_static/plus.png -------------------------------------------------------------------------------- /docs/_static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #eeffcc; } 8 | .highlight .c { color: #408090; font-style: italic } /* Comment */ 9 | .highlight .err { border: 1px solid #FF0000 } /* Error */ 10 | .highlight .k { color: #007020; font-weight: bold } /* Keyword */ 11 | .highlight .o { color: #666666 } /* Operator */ 12 | .highlight .ch { color: #408090; font-style: italic } /* Comment.Hashbang */ 13 | .highlight .cm { color: #408090; font-style: italic } /* Comment.Multiline */ 14 | .highlight .cp { color: #007020 } /* Comment.Preproc */ 15 | .highlight .cpf { color: #408090; font-style: italic } /* Comment.PreprocFile */ 16 | .highlight .c1 { color: #408090; font-style: italic } /* Comment.Single */ 17 | .highlight .cs { color: #408090; background-color: #fff0f0 } /* Comment.Special */ 18 | .highlight .gd { color: #A00000 } /* Generic.Deleted */ 19 | .highlight .ge { font-style: italic } /* Generic.Emph */ 20 | .highlight .gr { color: #FF0000 } /* Generic.Error */ 21 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 22 | .highlight .gi { color: #00A000 } /* Generic.Inserted */ 23 | .highlight .go { color: #333333 } /* Generic.Output */ 24 | .highlight .gp { color: #c65d09; font-weight: bold } /* Generic.Prompt */ 25 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 26 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 27 | .highlight .gt { color: #0044DD } /* Generic.Traceback */ 28 | .highlight .kc { color: #007020; font-weight: bold } /* Keyword.Constant */ 29 | .highlight .kd { color: #007020; font-weight: bold } /* Keyword.Declaration */ 30 | .highlight .kn { color: #007020; font-weight: bold } /* Keyword.Namespace */ 31 | .highlight .kp { color: #007020 } /* Keyword.Pseudo */ 32 | .highlight .kr { color: #007020; font-weight: bold } /* Keyword.Reserved */ 33 | .highlight .kt { color: #902000 } /* Keyword.Type */ 34 | .highlight .m { color: #208050 } /* Literal.Number */ 35 | .highlight .s { color: #4070a0 } /* Literal.String */ 36 | .highlight .na { color: #4070a0 } /* Name.Attribute */ 37 | .highlight .nb { color: #007020 } /* Name.Builtin */ 38 | .highlight .nc { color: #0e84b5; font-weight: bold } /* Name.Class */ 39 | .highlight .no { color: #60add5 } /* Name.Constant */ 40 | .highlight .nd { color: #555555; font-weight: bold } /* Name.Decorator */ 41 | .highlight .ni { color: #d55537; font-weight: bold } /* Name.Entity */ 42 | .highlight .ne { color: #007020 } /* Name.Exception */ 43 | .highlight .nf { color: #06287e } /* Name.Function */ 44 | .highlight .nl { color: #002070; font-weight: bold } /* Name.Label */ 45 | .highlight .nn { color: #0e84b5; font-weight: bold } /* Name.Namespace */ 46 | .highlight .nt { color: #062873; font-weight: bold } /* Name.Tag */ 47 | .highlight .nv { color: #bb60d5 } /* Name.Variable */ 48 | .highlight .ow { color: #007020; font-weight: bold } /* Operator.Word */ 49 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */ 50 | .highlight .mb { color: #208050 } /* Literal.Number.Bin */ 51 | .highlight .mf { color: #208050 } /* Literal.Number.Float */ 52 | .highlight .mh { color: #208050 } /* Literal.Number.Hex */ 53 | .highlight .mi { color: #208050 } /* Literal.Number.Integer */ 54 | .highlight .mo { color: #208050 } /* Literal.Number.Oct */ 55 | .highlight .sa { color: #4070a0 } /* Literal.String.Affix */ 56 | .highlight .sb { color: #4070a0 } /* Literal.String.Backtick */ 57 | .highlight .sc { color: #4070a0 } /* Literal.String.Char */ 58 | .highlight .dl { color: #4070a0 } /* Literal.String.Delimiter */ 59 | .highlight .sd { color: #4070a0; font-style: italic } /* Literal.String.Doc */ 60 | .highlight .s2 { color: #4070a0 } /* Literal.String.Double */ 61 | .highlight .se { color: #4070a0; font-weight: bold } /* Literal.String.Escape */ 62 | .highlight .sh { color: #4070a0 } /* Literal.String.Heredoc */ 63 | .highlight .si { color: #70a0d0; font-style: italic } /* Literal.String.Interpol */ 64 | .highlight .sx { color: #c65d09 } /* Literal.String.Other */ 65 | .highlight .sr { color: #235388 } /* Literal.String.Regex */ 66 | .highlight .s1 { color: #4070a0 } /* Literal.String.Single */ 67 | .highlight .ss { color: #517918 } /* Literal.String.Symbol */ 68 | .highlight .bp { color: #007020 } /* Name.Builtin.Pseudo */ 69 | .highlight .fm { color: #06287e } /* Name.Function.Magic */ 70 | .highlight .vc { color: #bb60d5 } /* Name.Variable.Class */ 71 | .highlight .vg { color: #bb60d5 } /* Name.Variable.Global */ 72 | .highlight .vi { color: #bb60d5 } /* Name.Variable.Instance */ 73 | .highlight .vm { color: #bb60d5 } /* Name.Variable.Magic */ 74 | .highlight .il { color: #208050 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/_static/underscore.js: -------------------------------------------------------------------------------- 1 | // Underscore.js 1.3.1 2 | // (c) 2009-2012 Jeremy Ashkenas, DocumentCloud Inc. 3 | // Underscore is freely distributable under the MIT license. 4 | // Portions of Underscore are inspired or borrowed from Prototype, 5 | // Oliver Steele's Functional, and John Resig's Micro-Templating. 6 | // For all details and documentation: 7 | // http://documentcloud.github.com/underscore 8 | (function(){function q(a,c,d){if(a===c)return a!==0||1/a==1/c;if(a==null||c==null)return a===c;if(a._chain)a=a._wrapped;if(c._chain)c=c._wrapped;if(a.isEqual&&b.isFunction(a.isEqual))return a.isEqual(c);if(c.isEqual&&b.isFunction(c.isEqual))return c.isEqual(a);var e=l.call(a);if(e!=l.call(c))return false;switch(e){case "[object String]":return a==String(c);case "[object Number]":return a!=+a?c!=+c:a==0?1/a==1/c:a==+c;case "[object Date]":case "[object Boolean]":return+a==+c;case "[object RegExp]":return a.source== 9 | c.source&&a.global==c.global&&a.multiline==c.multiline&&a.ignoreCase==c.ignoreCase}if(typeof a!="object"||typeof c!="object")return false;for(var f=d.length;f--;)if(d[f]==a)return true;d.push(a);var f=0,g=true;if(e=="[object Array]"){if(f=a.length,g=f==c.length)for(;f--;)if(!(g=f in a==f in c&&q(a[f],c[f],d)))break}else{if("constructor"in a!="constructor"in c||a.constructor!=c.constructor)return false;for(var h in a)if(b.has(a,h)&&(f++,!(g=b.has(c,h)&&q(a[h],c[h],d))))break;if(g){for(h in c)if(b.has(c, 10 | h)&&!f--)break;g=!f}}d.pop();return g}var r=this,G=r._,n={},k=Array.prototype,o=Object.prototype,i=k.slice,H=k.unshift,l=o.toString,I=o.hasOwnProperty,w=k.forEach,x=k.map,y=k.reduce,z=k.reduceRight,A=k.filter,B=k.every,C=k.some,p=k.indexOf,D=k.lastIndexOf,o=Array.isArray,J=Object.keys,s=Function.prototype.bind,b=function(a){return new m(a)};if(typeof exports!=="undefined"){if(typeof module!=="undefined"&&module.exports)exports=module.exports=b;exports._=b}else r._=b;b.VERSION="1.3.1";var j=b.each= 11 | b.forEach=function(a,c,d){if(a!=null)if(w&&a.forEach===w)a.forEach(c,d);else if(a.length===+a.length)for(var e=0,f=a.length;e2;a== 12 | null&&(a=[]);if(y&&a.reduce===y)return e&&(c=b.bind(c,e)),f?a.reduce(c,d):a.reduce(c);j(a,function(a,b,i){f?d=c.call(e,d,a,b,i):(d=a,f=true)});if(!f)throw new TypeError("Reduce of empty array with no initial value");return d};b.reduceRight=b.foldr=function(a,c,d,e){var f=arguments.length>2;a==null&&(a=[]);if(z&&a.reduceRight===z)return e&&(c=b.bind(c,e)),f?a.reduceRight(c,d):a.reduceRight(c);var g=b.toArray(a).reverse();e&&!f&&(c=b.bind(c,e));return f?b.reduce(g,c,d,e):b.reduce(g,c)};b.find=b.detect= 13 | function(a,c,b){var e;E(a,function(a,g,h){if(c.call(b,a,g,h))return e=a,true});return e};b.filter=b.select=function(a,c,b){var e=[];if(a==null)return e;if(A&&a.filter===A)return a.filter(c,b);j(a,function(a,g,h){c.call(b,a,g,h)&&(e[e.length]=a)});return e};b.reject=function(a,c,b){var e=[];if(a==null)return e;j(a,function(a,g,h){c.call(b,a,g,h)||(e[e.length]=a)});return e};b.every=b.all=function(a,c,b){var e=true;if(a==null)return e;if(B&&a.every===B)return a.every(c,b);j(a,function(a,g,h){if(!(e= 14 | e&&c.call(b,a,g,h)))return n});return e};var E=b.some=b.any=function(a,c,d){c||(c=b.identity);var e=false;if(a==null)return e;if(C&&a.some===C)return a.some(c,d);j(a,function(a,b,h){if(e||(e=c.call(d,a,b,h)))return n});return!!e};b.include=b.contains=function(a,c){var b=false;if(a==null)return b;return p&&a.indexOf===p?a.indexOf(c)!=-1:b=E(a,function(a){return a===c})};b.invoke=function(a,c){var d=i.call(arguments,2);return b.map(a,function(a){return(b.isFunction(c)?c||a:a[c]).apply(a,d)})};b.pluck= 15 | function(a,c){return b.map(a,function(a){return a[c]})};b.max=function(a,c,d){if(!c&&b.isArray(a))return Math.max.apply(Math,a);if(!c&&b.isEmpty(a))return-Infinity;var e={computed:-Infinity};j(a,function(a,b,h){b=c?c.call(d,a,b,h):a;b>=e.computed&&(e={value:a,computed:b})});return e.value};b.min=function(a,c,d){if(!c&&b.isArray(a))return Math.min.apply(Math,a);if(!c&&b.isEmpty(a))return Infinity;var e={computed:Infinity};j(a,function(a,b,h){b=c?c.call(d,a,b,h):a;bd?1:0}),"value")};b.groupBy=function(a,c){var d={},e=b.isFunction(c)?c:function(a){return a[c]};j(a,function(a,b){var c=e(a,b);(d[c]||(d[c]=[])).push(a)});return d};b.sortedIndex=function(a, 17 | c,d){d||(d=b.identity);for(var e=0,f=a.length;e>1;d(a[g])=0})})};b.difference=function(a){var c=b.flatten(i.call(arguments,1));return b.filter(a,function(a){return!b.include(c,a)})};b.zip=function(){for(var a=i.call(arguments),c=b.max(b.pluck(a,"length")),d=Array(c),e=0;e=0;d--)b=[a[d].apply(this,b)];return b[0]}}; 24 | b.after=function(a,b){return a<=0?b():function(){if(--a<1)return b.apply(this,arguments)}};b.keys=J||function(a){if(a!==Object(a))throw new TypeError("Invalid object");var c=[],d;for(d in a)b.has(a,d)&&(c[c.length]=d);return c};b.values=function(a){return b.map(a,b.identity)};b.functions=b.methods=function(a){var c=[],d;for(d in a)b.isFunction(a[d])&&c.push(d);return c.sort()};b.extend=function(a){j(i.call(arguments,1),function(b){for(var d in b)a[d]=b[d]});return a};b.defaults=function(a){j(i.call(arguments, 25 | 1),function(b){for(var d in b)a[d]==null&&(a[d]=b[d])});return a};b.clone=function(a){return!b.isObject(a)?a:b.isArray(a)?a.slice():b.extend({},a)};b.tap=function(a,b){b(a);return a};b.isEqual=function(a,b){return q(a,b,[])};b.isEmpty=function(a){if(b.isArray(a)||b.isString(a))return a.length===0;for(var c in a)if(b.has(a,c))return false;return true};b.isElement=function(a){return!!(a&&a.nodeType==1)};b.isArray=o||function(a){return l.call(a)=="[object Array]"};b.isObject=function(a){return a===Object(a)}; 26 | b.isArguments=function(a){return l.call(a)=="[object Arguments]"};if(!b.isArguments(arguments))b.isArguments=function(a){return!(!a||!b.has(a,"callee"))};b.isFunction=function(a){return l.call(a)=="[object Function]"};b.isString=function(a){return l.call(a)=="[object String]"};b.isNumber=function(a){return l.call(a)=="[object Number]"};b.isNaN=function(a){return a!==a};b.isBoolean=function(a){return a===true||a===false||l.call(a)=="[object Boolean]"};b.isDate=function(a){return l.call(a)=="[object Date]"}; 27 | b.isRegExp=function(a){return l.call(a)=="[object RegExp]"};b.isNull=function(a){return a===null};b.isUndefined=function(a){return a===void 0};b.has=function(a,b){return I.call(a,b)};b.noConflict=function(){r._=G;return this};b.identity=function(a){return a};b.times=function(a,b,d){for(var e=0;e/g,">").replace(/"/g,""").replace(/'/g,"'").replace(/\//g,"/")};b.mixin=function(a){j(b.functions(a), 28 | function(c){K(c,b[c]=a[c])})};var L=0;b.uniqueId=function(a){var b=L++;return a?a+b:b};b.templateSettings={evaluate:/<%([\s\S]+?)%>/g,interpolate:/<%=([\s\S]+?)%>/g,escape:/<%-([\s\S]+?)%>/g};var t=/.^/,u=function(a){return a.replace(/\\\\/g,"\\").replace(/\\'/g,"'")};b.template=function(a,c){var d=b.templateSettings,d="var __p=[],print=function(){__p.push.apply(__p,arguments);};with(obj||{}){__p.push('"+a.replace(/\\/g,"\\\\").replace(/'/g,"\\'").replace(d.escape||t,function(a,b){return"',_.escape("+ 29 | u(b)+"),'"}).replace(d.interpolate||t,function(a,b){return"',"+u(b)+",'"}).replace(d.evaluate||t,function(a,b){return"');"+u(b).replace(/[\r\n\t]/g," ")+";__p.push('"}).replace(/\r/g,"\\r").replace(/\n/g,"\\n").replace(/\t/g,"\\t")+"');}return __p.join('');",e=new Function("obj","_",d);return c?e(c,b):function(a){return e.call(this,a,b)}};b.chain=function(a){return b(a).chain()};var m=function(a){this._wrapped=a};b.prototype=m.prototype;var v=function(a,c){return c?b(a).chain():a},K=function(a,c){m.prototype[a]= 30 | function(){var a=i.call(arguments);H.call(a,this._wrapped);return v(c.apply(b,a),this._chain)}};b.mixin(b);j("pop,push,reverse,shift,sort,splice,unshift".split(","),function(a){var b=k[a];m.prototype[a]=function(){var d=this._wrapped;b.apply(d,arguments);var e=d.length;(a=="shift"||a=="splice")&&e===0&&delete d[0];return v(d,this._chain)}});j(["concat","join","slice"],function(a){var b=k[a];m.prototype[a]=function(){return v(b.apply(this._wrapped,arguments),this._chain)}});m.prototype.chain=function(){this._chain= 31 | true;return this};m.prototype.value=function(){return this._wrapped}}).call(this); 32 | -------------------------------------------------------------------------------- /docs/generated_docs/docs-scripts....pytorch_tabnet.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | docs-scripts….pytorch_tabnet package — pytorch_tabnet documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
    47 | 48 | 101 | 102 |
    103 | 104 | 105 | 111 | 112 | 113 |
    114 | 115 |
    116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 |
    134 | 135 |
      136 | 137 |
    • »
    • 138 | 139 |
    • docs-scripts….pytorch_tabnet package
    • 140 | 141 | 142 |
    • 143 | 144 | 145 | View page source 146 | 147 | 148 |
    • 149 | 150 |
    151 | 152 | 153 |
    154 |
    155 |
    156 |
    157 | 158 |
    159 |

    docs-scripts….pytorch_tabnet package

    160 |
    161 |

    docs-scripts….pytorch_tabnet.metrics module

    162 |
    163 |
    164 |

    docs-scripts….pytorch_tabnet.sparsemax module

    165 |
    166 |
    167 |

    docs-scripts….pytorch_tabnet.callbacks module

    168 |
    169 |
    170 |

    docs-scripts….pytorch_tabnet.tab_network module

    171 |
    172 |
    173 |

    docs-scripts….pytorch_tabnet.utils module

    174 |
    175 |
    176 |

    docs-scripts….pytorch_tabnet.multiclass_utils module

    177 |
    178 |
    179 |

    docs-scripts….pytorch_tabnet.abstract_model module

    180 |
    181 |
    182 |

    docs-scripts….pytorch_tabnet.multitask module

    183 |
    184 |
    185 |

    docs-scripts….pytorch_tabnet.tab_model module

    186 |
    187 |
    188 | 189 | 190 |
    191 | 192 |
    193 |
    194 | 195 | 196 |
    197 | 198 |
    199 |

    200 | 201 | © Copyright 2019, Dreamquark 202 | 203 |

    204 |
    205 | 206 | 207 | 208 | Built with Sphinx using a 209 | 210 | theme 211 | 212 | provided by Read the Docs. 213 | 214 |
    215 | 216 |
    217 |
    218 | 219 |
    220 | 221 |
    222 | 223 | 224 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /docs/py-modindex.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Python Module Index — pytorch_tabnet documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |
    50 | 51 | 110 | 111 |
    112 | 113 | 114 | 120 | 121 | 122 |
    123 | 124 |
    125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 |
    143 | 144 |
      145 | 146 |
    • »
    • 147 | 148 |
    • Python Module Index
    • 149 | 150 | 151 |
    • 152 | 153 |
    • 154 | 155 |
    156 | 157 | 158 |
    159 |
    160 |
    161 |
    162 | 163 | 164 |

    Python Module Index

    165 | 166 |
    167 | p 168 |
    169 | 170 | 171 | 172 | 174 | 175 | 177 | 180 | 181 | 182 | 185 | 186 | 187 | 190 | 191 | 192 | 195 | 196 | 197 | 200 | 201 | 202 | 205 | 206 | 207 | 210 | 211 | 212 | 215 | 216 | 217 | 220 | 221 | 222 | 225 | 226 | 227 | 230 | 231 | 232 | 235 | 236 | 237 | 240 |
     
    173 | p
    178 | pytorch_tabnet 179 |
        183 | pytorch_tabnet.abstract_model 184 |
        188 | pytorch_tabnet.augmentations 189 |
        193 | pytorch_tabnet.callbacks 194 |
        198 | pytorch_tabnet.metrics 199 |
        203 | pytorch_tabnet.multiclass_utils 204 |
        208 | pytorch_tabnet.multitask 209 |
        213 | pytorch_tabnet.pretraining 214 |
        218 | pytorch_tabnet.pretraining_utils 219 |
        223 | pytorch_tabnet.sparsemax 224 |
        228 | pytorch_tabnet.tab_model 229 |
        233 | pytorch_tabnet.tab_network 234 |
        238 | pytorch_tabnet.utils 239 |
    241 | 242 | 243 |
    244 | 245 |
    246 |
    247 | 248 | 249 |
    250 | 251 |
    252 |

    253 | 254 | © Copyright 2019, Dreamquark 255 | 256 |

    257 |
    258 | 259 | 260 | 261 | Built with Sphinx using a 262 | 263 | theme 264 | 265 | provided by Read the Docs. 266 | 267 |
    268 | 269 |
    270 |
    271 | 272 |
    273 | 274 |
    275 | 276 | 277 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | -------------------------------------------------------------------------------- /docs/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | Search — pytorch_tabnet documentation 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 |
    49 | 50 | 109 | 110 |
    111 | 112 | 113 | 119 | 120 | 121 |
    122 | 123 |
    124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 |
    142 | 143 |
      144 | 145 |
    • »
    • 146 | 147 |
    • Search
    • 148 | 149 | 150 |
    • 151 | 152 | 153 | 154 |
    • 155 | 156 |
    157 | 158 | 159 |
    160 |
    161 |
    162 |
    163 | 164 | 171 | 172 | 173 |
    174 | 175 |
    176 | 177 |
    178 | 179 |
    180 |
    181 | 182 | 183 |
    184 | 185 |
    186 |

    187 | 188 | © Copyright 2019, Dreamquark 189 | 190 |

    191 |
    192 | 193 | 194 | 195 | Built with Sphinx using a 196 | 197 | theme 198 | 199 | provided by Read the Docs. 200 | 201 |
    202 | 203 |
    204 |
    205 | 206 |
    207 | 208 |
    209 | 210 | 211 | 216 | 217 | 218 | 219 | 220 | 221 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | -------------------------------------------------------------------------------- /multi_regression_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pytorch_tabnet.tab_model import TabNetRegressor\n", 10 | "\n", 11 | "import torch\n", 12 | "from sklearn.preprocessing import LabelEncoder\n", 13 | "from sklearn.metrics import mean_squared_error\n", 14 | "\n", 15 | "import pandas as pd\n", 16 | "import numpy as np\n", 17 | "np.random.seed(0)\n", 18 | "\n", 19 | "\n", 20 | "import os\n", 21 | "import wget\n", 22 | "from pathlib import Path\n", 23 | "\n", 24 | "\n", 25 | "%load_ext autoreload\n", 26 | "\n", 27 | "%autoreload 2" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "# Download census-income dataset" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", 44 | "dataset_name = 'census-income'\n", 45 | "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "out.parent.mkdir(parents=True, exist_ok=True)\n", 55 | "if out.exists():\n", 56 | " print(\"File already exists.\")\n", 57 | "else:\n", 58 | " print(\"Downloading file...\")\n", 59 | " wget.download(url, out.as_posix())" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "# Load data and split" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "train = pd.read_csv(out)\n", 76 | "target = ' <=50K'\n", 77 | "if \"Set\" not in train.columns:\n", 78 | " train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n", 79 | "\n", 80 | "train_indices = train[train.Set==\"train\"].index\n", 81 | "valid_indices = train[train.Set==\"valid\"].index\n", 82 | "test_indices = train[train.Set==\"test\"].index" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "# Simple preprocessing\n", 90 | "\n", 91 | "Label encode categorical features and fill empty cells." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "categorical_columns = []\n", 101 | "categorical_dims = {}\n", 102 | "for col in train.columns[train.dtypes == object]:\n", 103 | " print(col, train[col].nunique())\n", 104 | " l_enc = LabelEncoder()\n", 105 | " train[col] = train[col].fillna(\"VV_likely\")\n", 106 | " train[col] = l_enc.fit_transform(train[col].values)\n", 107 | " categorical_columns.append(col)\n", 108 | " categorical_dims[col] = len(l_enc.classes_)\n", 109 | "\n", 110 | "for col in train.columns[train.dtypes == 'float64']:\n", 111 | " train.fillna(train.loc[train_indices, col].mean(), inplace=True)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "# Define categorical features for categorical embeddings" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "unused_feat = ['Set']\n", 128 | "\n", 129 | "features = [ col for col in train.columns if col not in unused_feat+[target]] \n", 130 | "\n", 131 | "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", 132 | "\n", 133 | "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n", 134 | "\n", 135 | "# define your embedding sizes : here just a random choice\n", 136 | "cat_emb_dim = [5, 4, 3, 6, 2, 2, 1, 10]" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "# Network parameters" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "clf = TabNetRegressor(cat_dims=cat_dims, cat_emb_dim=cat_emb_dim, cat_idxs=cat_idxs)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "# Training" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "### We will simulate 5 targets here to perform multi regression without changing anything!" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "n_targets = 8\n", 176 | "\n", 177 | "X_train = train[features].values[train_indices]\n", 178 | "y_train = train[target].values[train_indices]\n", 179 | "y_train = np.transpose(np.tile(y_train, (n_targets,1)))\n", 180 | "\n", 181 | "X_valid = train[features].values[valid_indices]\n", 182 | "y_valid = train[target].values[valid_indices]\n", 183 | "y_valid = np.transpose(np.tile(y_valid, (n_targets,1)))\n", 184 | "\n", 185 | "X_test = train[features].values[test_indices]\n", 186 | "y_test = train[target].values[test_indices]\n", 187 | "y_test = np.transpose(np.tile(y_test, (n_targets,1)))" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "max_epochs = 1000 if not os.getenv(\"CI\", False) else 2" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": { 203 | "scrolled": false 204 | }, 205 | "outputs": [], 206 | "source": [ 207 | "clf.fit(\n", 208 | " X_train=X_train, y_train=y_train,\n", 209 | " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", 210 | " eval_name=['train', 'valid'],\n", 211 | " eval_metric=['rmsle', 'mae', 'rmse', 'mse'],\n", 212 | " max_epochs=max_epochs,\n", 213 | " patience=50,\n", 214 | " batch_size=1024, virtual_batch_size=128,\n", 215 | " num_workers=0,\n", 216 | " drop_last=False\n", 217 | ") " 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "preds = clf.predict(X_test)\n", 227 | "\n", 228 | "test_mse = mean_squared_error(y_pred=preds, y_true=y_test)\n", 229 | "\n", 230 | "print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n", 231 | "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_mse}\")" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "# Global explainability : feat importance summing to 1" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "clf.feature_importances_" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "# Local explainability and masks" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "explain_matrix, masks = clf.explain(X_test)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "from matplotlib import pyplot as plt\n", 273 | "%matplotlib inline" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "fig, axs = plt.subplots(1, 3, figsize=(20,20))\n", 283 | "\n", 284 | "for i in range(3):\n", 285 | " axs[i].imshow(masks[i][:50])\n", 286 | " axs[i].set_title(f\"mask {i}\")\n" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "# XGB : unfortunately this is still not possible with XGBoost\n", 294 | "\n", 295 | "https://github.com/dmlc/xgboost/issues/2087" 296 | ] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.7.6" 316 | }, 317 | "toc": { 318 | "base_numbering": 1, 319 | "nav_menu": {}, 320 | "number_sections": true, 321 | "sideBar": true, 322 | "skip_h1_title": false, 323 | "title_cell": "Table of Contents", 324 | "title_sidebar": "Contents", 325 | "toc_cell": false, 326 | "toc_position": {}, 327 | "toc_section_display": true, 328 | "toc_window_display": false 329 | } 330 | }, 331 | "nbformat": 4, 332 | "nbformat_minor": 2 333 | } 334 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [installer] 2 | parallel = false 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pytorch_tabnet" 3 | version = "4.1.0" 4 | description = "PyTorch implementation of TabNet" 5 | homepage = "https://github.com/dreamquark-ai/tabnet" 6 | repository = "https://github.com/dreamquark-ai/tabnet" 7 | documentation = "https://github.com/dreamquark-ai/tabnet" 8 | readme = "README.md" 9 | authors = [] 10 | keywords = ["tabnet", "pytorch", "neural-networks" ] 11 | exclude = ["tabnet/*.ipynb"] 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.7" 15 | 16 | numpy=">=1.17" 17 | torch=">=1.3" 18 | tqdm=">=4.36" 19 | scikit_learn=">0.21" 20 | scipy=">1.4" 21 | 22 | [tool.poetry.dev-dependencies] 23 | jupyter="1.0.0" 24 | xgboost="0.90" 25 | matplotlib="3.1.1" 26 | wget="3.2" 27 | pandas="0.25.3" 28 | flake8="3.7.9" 29 | sphinx = "^2.2" 30 | sphinx-rtd-theme = "0.5.0" 31 | recommonmark = "0.6.0" 32 | pytest = "6.2.5" 33 | 34 | [build-system] 35 | requires = ["poetry>=0.12"] 36 | build-backend = "poetry.masonry.api" 37 | -------------------------------------------------------------------------------- /pytorch_tabnet/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_tabnet.utils import define_device 3 | import numpy as np 4 | 5 | 6 | class RegressionSMOTE(): 7 | """ 8 | Apply SMOTE 9 | 10 | This will average a percentage p of the elements in the batch with other elements. 11 | The target will be averaged as well (this might work with binary classification 12 | and certain loss), following a beta distribution. 13 | """ 14 | def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): 15 | "" 16 | self.seed = seed 17 | self._set_seed() 18 | self.device = define_device(device_name) 19 | self.alpha = alpha 20 | self.beta = beta 21 | self.p = p 22 | if (p < 0.) or (p > 1.0): 23 | raise ValueError("Value of p should be between 0. and 1.") 24 | 25 | def _set_seed(self): 26 | torch.manual_seed(self.seed) 27 | np.random.seed(self.seed) 28 | return 29 | 30 | def __call__(self, X, y): 31 | batch_size = X.shape[0] 32 | random_values = torch.rand(batch_size, device=self.device) 33 | idx_to_change = random_values < self.p 34 | 35 | # ensure that first element to switch has probability > 0.5 36 | np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5 37 | random_betas = torch.from_numpy(np_betas).to(self.device).float() 38 | index_permute = torch.randperm(batch_size, device=self.device) 39 | 40 | X[idx_to_change] = random_betas[idx_to_change, None] * X[idx_to_change] 41 | X[idx_to_change] += (1 - random_betas[idx_to_change, None]) * X[index_permute][idx_to_change].view(X[idx_to_change].size()) # noqa 42 | 43 | y[idx_to_change] = random_betas[idx_to_change, None] * y[idx_to_change] 44 | y[idx_to_change] += (1 - random_betas[idx_to_change, None]) * y[index_permute][idx_to_change].view(y[idx_to_change].size()) # noqa 45 | 46 | return X, y 47 | 48 | 49 | class ClassificationSMOTE(): 50 | """ 51 | Apply SMOTE for classification tasks. 52 | 53 | This will average a percentage p of the elements in the batch with other elements. 54 | The target will stay unchanged and keep the value of the most important row in the mix. 55 | """ 56 | def __init__(self, device_name="auto", p=0.8, alpha=0.5, beta=0.5, seed=0): 57 | "" 58 | self.seed = seed 59 | self._set_seed() 60 | self.device = define_device(device_name) 61 | self.alpha = alpha 62 | self.beta = beta 63 | self.p = p 64 | if (p < 0.) or (p > 1.0): 65 | raise ValueError("Value of p should be between 0. and 1.") 66 | 67 | def _set_seed(self): 68 | torch.manual_seed(self.seed) 69 | np.random.seed(self.seed) 70 | return 71 | 72 | def __call__(self, X, y): 73 | batch_size = X.shape[0] 74 | random_values = torch.rand(batch_size, device=self.device) 75 | idx_to_change = random_values < self.p 76 | 77 | # ensure that first element to switch has probability > 0.5 78 | np_betas = np.random.beta(self.alpha, self.beta, batch_size) / 2 + 0.5 79 | random_betas = torch.from_numpy(np_betas).to(self.device).float() 80 | index_permute = torch.randperm(batch_size, device=self.device) 81 | 82 | X[idx_to_change] = random_betas[idx_to_change, None] * X[idx_to_change] 83 | X[idx_to_change] += (1 - random_betas[idx_to_change, None]) * X[index_permute][idx_to_change].view(X[idx_to_change].size()) # noqa 84 | 85 | return X, y 86 | -------------------------------------------------------------------------------- /pytorch_tabnet/callbacks.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import copy 4 | import numpy as np 5 | from dataclasses import dataclass, field 6 | from typing import List, Any 7 | import warnings 8 | 9 | 10 | class Callback: 11 | """ 12 | Abstract base class used to build new callbacks. 13 | """ 14 | 15 | def __init__(self): 16 | pass 17 | 18 | def set_params(self, params): 19 | self.params = params 20 | 21 | def set_trainer(self, model): 22 | self.trainer = model 23 | 24 | def on_epoch_begin(self, epoch, logs=None): 25 | pass 26 | 27 | def on_epoch_end(self, epoch, logs=None): 28 | pass 29 | 30 | def on_batch_begin(self, batch, logs=None): 31 | pass 32 | 33 | def on_batch_end(self, batch, logs=None): 34 | pass 35 | 36 | def on_train_begin(self, logs=None): 37 | pass 38 | 39 | def on_train_end(self, logs=None): 40 | pass 41 | 42 | 43 | @dataclass 44 | class CallbackContainer: 45 | """ 46 | Container holding a list of callbacks. 47 | """ 48 | 49 | callbacks: List[Callback] = field(default_factory=list) 50 | 51 | def append(self, callback): 52 | self.callbacks.append(callback) 53 | 54 | def set_params(self, params): 55 | for callback in self.callbacks: 56 | callback.set_params(params) 57 | 58 | def set_trainer(self, trainer): 59 | self.trainer = trainer 60 | for callback in self.callbacks: 61 | callback.set_trainer(trainer) 62 | 63 | def on_epoch_begin(self, epoch, logs=None): 64 | logs = logs or {} 65 | for callback in self.callbacks: 66 | callback.on_epoch_begin(epoch, logs) 67 | 68 | def on_epoch_end(self, epoch, logs=None): 69 | logs = logs or {} 70 | for callback in self.callbacks: 71 | callback.on_epoch_end(epoch, logs) 72 | 73 | def on_batch_begin(self, batch, logs=None): 74 | logs = logs or {} 75 | for callback in self.callbacks: 76 | callback.on_batch_begin(batch, logs) 77 | 78 | def on_batch_end(self, batch, logs=None): 79 | logs = logs or {} 80 | for callback in self.callbacks: 81 | callback.on_batch_end(batch, logs) 82 | 83 | def on_train_begin(self, logs=None): 84 | logs = logs or {} 85 | logs["start_time"] = time.time() 86 | for callback in self.callbacks: 87 | callback.on_train_begin(logs) 88 | 89 | def on_train_end(self, logs=None): 90 | logs = logs or {} 91 | for callback in self.callbacks: 92 | callback.on_train_end(logs) 93 | 94 | 95 | @dataclass 96 | class EarlyStopping(Callback): 97 | """EarlyStopping callback to exit the training loop if early_stopping_metric 98 | does not improve by a certain amount for a certain 99 | number of epochs. 100 | 101 | Parameters 102 | --------- 103 | early_stopping_metric : str 104 | Early stopping metric name 105 | is_maximize : bool 106 | Whether to maximize or not early_stopping_metric 107 | tol : float 108 | minimum change in monitored value to qualify as improvement. 109 | This number should be positive. 110 | patience : integer 111 | number of epochs to wait for improvement before terminating. 112 | the counter be reset after each improvement 113 | 114 | """ 115 | 116 | early_stopping_metric: str 117 | is_maximize: bool 118 | tol: float = 0.0 119 | patience: int = 5 120 | 121 | def __post_init__(self): 122 | self.best_epoch = 0 123 | self.stopped_epoch = 0 124 | self.wait = 0 125 | self.best_weights = None 126 | self.best_loss = np.inf 127 | if self.is_maximize: 128 | self.best_loss = -self.best_loss 129 | super().__init__() 130 | 131 | def on_epoch_end(self, epoch, logs=None): 132 | current_loss = logs.get(self.early_stopping_metric) 133 | if current_loss is None: 134 | return 135 | 136 | loss_change = current_loss - self.best_loss 137 | max_improved = self.is_maximize and loss_change > self.tol 138 | min_improved = (not self.is_maximize) and (-loss_change > self.tol) 139 | if max_improved or min_improved: 140 | self.best_loss = current_loss 141 | self.best_epoch = epoch 142 | self.wait = 1 143 | self.best_weights = copy.deepcopy(self.trainer.network.state_dict()) 144 | else: 145 | if self.wait >= self.patience: 146 | self.stopped_epoch = epoch 147 | self.trainer._stop_training = True 148 | self.wait += 1 149 | 150 | def on_train_end(self, logs=None): 151 | self.trainer.best_epoch = self.best_epoch 152 | self.trainer.best_cost = self.best_loss 153 | 154 | if self.best_weights is not None: 155 | self.trainer.network.load_state_dict(self.best_weights) 156 | 157 | if self.stopped_epoch > 0: 158 | msg = f"\nEarly stopping occurred at epoch {self.stopped_epoch}" 159 | msg += ( 160 | f" with best_epoch = {self.best_epoch} and " 161 | + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}" 162 | ) 163 | print(msg) 164 | else: 165 | msg = ( 166 | f"Stop training because you reached max_epochs = {self.trainer.max_epochs}" 167 | + f" with best_epoch = {self.best_epoch} and " 168 | + f"best_{self.early_stopping_metric} = {round(self.best_loss, 5)}" 169 | ) 170 | print(msg) 171 | wrn_msg = "Best weights from best epoch are automatically used!" 172 | warnings.warn(wrn_msg) 173 | 174 | 175 | @dataclass 176 | class History(Callback): 177 | """Callback that records events into a `History` object. 178 | This callback is automatically applied to 179 | every SuperModule. 180 | 181 | Parameters 182 | --------- 183 | trainer : DeepRecoModel 184 | Model class to train 185 | verbose : int 186 | Print results every verbose iteration 187 | 188 | """ 189 | 190 | trainer: Any 191 | verbose: int = 1 192 | 193 | def __post_init__(self): 194 | super().__init__() 195 | self.samples_seen = 0.0 196 | self.total_time = 0.0 197 | 198 | def on_train_begin(self, logs=None): 199 | self.history = {"loss": []} 200 | self.history.update({"lr": []}) 201 | self.history.update({name: [] for name in self.trainer._metrics_names}) 202 | self.start_time = logs["start_time"] 203 | self.epoch_loss = 0.0 204 | 205 | def on_epoch_begin(self, epoch, logs=None): 206 | self.epoch_metrics = {"loss": 0.0} 207 | self.samples_seen = 0.0 208 | 209 | def on_epoch_end(self, epoch, logs=None): 210 | self.epoch_metrics["loss"] = self.epoch_loss 211 | for metric_name, metric_value in self.epoch_metrics.items(): 212 | self.history[metric_name].append(metric_value) 213 | if self.verbose == 0: 214 | return 215 | if epoch % self.verbose != 0: 216 | return 217 | msg = f"epoch {epoch:<3}" 218 | for metric_name, metric_value in self.epoch_metrics.items(): 219 | if metric_name != "lr": 220 | msg += f"| {metric_name:<3}: {np.round(metric_value, 5):<8}" 221 | self.total_time = int(time.time() - self.start_time) 222 | msg += f"| {str(datetime.timedelta(seconds=self.total_time)) + 's':<6}" 223 | print(msg) 224 | 225 | def on_batch_end(self, batch, logs=None): 226 | batch_size = logs["batch_size"] 227 | self.epoch_loss = ( 228 | self.samples_seen * self.epoch_loss + batch_size * logs["loss"] 229 | ) / (self.samples_seen + batch_size) 230 | self.samples_seen += batch_size 231 | 232 | def __getitem__(self, name): 233 | return self.history[name] 234 | 235 | def __repr__(self): 236 | return str(self.history) 237 | 238 | def __str__(self): 239 | return str(self.history) 240 | 241 | 242 | @dataclass 243 | class LRSchedulerCallback(Callback): 244 | """Wrapper for most torch scheduler functions. 245 | 246 | Parameters 247 | --------- 248 | scheduler_fn : torch.optim.lr_scheduler 249 | Torch scheduling class 250 | scheduler_params : dict 251 | Dictionnary containing all parameters for the scheduler_fn 252 | is_batch_level : bool (default = False) 253 | If set to False : lr updates will happen at every epoch 254 | If set to True : lr updates happen at every batch 255 | Set this to True for OneCycleLR for example 256 | """ 257 | 258 | scheduler_fn: Any 259 | optimizer: Any 260 | scheduler_params: dict 261 | early_stopping_metric: str 262 | is_batch_level: bool = False 263 | 264 | def __post_init__( 265 | self, 266 | ): 267 | self.is_metric_related = hasattr(self.scheduler_fn, "is_better") 268 | self.scheduler = self.scheduler_fn(self.optimizer, **self.scheduler_params) 269 | super().__init__() 270 | 271 | def on_batch_end(self, batch, logs=None): 272 | if self.is_batch_level: 273 | self.scheduler.step() 274 | else: 275 | pass 276 | 277 | def on_epoch_end(self, epoch, logs=None): 278 | current_loss = logs.get(self.early_stopping_metric) 279 | if current_loss is None: 280 | return 281 | if self.is_batch_level: 282 | pass 283 | else: 284 | if self.is_metric_related: 285 | self.scheduler.step(current_loss) 286 | else: 287 | self.scheduler.step() 288 | -------------------------------------------------------------------------------- /pytorch_tabnet/multitask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.special import softmax 4 | from pytorch_tabnet.utils import SparsePredictDataset, PredictDataset, filter_weights 5 | from pytorch_tabnet.abstract_model import TabModel 6 | from pytorch_tabnet.multiclass_utils import infer_multitask_output, check_output_dim 7 | from torch.utils.data import DataLoader 8 | import scipy 9 | 10 | 11 | class TabNetMultiTaskClassifier(TabModel): 12 | def __post_init__(self): 13 | super(TabNetMultiTaskClassifier, self).__post_init__() 14 | self._task = 'classification' 15 | self._default_loss = torch.nn.functional.cross_entropy 16 | self._default_metric = 'logloss' 17 | 18 | def prepare_target(self, y): 19 | y_mapped = y.copy() 20 | for task_idx in range(y.shape[1]): 21 | task_mapper = self.target_mapper[task_idx] 22 | y_mapped[:, task_idx] = np.vectorize(task_mapper.get)(y[:, task_idx]) 23 | return y_mapped 24 | 25 | def compute_loss(self, y_pred, y_true): 26 | """ 27 | Computes the loss according to network output and targets 28 | 29 | Parameters 30 | ---------- 31 | y_pred : list of tensors 32 | Output of network 33 | y_true : LongTensor 34 | Targets label encoded 35 | 36 | Returns 37 | ------- 38 | loss : torch.Tensor 39 | output of loss function(s) 40 | 41 | """ 42 | loss = 0 43 | y_true = y_true.long() 44 | if isinstance(self.loss_fn, list): 45 | # if you specify a different loss for each task 46 | for task_loss, task_output, task_id in zip( 47 | self.loss_fn, y_pred, range(len(self.loss_fn)) 48 | ): 49 | loss += task_loss(task_output, y_true[:, task_id]) 50 | else: 51 | # same loss function is applied to all tasks 52 | for task_id, task_output in enumerate(y_pred): 53 | loss += self.loss_fn(task_output, y_true[:, task_id]) 54 | 55 | loss /= len(y_pred) 56 | return loss 57 | 58 | def stack_batches(self, list_y_true, list_y_score): 59 | y_true = np.vstack(list_y_true) 60 | y_score = [] 61 | for i in range(len(self.output_dim)): 62 | score = np.vstack([x[i] for x in list_y_score]) 63 | score = softmax(score, axis=1) 64 | y_score.append(score) 65 | return y_true, y_score 66 | 67 | def update_fit_params(self, X_train, y_train, eval_set, weights): 68 | output_dim, train_labels = infer_multitask_output(y_train) 69 | for _, y in eval_set: 70 | for task_idx in range(y.shape[1]): 71 | check_output_dim(train_labels[task_idx], y[:, task_idx]) 72 | self.output_dim = output_dim 73 | self.classes_ = train_labels 74 | self.target_mapper = [ 75 | {class_label: index for index, class_label in enumerate(classes)} 76 | for classes in self.classes_ 77 | ] 78 | self.preds_mapper = [ 79 | {str(index): str(class_label) for index, class_label in enumerate(classes)} 80 | for classes in self.classes_ 81 | ] 82 | self.updated_weights = weights 83 | filter_weights(self.updated_weights) 84 | 85 | def predict(self, X): 86 | """ 87 | Make predictions on a batch (valid) 88 | 89 | Parameters 90 | ---------- 91 | X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix` 92 | Input data 93 | 94 | Returns 95 | ------- 96 | results : np.array 97 | Predictions of the most probable class 98 | """ 99 | self.network.eval() 100 | 101 | if scipy.sparse.issparse(X): 102 | dataloader = DataLoader( 103 | SparsePredictDataset(X), 104 | batch_size=self.batch_size, 105 | shuffle=False, 106 | ) 107 | else: 108 | dataloader = DataLoader( 109 | PredictDataset(X), 110 | batch_size=self.batch_size, 111 | shuffle=False, 112 | ) 113 | 114 | results = {} 115 | for data in dataloader: 116 | data = data.to(self.device).float() 117 | output, _ = self.network(data) 118 | predictions = [ 119 | torch.argmax(torch.nn.Softmax(dim=1)(task_output), dim=1) 120 | .cpu() 121 | .detach() 122 | .numpy() 123 | .reshape(-1) 124 | for task_output in output 125 | ] 126 | 127 | for task_idx in range(len(self.output_dim)): 128 | results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]] 129 | # stack all task individually 130 | results = [np.hstack(task_res) for task_res in results.values()] 131 | # map all task individually 132 | results = [ 133 | np.vectorize(self.preds_mapper[task_idx].get)(task_res.astype(str)) 134 | for task_idx, task_res in enumerate(results) 135 | ] 136 | return results 137 | 138 | def predict_proba(self, X): 139 | """ 140 | Make predictions for classification on a batch (valid) 141 | 142 | Parameters 143 | ---------- 144 | X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix` 145 | Input data 146 | 147 | Returns 148 | ------- 149 | res : list of np.ndarray 150 | 151 | """ 152 | self.network.eval() 153 | 154 | if scipy.sparse.issparse(X): 155 | dataloader = DataLoader( 156 | SparsePredictDataset(X), 157 | batch_size=self.batch_size, 158 | shuffle=False, 159 | ) 160 | else: 161 | dataloader = DataLoader( 162 | PredictDataset(X), 163 | batch_size=self.batch_size, 164 | shuffle=False, 165 | ) 166 | 167 | results = {} 168 | for data in dataloader: 169 | data = data.to(self.device).float() 170 | output, _ = self.network(data) 171 | predictions = [ 172 | torch.nn.Softmax(dim=1)(task_output).cpu().detach().numpy() 173 | for task_output in output 174 | ] 175 | for task_idx in range(len(self.output_dim)): 176 | results[task_idx] = results.get(task_idx, []) + [predictions[task_idx]] 177 | res = [np.vstack(task_res) for task_res in results.values()] 178 | return res 179 | -------------------------------------------------------------------------------- /pytorch_tabnet/pretraining_utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from pytorch_tabnet.utils import ( 3 | create_sampler, 4 | SparsePredictDataset, 5 | PredictDataset, 6 | check_input 7 | ) 8 | import scipy 9 | 10 | 11 | def create_dataloaders( 12 | X_train, eval_set, weights, batch_size, num_workers, drop_last, pin_memory 13 | ): 14 | """ 15 | Create dataloaders with or without subsampling depending on weights and balanced. 16 | 17 | Parameters 18 | ---------- 19 | X_train : np.ndarray or scipy.sparse.csr_matrix 20 | Training data 21 | eval_set : list of np.array (for Xs and ys) or scipy.sparse.csr_matrix (for Xs) 22 | List of eval sets 23 | weights : either 0, 1, dict or iterable 24 | if 0 (default) : no weights will be applied 25 | if 1 : classification only, will balanced class with inverse frequency 26 | if dict : keys are corresponding class values are sample weights 27 | if iterable : list or np array must be of length equal to nb elements 28 | in the training set 29 | batch_size : int 30 | how many samples per batch to load 31 | num_workers : int 32 | how many subprocesses to use for data loading. 0 means that the data 33 | will be loaded in the main process 34 | drop_last : bool 35 | set to True to drop the last incomplete batch, if the dataset size is not 36 | divisible by the batch size. If False and the size of dataset is not 37 | divisible by the batch size, then the last batch will be smaller 38 | pin_memory : bool 39 | Whether to pin GPU memory during training 40 | 41 | Returns 42 | ------- 43 | train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader 44 | Training and validation dataloaders 45 | """ 46 | need_shuffle, sampler = create_sampler(weights, X_train) 47 | 48 | if scipy.sparse.issparse(X_train): 49 | train_dataloader = DataLoader( 50 | SparsePredictDataset(X_train), 51 | batch_size=batch_size, 52 | sampler=sampler, 53 | shuffle=need_shuffle, 54 | num_workers=num_workers, 55 | drop_last=drop_last, 56 | pin_memory=pin_memory, 57 | ) 58 | else: 59 | train_dataloader = DataLoader( 60 | PredictDataset(X_train), 61 | batch_size=batch_size, 62 | sampler=sampler, 63 | shuffle=need_shuffle, 64 | num_workers=num_workers, 65 | drop_last=drop_last, 66 | pin_memory=pin_memory, 67 | ) 68 | 69 | valid_dataloaders = [] 70 | for X in eval_set: 71 | if scipy.sparse.issparse(X): 72 | valid_dataloaders.append( 73 | DataLoader( 74 | SparsePredictDataset(X), 75 | batch_size=batch_size, 76 | sampler=sampler, 77 | shuffle=need_shuffle, 78 | num_workers=num_workers, 79 | drop_last=drop_last, 80 | pin_memory=pin_memory, 81 | ) 82 | ) 83 | else: 84 | valid_dataloaders.append( 85 | DataLoader( 86 | PredictDataset(X), 87 | batch_size=batch_size, 88 | sampler=sampler, 89 | shuffle=need_shuffle, 90 | num_workers=num_workers, 91 | drop_last=drop_last, 92 | pin_memory=pin_memory, 93 | ) 94 | ) 95 | 96 | return train_dataloader, valid_dataloaders 97 | 98 | 99 | def validate_eval_set(eval_set, eval_name, X_train): 100 | """Check if the shapes of eval_set are compatible with X_train. 101 | 102 | Parameters 103 | ---------- 104 | eval_set : List of numpy array 105 | The list evaluation set. 106 | The last one is used for early stopping 107 | X_train : np.ndarray 108 | Train owned products 109 | 110 | Returns 111 | ------- 112 | eval_names : list of str 113 | Validated list of eval_names. 114 | 115 | """ 116 | eval_names = eval_name or [f"val_{i}" for i in range(len(eval_set))] 117 | assert len(eval_set) == len( 118 | eval_names 119 | ), "eval_set and eval_name have not the same length" 120 | 121 | for set_nb, X in enumerate(eval_set): 122 | check_input(X) 123 | msg = ( 124 | f"Number of columns is different between eval set {set_nb}" 125 | + f"({X.shape[1]}) and X_train ({X_train.shape[1]})" 126 | ) 127 | assert X.shape[1] == X_train.shape[1], msg 128 | return eval_names 129 | -------------------------------------------------------------------------------- /pytorch_tabnet/sparsemax.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | 7 | """ 8 | Other possible implementations: 9 | https://github.com/KrisKorrel/sparsemax-pytorch/blob/master/sparsemax.py 10 | https://github.com/msobroza/SparsemaxPytorch/blob/master/mnist/sparsemax.py 11 | https://github.com/vene/sparse-structured-attention/blob/master/pytorch/torchsparseattn/sparsemax.py 12 | """ 13 | 14 | 15 | # credits to Yandex https://github.com/Qwicen/node/blob/master/lib/nn_utils.py 16 | def _make_ix_like(input, dim=0): 17 | d = input.size(dim) 18 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 19 | view = [1] * input.dim() 20 | view[0] = -1 21 | return rho.view(view).transpose(0, dim) 22 | 23 | 24 | class SparsemaxFunction(Function): 25 | """ 26 | An implementation of sparsemax (Martins & Astudillo, 2016). See 27 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 28 | By Ben Peters and Vlad Niculae 29 | """ 30 | 31 | @staticmethod 32 | def forward(ctx, input, dim=-1): 33 | """sparsemax: normalizing sparse transform (a la softmax) 34 | 35 | Parameters 36 | ---------- 37 | ctx : torch.autograd.function._ContextMethodMixin 38 | input : torch.Tensor 39 | any shape 40 | dim : int 41 | dimension along which to apply sparsemax 42 | 43 | Returns 44 | ------- 45 | output : torch.Tensor 46 | same shape as input 47 | 48 | """ 49 | ctx.dim = dim 50 | max_val, _ = input.max(dim=dim, keepdim=True) 51 | input -= max_val # same numerical stability trick as for softmax 52 | tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) 53 | output = torch.clamp(input - tau, min=0) 54 | ctx.save_for_backward(supp_size, output) 55 | return output 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output): 59 | supp_size, output = ctx.saved_tensors 60 | dim = ctx.dim 61 | grad_input = grad_output.clone() 62 | grad_input[output == 0] = 0 63 | 64 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 65 | v_hat = v_hat.unsqueeze(dim) 66 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 67 | return grad_input, None 68 | 69 | @staticmethod 70 | def _threshold_and_support(input, dim=-1): 71 | """Sparsemax building block: compute the threshold 72 | 73 | Parameters 74 | ---------- 75 | input: torch.Tensor 76 | any dimension 77 | dim : int 78 | dimension along which to apply the sparsemax 79 | 80 | Returns 81 | ------- 82 | tau : torch.Tensor 83 | the threshold value 84 | support_size : torch.Tensor 85 | 86 | """ 87 | 88 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 89 | input_cumsum = input_srt.cumsum(dim) - 1 90 | rhos = _make_ix_like(input, dim) 91 | support = rhos * input_srt > input_cumsum 92 | 93 | support_size = support.sum(dim=dim).unsqueeze(dim) 94 | tau = input_cumsum.gather(dim, support_size - 1) 95 | tau /= support_size.to(input.dtype) 96 | return tau, support_size 97 | 98 | 99 | sparsemax = SparsemaxFunction.apply 100 | 101 | 102 | class Sparsemax(nn.Module): 103 | 104 | def __init__(self, dim=-1): 105 | self.dim = dim 106 | super(Sparsemax, self).__init__() 107 | 108 | def forward(self, input): 109 | return sparsemax(input, self.dim) 110 | 111 | 112 | class Entmax15Function(Function): 113 | """ 114 | An implementation of exact Entmax with alpha=1.5 (B. Peters, V. Niculae, A. Martins). See 115 | :cite:`https://arxiv.org/abs/1905.05702 for detailed description. 116 | Source: https://github.com/deep-spin/entmax 117 | """ 118 | 119 | @staticmethod 120 | def forward(ctx, input, dim=-1): 121 | ctx.dim = dim 122 | 123 | max_val, _ = input.max(dim=dim, keepdim=True) 124 | input = input - max_val # same numerical stability trick as for softmax 125 | input = input / 2 # divide by 2 to solve actual Entmax 126 | 127 | tau_star, _ = Entmax15Function._threshold_and_support(input, dim) 128 | output = torch.clamp(input - tau_star, min=0) ** 2 129 | ctx.save_for_backward(output) 130 | return output 131 | 132 | @staticmethod 133 | def backward(ctx, grad_output): 134 | Y, = ctx.saved_tensors 135 | gppr = Y.sqrt() # = 1 / g'' (Y) 136 | dX = grad_output * gppr 137 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 138 | q = q.unsqueeze(ctx.dim) 139 | dX -= q * gppr 140 | return dX, None 141 | 142 | @staticmethod 143 | def _threshold_and_support(input, dim=-1): 144 | Xsrt, _ = torch.sort(input, descending=True, dim=dim) 145 | 146 | rho = _make_ix_like(input, dim) 147 | mean = Xsrt.cumsum(dim) / rho 148 | mean_sq = (Xsrt ** 2).cumsum(dim) / rho 149 | ss = rho * (mean_sq - mean ** 2) 150 | delta = (1 - ss) / rho 151 | 152 | # NOTE this is not exactly the same as in reference algo 153 | # Fortunately it seems the clamped values never wrongly 154 | # get selected by tau <= sorted_z. Prove this! 155 | delta_nz = torch.clamp(delta, 0) 156 | tau = mean - torch.sqrt(delta_nz) 157 | 158 | support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) 159 | tau_star = tau.gather(dim, support_size - 1) 160 | return tau_star, support_size 161 | 162 | 163 | class Entmoid15(Function): 164 | """ A highly optimized equivalent of lambda x: Entmax15([x, 0]) """ 165 | 166 | @staticmethod 167 | def forward(ctx, input): 168 | output = Entmoid15._forward(input) 169 | ctx.save_for_backward(output) 170 | return output 171 | 172 | @staticmethod 173 | def _forward(input): 174 | input, is_pos = abs(input), input >= 0 175 | tau = (input + torch.sqrt(F.relu(8 - input ** 2))) / 2 176 | tau.masked_fill_(tau <= input, 2.0) 177 | y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2 178 | return torch.where(is_pos, 1 - y_neg, y_neg) 179 | 180 | @staticmethod 181 | def backward(ctx, grad_output): 182 | return Entmoid15._backward(ctx.saved_tensors[0], grad_output) 183 | 184 | @staticmethod 185 | def _backward(output, grad_output): 186 | gppr0, gppr1 = output.sqrt(), (1 - output).sqrt() 187 | grad_input = grad_output * gppr0 188 | q = grad_input / (gppr0 + gppr1) 189 | grad_input -= q * gppr0 190 | return grad_input 191 | 192 | 193 | entmax15 = Entmax15Function.apply 194 | entmoid15 = Entmoid15.apply 195 | 196 | 197 | class Entmax15(nn.Module): 198 | 199 | def __init__(self, dim=-1): 200 | self.dim = dim 201 | super(Entmax15, self).__init__() 202 | 203 | def forward(self, input): 204 | return entmax15(input, self.dim) 205 | 206 | 207 | # Credits were lost... 208 | # def _make_ix_like(input, dim=0): 209 | # d = input.size(dim) 210 | # rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 211 | # view = [1] * input.dim() 212 | # view[0] = -1 213 | # return rho.view(view).transpose(0, dim) 214 | # 215 | # 216 | # def _threshold_and_support(input, dim=0): 217 | # """Sparsemax building block: compute the threshold 218 | # Args: 219 | # input: any dimension 220 | # dim: dimension along which to apply the sparsemax 221 | # Returns: 222 | # the threshold value 223 | # """ 224 | # 225 | # input_srt, _ = torch.sort(input, descending=True, dim=dim) 226 | # input_cumsum = input_srt.cumsum(dim) - 1 227 | # rhos = _make_ix_like(input, dim) 228 | # support = rhos * input_srt > input_cumsum 229 | # 230 | # support_size = support.sum(dim=dim).unsqueeze(dim) 231 | # tau = input_cumsum.gather(dim, support_size - 1) 232 | # tau /= support_size.to(input.dtype) 233 | # return tau, support_size 234 | # 235 | # 236 | # class SparsemaxFunction(Function): 237 | # 238 | # @staticmethod 239 | # def forward(ctx, input, dim=0): 240 | # """sparsemax: normalizing sparse transform (a la softmax) 241 | # Parameters: 242 | # input (Tensor): any shape 243 | # dim: dimension along which to apply sparsemax 244 | # Returns: 245 | # output (Tensor): same shape as input 246 | # """ 247 | # ctx.dim = dim 248 | # max_val, _ = input.max(dim=dim, keepdim=True) 249 | # input -= max_val # same numerical stability trick as for softmax 250 | # tau, supp_size = _threshold_and_support(input, dim=dim) 251 | # output = torch.clamp(input - tau, min=0) 252 | # ctx.save_for_backward(supp_size, output) 253 | # return output 254 | # 255 | # @staticmethod 256 | # def backward(ctx, grad_output): 257 | # supp_size, output = ctx.saved_tensors 258 | # dim = ctx.dim 259 | # grad_input = grad_output.clone() 260 | # grad_input[output == 0] = 0 261 | # 262 | # v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 263 | # v_hat = v_hat.unsqueeze(dim) 264 | # grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 265 | # return grad_input, None 266 | # 267 | # 268 | # sparsemax = SparsemaxFunction.apply 269 | # 270 | # 271 | # class Sparsemax(nn.Module): 272 | # 273 | # def __init__(self, dim=0): 274 | # self.dim = dim 275 | # super(Sparsemax, self).__init__() 276 | # 277 | # def forward(self, input): 278 | # return sparsemax(input, self.dim) 279 | -------------------------------------------------------------------------------- /pytorch_tabnet/tab_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.special import softmax 4 | from pytorch_tabnet.utils import SparsePredictDataset, PredictDataset, filter_weights 5 | from pytorch_tabnet.abstract_model import TabModel 6 | from pytorch_tabnet.multiclass_utils import infer_output_dim, check_output_dim 7 | from torch.utils.data import DataLoader 8 | import scipy 9 | 10 | 11 | class TabNetClassifier(TabModel): 12 | def __post_init__(self): 13 | super(TabNetClassifier, self).__post_init__() 14 | self._task = 'classification' 15 | self._default_loss = torch.nn.functional.cross_entropy 16 | self._default_metric = 'accuracy' 17 | 18 | def weight_updater(self, weights): 19 | """ 20 | Updates weights dictionary according to target_mapper. 21 | 22 | Parameters 23 | ---------- 24 | weights : bool or dict 25 | Given weights for balancing training. 26 | 27 | Returns 28 | ------- 29 | bool or dict 30 | Same bool if weights are bool, updated dict otherwise. 31 | 32 | """ 33 | if isinstance(weights, int): 34 | return weights 35 | elif isinstance(weights, dict): 36 | return {self.target_mapper[key]: value for key, value in weights.items()} 37 | else: 38 | return weights 39 | 40 | def prepare_target(self, y): 41 | return np.vectorize(self.target_mapper.get)(y) 42 | 43 | def compute_loss(self, y_pred, y_true): 44 | return self.loss_fn(y_pred, y_true.long()) 45 | 46 | def update_fit_params( 47 | self, 48 | X_train, 49 | y_train, 50 | eval_set, 51 | weights, 52 | ): 53 | output_dim, train_labels = infer_output_dim(y_train) 54 | for X, y in eval_set: 55 | check_output_dim(train_labels, y) 56 | self.output_dim = output_dim 57 | self._default_metric = ('auc' if self.output_dim == 2 else 'accuracy') 58 | self.classes_ = train_labels 59 | self.target_mapper = { 60 | class_label: index for index, class_label in enumerate(self.classes_) 61 | } 62 | self.preds_mapper = { 63 | str(index): class_label for index, class_label in enumerate(self.classes_) 64 | } 65 | self.updated_weights = self.weight_updater(weights) 66 | 67 | def stack_batches(self, list_y_true, list_y_score): 68 | y_true = np.hstack(list_y_true) 69 | y_score = np.vstack(list_y_score) 70 | y_score = softmax(y_score, axis=1) 71 | return y_true, y_score 72 | 73 | def predict_func(self, outputs): 74 | outputs = np.argmax(outputs, axis=1) 75 | return np.vectorize(self.preds_mapper.get)(outputs.astype(str)) 76 | 77 | def predict_proba(self, X): 78 | """ 79 | Make predictions for classification on a batch (valid) 80 | 81 | Parameters 82 | ---------- 83 | X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix` 84 | Input data 85 | 86 | Returns 87 | ------- 88 | res : np.ndarray 89 | 90 | """ 91 | self.network.eval() 92 | 93 | if scipy.sparse.issparse(X): 94 | dataloader = DataLoader( 95 | SparsePredictDataset(X), 96 | batch_size=self.batch_size, 97 | shuffle=False, 98 | ) 99 | else: 100 | dataloader = DataLoader( 101 | PredictDataset(X), 102 | batch_size=self.batch_size, 103 | shuffle=False, 104 | ) 105 | 106 | results = [] 107 | for batch_nb, data in enumerate(dataloader): 108 | data = data.to(self.device).float() 109 | 110 | output, M_loss = self.network(data) 111 | predictions = torch.nn.Softmax(dim=1)(output).cpu().detach().numpy() 112 | results.append(predictions) 113 | res = np.vstack(results) 114 | return res 115 | 116 | 117 | class TabNetRegressor(TabModel): 118 | def __post_init__(self): 119 | super(TabNetRegressor, self).__post_init__() 120 | self._task = 'regression' 121 | self._default_loss = torch.nn.functional.mse_loss 122 | self._default_metric = 'mse' 123 | 124 | def prepare_target(self, y): 125 | return y 126 | 127 | def compute_loss(self, y_pred, y_true): 128 | return self.loss_fn(y_pred, y_true) 129 | 130 | def update_fit_params( 131 | self, 132 | X_train, 133 | y_train, 134 | eval_set, 135 | weights 136 | ): 137 | if len(y_train.shape) != 2: 138 | msg = "Targets should be 2D : (n_samples, n_regression) " + \ 139 | f"but y_train.shape={y_train.shape} given.\n" + \ 140 | "Use reshape(-1, 1) for single regression." 141 | raise ValueError(msg) 142 | self.output_dim = y_train.shape[1] 143 | self.preds_mapper = None 144 | 145 | self.updated_weights = weights 146 | filter_weights(self.updated_weights) 147 | 148 | def predict_func(self, outputs): 149 | return outputs 150 | 151 | def stack_batches(self, list_y_true, list_y_score): 152 | y_true = np.vstack(list_y_true) 153 | y_score = np.vstack(list_y_score) 154 | return y_true, y_score 155 | -------------------------------------------------------------------------------- /regression_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pytorch_tabnet.tab_model import TabNetRegressor\n", 10 | "\n", 11 | "import torch\n", 12 | "from sklearn.preprocessing import LabelEncoder\n", 13 | "from sklearn.metrics import mean_squared_error\n", 14 | "\n", 15 | "import pandas as pd\n", 16 | "import numpy as np\n", 17 | "np.random.seed(0)\n", 18 | "\n", 19 | "\n", 20 | "import os\n", 21 | "import wget\n", 22 | "from pathlib import Path" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "# Download census-income dataset" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n", 39 | "dataset_name = 'census-income'\n", 40 | "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "out.parent.mkdir(parents=True, exist_ok=True)\n", 50 | "if out.exists():\n", 51 | " print(\"File already exists.\")\n", 52 | "else:\n", 53 | " print(\"Downloading file...\")\n", 54 | " wget.download(url, out.as_posix())" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "# Load data and split" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "train = pd.read_csv(out)\n", 71 | "target = ' <=50K'\n", 72 | "if \"Set\" not in train.columns:\n", 73 | " train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n", 74 | "\n", 75 | "train_indices = train[train.Set==\"train\"].index\n", 76 | "valid_indices = train[train.Set==\"valid\"].index\n", 77 | "test_indices = train[train.Set==\"test\"].index" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "# Simple preprocessing\n", 85 | "\n", 86 | "Label encode categorical features and fill empty cells." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "categorical_columns = []\n", 96 | "categorical_dims = {}\n", 97 | "for col in train.columns[train.dtypes == object]:\n", 98 | " print(col, train[col].nunique())\n", 99 | " l_enc = LabelEncoder()\n", 100 | " train[col] = train[col].fillna(\"VV_likely\")\n", 101 | " train[col] = l_enc.fit_transform(train[col].values)\n", 102 | " categorical_columns.append(col)\n", 103 | " categorical_dims[col] = len(l_enc.classes_)\n", 104 | "\n", 105 | "for col in train.columns[train.dtypes == 'float64']:\n", 106 | " train.fillna(train.loc[train_indices, col].mean(), inplace=True)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "# Define categorical features for categorical embeddings" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "unused_feat = ['Set']\n", 123 | "\n", 124 | "features = [ col for col in train.columns if col not in unused_feat+[target]] \n", 125 | "\n", 126 | "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n", 127 | "\n", 128 | "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n", 129 | "\n", 130 | "# define your embedding sizes : here just a random choice\n", 131 | "cat_emb_dim = [5, 4, 3, 6, 2, 2, 1, 10]" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "# Network parameters" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "clf = TabNetRegressor(cat_dims=cat_dims, cat_emb_dim=cat_emb_dim, cat_idxs=cat_idxs)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "# Training" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "X_train = train[features].values[train_indices]\n", 164 | "y_train = train[target].values[train_indices].reshape(-1, 1)\n", 165 | "\n", 166 | "X_valid = train[features].values[valid_indices]\n", 167 | "y_valid = train[target].values[valid_indices].reshape(-1, 1)\n", 168 | "\n", 169 | "X_test = train[features].values[test_indices]\n", 170 | "y_test = train[target].values[test_indices].reshape(-1, 1)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "max_epochs = 100 if not os.getenv(\"CI\", False) else 2" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "from pytorch_tabnet.augmentations import RegressionSMOTE\n", 189 | "aug = RegressionSMOTE(p=0.2)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": { 196 | "scrolled": true 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "clf.fit(\n", 201 | " X_train=X_train, y_train=y_train,\n", 202 | " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", 203 | " eval_name=['train', 'valid'],\n", 204 | " eval_metric=['rmsle', 'mae', 'rmse', 'mse'],\n", 205 | " max_epochs=max_epochs,\n", 206 | " patience=50,\n", 207 | " batch_size=1024, virtual_batch_size=128,\n", 208 | " num_workers=0,\n", 209 | " drop_last=False,\n", 210 | " augmentations=aug, #aug\n", 211 | ") " 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "# Deprecated : best model is automatically loaded at end of fit\n", 221 | "# clf.load_best_model()\n", 222 | "\n", 223 | "preds = clf.predict(X_test)\n", 224 | "\n", 225 | "y_true = y_test\n", 226 | "\n", 227 | "test_score = mean_squared_error(y_pred=preds, y_true=y_true)\n", 228 | "\n", 229 | "print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n", 230 | "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_score}\")" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "# Save model and load" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "# save tabnet model\n", 247 | "saving_path_name = \"./tabnet_model_test_1\"\n", 248 | "saved_filepath = clf.save_model(saving_path_name)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "# define new model with basic parameters and load state dict weights\n", 258 | "loaded_clf = TabNetRegressor()\n", 259 | "loaded_clf.load_model(saved_filepath)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "loaded_preds = loaded_clf.predict(X_test)\n", 269 | "loaded_test_mse = mean_squared_error(loaded_preds, y_test)\n", 270 | "\n", 271 | "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_mse}\")" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "assert(test_score == loaded_test_mse)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "# Global explainability : feat importance summing to 1" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "clf.feature_importances_" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "# Local explainability and masks" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "explain_matrix, masks = clf.explain(X_test)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "from matplotlib import pyplot as plt\n", 322 | "%matplotlib inline" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "fig, axs = plt.subplots(1, 3, figsize=(20,20))\n", 332 | "\n", 333 | "for i in range(3):\n", 334 | " axs[i].imshow(masks[i][:50])\n", 335 | " axs[i].set_title(f\"mask {i}\")\n" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": {}, 341 | "source": [ 342 | "# XGB" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": { 349 | "scrolled": true 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "from xgboost import XGBRegressor\n", 354 | "\n", 355 | "clf_xgb = XGBRegressor(max_depth=8,\n", 356 | " learning_rate=0.1,\n", 357 | " n_estimators=1000,\n", 358 | " verbosity=0,\n", 359 | " silent=None,\n", 360 | " objective='reg:linear',\n", 361 | " booster='gbtree',\n", 362 | " n_jobs=-1,\n", 363 | " nthread=None,\n", 364 | " gamma=0,\n", 365 | " min_child_weight=1,\n", 366 | " max_delta_step=0,\n", 367 | " subsample=0.7,\n", 368 | " colsample_bytree=1,\n", 369 | " colsample_bylevel=1,\n", 370 | " colsample_bynode=1,\n", 371 | " reg_alpha=0,\n", 372 | " reg_lambda=1,\n", 373 | " scale_pos_weight=1,\n", 374 | " base_score=0.5,\n", 375 | " random_state=0,\n", 376 | " seed=None,)\n", 377 | "\n", 378 | "clf_xgb.fit(X_train, y_train,\n", 379 | " eval_set=[(X_valid, y_valid)],\n", 380 | " early_stopping_rounds=40,\n", 381 | " verbose=10)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": null, 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "preds = np.array(clf_xgb.predict(X_valid))\n", 391 | "valid_auc = mean_squared_error(y_pred=preds, y_true=y_valid)\n", 392 | "print(valid_auc)\n", 393 | "\n", 394 | "preds = np.array(clf_xgb.predict(X_test))\n", 395 | "test_auc = mean_squared_error(y_pred=preds, y_true=y_test)\n", 396 | "print(test_auc)" 397 | ] 398 | } 399 | ], 400 | "metadata": { 401 | "kernelspec": { 402 | "display_name": "Python 3", 403 | "language": "python", 404 | "name": "python3" 405 | }, 406 | "language_info": { 407 | "codemirror_mode": { 408 | "name": "ipython", 409 | "version": 3 410 | }, 411 | "file_extension": ".py", 412 | "mimetype": "text/x-python", 413 | "name": "python", 414 | "nbconvert_exporter": "python", 415 | "pygments_lexer": "ipython3", 416 | "version": "3.7.13" 417 | }, 418 | "toc": { 419 | "base_numbering": 1, 420 | "nav_menu": {}, 421 | "number_sections": true, 422 | "sideBar": true, 423 | "skip_h1_title": false, 424 | "title_cell": "Table of Contents", 425 | "title_sidebar": "Contents", 426 | "toc_cell": false, 427 | "toc_position": {}, 428 | "toc_section_display": true, 429 | "toc_window_display": false 430 | } 431 | }, 432 | "nbformat": 4, 433 | "nbformat_minor": 2 434 | } 435 | -------------------------------------------------------------------------------- /release-script/Dockerfile_changelog: -------------------------------------------------------------------------------- 1 | FROM node:lts-alpine@sha256:c785e617c8d7015190c0d41af52cc69be8a16e3d9eb7cb21f0bb58bcfca14d6b 2 | 3 | RUN apk add git 4 | 5 | RUN npm i -g conventional-github-releaser@3.1.3 conventional-changelog-cli 6 | 7 | -------------------------------------------------------------------------------- /release-script/do-release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # stop if error 4 | set -e 5 | 6 | read -p 'Release version: ' version 7 | read -p 'You personal access token for Github: ' token 8 | read -p 'Your username for pipy: ' pipyUser 9 | read -p 'Your password for pipy: ' pipyPassword 10 | 11 | echo ${version} | grep v && echo "Version should be x.y.z (for example, 1.1.1, 2.0.0, ...)" && exit -1 12 | 13 | localDir=`readlink -f .` 14 | releaseDir="${localDir}/release-${version}" 15 | rm -rf ${releaseDir} 16 | mkdir ${releaseDir} 17 | cd $releaseDir 18 | 19 | echo "Cloning repo into tabnet" 20 | git clone -q git@github.com:dreamquark-ai/tabnet.git tabnet 21 | 22 | cd tabnet 23 | # Create release branch and push it 24 | git checkout release/${version} 25 | 26 | # Tagging proper version 27 | echo "Tagging proper version" 28 | git tag v${version} 29 | 30 | # Build release 31 | echo "Building latest build" 32 | docker run --rm -v ${PWD}:/work -w /work tabnet:latest poetry build 33 | 34 | echo "Merging into develop and master" 35 | git checkout master 36 | git merge --no-ff origin/release/${version} -m "chore: release v${version} (merge)" 37 | git checkout develop 38 | git merge --no-ff origin/release/${version} -m "chore: release v${version} (merge)" 39 | 40 | echo "Pushing branch" 41 | git push origin develop 42 | git push origin master 43 | echo "Pushing tag" 44 | git push origin --tags 45 | 46 | echo "Making github release" 47 | docker run -v ${PWD}:/work -w /work --entrypoint "" release-changelog:latest conventional-github-releaser -p angular --token ${token} 48 | 49 | # Build release 50 | echo "Building latest build" 51 | docker run --rm -v ${PWD}:/work -w /work tabnet:latest poetry build 52 | # Build release 53 | echo "Publishing latest build" 54 | docker run --rm -v ${PWD}:/work -w /work tabnet:latest poetry publish -u ${pipyUser} -p ${pipyPassword} 55 | 56 | echo "Deleting release branch" 57 | git checkout develop 58 | git push origin :release/${version} 59 | 60 | cd ${localDir} 61 | rm -rf ${releaseDir} -------------------------------------------------------------------------------- /release-script/prepare-release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # stop if error 4 | set -e 5 | 6 | read -p 'Release version: ' version 7 | echo ${version} | grep v && echo "Version should be x.y.z (for example, 1.1.1, 2.0.0, ...)" && exit -1 8 | 9 | localDir=`readlink -f .` 10 | releaseDir="${localDir}/release-${version}" 11 | rm -rf ${releaseDir} 12 | mkdir ${releaseDir} 13 | cd $releaseDir 14 | 15 | echo "Cloning repo into tabnet" 16 | git clone -q git@github.com:dreamquark-ai/tabnet.git tabnet 17 | 18 | cd tabnet 19 | # Create release branch and push it 20 | git checkout -b release/${version} 21 | # Change version of package 22 | docker run --rm -v ${PWD}:/work -w /work tabnet:latest poetry version ${version} 23 | # Generate docs 24 | make install doc 25 | # Add modified files 26 | git add pyproject.toml docs/ 27 | # Commit release 28 | git commit -m "chore: release v${version}" 29 | # Create tag for changelog generation 30 | git tag v${version} 31 | docker run -v ${PWD}:/work -w /work --entrypoint "" release-changelog:latest git config --global --add safe.directory /work &&\ 32 | conventional-changelog -p angular -i CHANGELOG.md -s -r 0 && \ 33 | chmod 777 CHANGELOG.md 34 | # Removing 4 first line of the file 35 | echo "$(tail -n +4 CHANGELOG.md)" > CHANGELOG.md 36 | # Deleting tag 37 | git tag -d v${version} 38 | # Adding CHANGELOG to commit 39 | git add CHANGELOG.md 40 | git commit --amend --no-edit 41 | # Push release branch 42 | git push origin release/${version} 43 | 44 | cd ${localDir} 45 | sudo rm -rf ${releaseDir} -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "rebaseStalePrs": true, 3 | "extends": [ 4 | "config:base", 5 | "docker:enableMajor", 6 | ":disableRateLimiting" 7 | ], 8 | "docker": { 9 | "enabled": true 10 | }, 11 | "poetry": { 12 | "enabled": true 13 | }, 14 | "pip_requirements": { 15 | "enabled": false 16 | }, 17 | "pipenv": { 18 | "enabled": false 19 | }, 20 | "pip_setup": { 21 | "enabled": false 22 | }, 23 | "pinDigests": true, 24 | "semanticCommits": true, 25 | "semanticCommitType": "fix", 26 | "branchPrefix": "feature/renovate-", 27 | "assignees": [ 28 | "Hartorn", 29 | "j-abi", 30 | "Optimox", 31 | "eduardocarvp" 32 | ], 33 | "labels": [ 34 | "deps" 35 | ], 36 | "baseBranches": [ 37 | "develop" 38 | ], 39 | "major": { 40 | "labels": [ 41 | "deps", 42 | "dep:major" 43 | ] 44 | }, 45 | "minor": { 46 | "labels": [ 47 | "deps", 48 | "dep:minor" 49 | ] 50 | } 51 | } -------------------------------------------------------------------------------- /tests/test_unsupervised_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytest 4 | from pytorch_tabnet.metrics import UnsupervisedLoss, UnsupervisedLossNumpy 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "y_pred,embedded_x,obf_vars", 9 | [ 10 | ( 11 | np.random.uniform(low=-2, high=2, size=(20, 100)), 12 | np.random.uniform(low=-2, high=2, size=(20, 100)), 13 | np.random.choice([0, 1], size=(20, 100), replace=True) 14 | ), 15 | ( 16 | np.random.uniform(low=-2, high=2, size=(30, 50)), 17 | np.ones((30, 50)), 18 | np.random.choice([0, 1], size=(30, 50), replace=True) 19 | ) 20 | ] 21 | ) 22 | def test_equal_losses(y_pred, embedded_x, obf_vars): 23 | numpy_loss = UnsupervisedLossNumpy( 24 | y_pred=y_pred, 25 | embedded_x=embedded_x, 26 | obf_vars=obf_vars 27 | ) 28 | 29 | torch_loss = UnsupervisedLoss( 30 | y_pred=torch.tensor(y_pred, dtype=torch.float64), 31 | embedded_x=torch.tensor(embedded_x, dtype=torch.float64), 32 | obf_vars=torch.tensor(obf_vars, dtype=torch.float64) 33 | ) 34 | 35 | assert np.isclose(numpy_loss, torch_loss.detach().numpy()) 36 | --------------------------------------------------------------------------------