├── rtdl
├── tests
│ ├── __init__.py
│ ├── test_data.py
│ ├── test_modules.py
│ └── test_vs_paper.py
├── _utils.py
├── __init__.py
├── functional.py
├── data.py
└── modules.py
├── environment.yaml
├── Makefile
├── pyproject.toml
├── README.md
├── .gitignore
└── LICENSE
/rtdl/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """@private"""
2 |
--------------------------------------------------------------------------------
/rtdl/_utils.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | from typing_extensions import ParamSpec
4 |
5 | INTERNAL_ERROR_MESSAGE = (
6 | 'Internal error. Please, open an issue here:'
7 | ' https://github.com/Yura52/rtdl/issues/new'
8 | )
9 |
10 |
11 | def all_or_none(values):
12 | return all(x is None for x in values) or all(x is not None for x in values)
13 |
14 |
15 | P = ParamSpec('P')
16 | T = TypeVar('T')
17 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: rtdl
2 | channels:
3 | - https://conda.anaconda.org/conda-forge
4 | - nodefaults
5 | dependencies:
6 | - black=23.10.1
7 | - flit=3.9.0
8 | - isort=5.12.0
9 | - jupyterlab<5
10 | - mypy=1.6.1
11 | - numpy=1.19.5
12 | - pdoc=14.1.0
13 | - pip<24
14 | - pytest=7.4.2
15 | - python=3.8
16 | - pytorch=1.8.0
17 | - scikit-learn=1.0.2
18 | - ruff=0.1.7
19 | - tomli=2.0.1
20 | - tqdm=4.66.1
21 | - xdoctest=1.1.2
22 | - pip:
23 | - delu==0.0.23
24 |
--------------------------------------------------------------------------------
/rtdl/__init__.py:
--------------------------------------------------------------------------------
1 | """Research on tabular deep learning."""
2 |
3 | __version__ = '0.0.14.dev7'
4 |
5 | import warnings
6 |
7 | warnings.warn(
8 | 'The rtdl package is deprecated. See the GitHub repository to learn more.',
9 | DeprecationWarning,
10 | )
11 |
12 | from . import data # noqa: F401
13 | from .functional import geglu, reglu # noqa: F401
14 | from .modules import ( # noqa: F401
15 | GEGLU,
16 | MLP,
17 | CategoricalFeatureTokenizer,
18 | CLSToken,
19 | FeatureTokenizer,
20 | FTTransformer,
21 | MultiheadAttention,
22 | NumericalFeatureTokenizer,
23 | ReGLU,
24 | ResNet,
25 | Transformer,
26 | )
27 |
--------------------------------------------------------------------------------
/rtdl/functional.py:
--------------------------------------------------------------------------------
1 | """@private"""
2 | import torch.nn.functional as F
3 | from torch import Tensor
4 |
5 |
6 | def reglu(x: Tensor) -> Tensor:
7 | """The ReGLU activation function from [1].
8 |
9 | References:
10 |
11 | [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020
12 | """
13 | assert x.shape[-1] % 2 == 0
14 | a, b = x.chunk(2, dim=-1)
15 | return a * F.relu(b)
16 |
17 |
18 | def geglu(x: Tensor) -> Tensor:
19 | """The GEGLU activation function from [1].
20 |
21 | References:
22 |
23 | [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020
24 | """
25 | assert x.shape[-1] % 2 == 0
26 | a, b = x.chunk(2, dim=-1)
27 | return a * F.gelu(b)
28 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: default clean doctest lint pre-commit typecheck
2 |
3 | PACKAGE_ROOT = rtdl
4 |
5 | default:
6 | echo "Hello, World!"
7 |
8 | clean:
9 | find . -type f -name "*.py[co]" -delete -o -type d -name __pycache__ -delete
10 | rm -rf .ipynb_checkpoints
11 | rm -rf .mypy_cache
12 | rm -rf .pytest_cache
13 | rm -rf .ruff_cache
14 | rm -rf dist
15 |
16 | doctest:
17 | xdoctest $(PACKAGE_ROOT)
18 | python test_code_blocks.py rtdl/revisiting_models/README.md
19 | python test_code_blocks.py rtdl/num_embeddings/README.md
20 |
21 | lint:
22 | isort $(PACKAGE_ROOT) --check-only
23 | black $(PACKAGE_ROOT) --check
24 | ruff check .
25 |
26 | # The order is important.
27 | pre-commit: clean lint doctest typecheck
28 |
29 | typecheck:
30 | mypy $(PACKAGE_ROOT)
31 |
--------------------------------------------------------------------------------
/rtdl/tests/test_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import rtdl.data
5 |
6 |
7 | def test_get_category_sizes():
8 | get_category_sizes = rtdl.data.get_category_sizes
9 |
10 | # not two dimensions
11 | with pytest.raises(ValueError):
12 | get_category_sizes(np.array([0, 0, 0]))
13 | with pytest.raises(ValueError):
14 | get_category_sizes(np.array([[[0, 0, 0]]]))
15 |
16 | # not signed integers
17 | for dtype in [np.uint32, np.float32, str]:
18 | with pytest.raises(ValueError):
19 | get_category_sizes(np.array([[0, 0, 0]], dtype=dtype))
20 |
21 | # non-zero min value
22 | for x in [-1, 1]:
23 | with pytest.raises(ValueError):
24 | get_category_sizes(np.array([[0, 0, x]]))
25 |
26 | # not full range
27 | with pytest.raises(ValueError):
28 | get_category_sizes(np.array([[0, 0, 0], [2, 1, 1]]))
29 |
30 | # correctness
31 | assert get_category_sizes(np.array([[0]])) == [1]
32 | assert get_category_sizes(np.array([[0], [1]])) == [2]
33 | assert get_category_sizes(np.array([[0, 0]])) == [1, 1]
34 | assert get_category_sizes(np.array([[0, 0], [0, 1]])) == [1, 2]
35 | assert get_category_sizes(np.array([[1, 0], [0, 1]])) == [2, 2]
36 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "flit_core.buildapi"
3 | requires = ["flit_core >=3.2,<4"]
4 |
5 | [project]
6 | authors = [{ name = "Yury Gorishniy" }]
7 | classifiers = [
8 | "Intended Audience :: Developers",
9 | "Intended Audience :: Science/Research",
10 | "License :: OSI Approved :: Apache Software License",
11 | "Programming Language :: Python :: 3",
12 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
13 | "Topic :: Software Development :: Libraries :: Python Modules",
14 | ]
15 | dependencies = ["torch >=1.8,<3"]
16 | dynamic = ["version", "description"]
17 | keywords = [
18 | "artificial intelligence",
19 | "deep learning",
20 | "library",
21 | "python",
22 | "pytorch",
23 | "research",
24 | "torch",
25 | "tabular",
26 | "tabular data",
27 | ]
28 | license = { file = "LICENSE" }
29 | name = "rtdl"
30 | requires-python = ">=3.8"
31 |
32 | [project.urls]
33 | Code = "https://github.com/yandex-research/rtdl"
34 | Documentation = "https://github.com/yandex-research/rtdl"
35 |
36 | [tool.black]
37 | skip_string_normalization = true
38 |
39 | [tool.flit.module]
40 | name = "rtdl"
41 |
42 | [tool.isort]
43 | profile = "black"
44 | multi_line_output = 3
45 | known_first_party = ["rtdl"]
46 |
47 | [tool.mypy]
48 | check_untyped_defs = true
49 | ignore_missing_imports = true
50 |
51 | [tool.ruff]
52 | line-length = 88
53 | extend-select = ["RUF", "UP", "E101", "E501"]
54 | target-version = "py38"
55 |
56 | [tool.ruff.per-file-ignores]
57 | "rtdl/_utils.py" = ["E501"]
58 | "rtdl/data.py" = ["E501"]
59 | "rtdl/modules.py" = ["E501"]
60 | "rtdl/nn/*" = ["E501"]
61 |
62 | [[tool.mypy.overrides]]
63 | module = "rtdl.*.tests.*"
64 | ignore_errors = true
65 |
--------------------------------------------------------------------------------
/rtdl/data.py:
--------------------------------------------------------------------------------
1 | """Tools for data (pre)processing. @private"""
2 |
3 | __all__ = ['get_category_sizes']
4 |
5 | from typing import List, TypeVar
6 |
7 | import numpy as np
8 |
9 | Number = TypeVar('Number', int, float)
10 |
11 |
12 | def get_category_sizes(X: np.ndarray) -> List[int]:
13 | """Validate encoded categorical features and count distinct values.
14 |
15 | The function calculates the "category sizes" that can be used to construct
16 | `rtdl.CategoricalFeatureTokenizer` and `rtdl.FTTransformer`. Additionally, the
17 | following conditions are checked:
18 |
19 | * the data is a two-dimensional array of signed integers
20 | * distinct values of each column form zero-based ranges
21 |
22 | Note:
23 | For valid inputs, the result equals :code:`X.max(0) + 1`.
24 |
25 | Args:
26 | X: encoded categorical features (e.g. the output of :code:`sklearn.preprocessing.OrdinalEncoder`)
27 |
28 | Returns:
29 | The counts of distinct values for all columns.
30 |
31 | Examples:
32 | .. testcode::
33 |
34 | assert get_category_sizes(np.array(
35 | [
36 | [0, 0, 0],
37 | [1, 0, 0],
38 | [2, 1, 0],
39 | ]
40 | )) == [3, 2, 1]
41 | """
42 | if X.ndim != 2:
43 | raise ValueError('X must be two-dimensional')
44 | if not issubclass(X.dtype.type, np.signedinteger):
45 | raise ValueError('X data type must be integer')
46 | sizes = []
47 | for i, column in enumerate(X.T):
48 | unique_values = np.unique(column)
49 | min_value = unique_values.min()
50 | if min_value != 0:
51 | raise ValueError(
52 | f'The minimum value of column {i} is {min_value}, but it must be zero.'
53 | )
54 | max_value = unique_values.max()
55 | if max_value + 1 != len(unique_values):
56 | raise ValueError(
57 | f'The values of column {i} do not fully cover the range from zero to maximum_value={max_value}'
58 | )
59 |
60 | sizes.append(len(unique_values))
61 | return sizes
62 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RTDL (Research on Tabular Deep Learning)
2 |
3 | RTDL (**R**esearch on **T**abular **D**eep **L**earning) is a collection of papers and packages
4 | on deep learning for tabular data.
5 |
6 | :bell: *To follow announcements on new projects, subscribe to releases in this GitHub repository:
7 | "Watch -> Custom -> Releases".*
8 |
9 | > [!NOTE]
10 | > The list of projects below is up-to-date, but the `rtdl` Python package is deprecated.
11 | > If you used the rtdl package, please, read the details.
12 | >
13 | >
14 | >
15 | > 1. First, to clarify, this repository is **NOT** deprecated,
16 | > only the package `rtdl` is deprecated: it is replaced with other packages.
17 | > 2. If you used the latest `rtdl==0.0.13` installed from PyPI (not from GitHub!)
18 | > as `pip install rtdl`, then the same models
19 | > (MLP, ResNet, FT-Transformer) can be found in the `rtdl_revisiting_models` package,
20 | > though API is slightly different.
21 | > 3. :exclamation: **If you used the unfinished code from the main branch, it is highly**
22 | > **recommended to switch to the new packages.** In particular,
23 | > the unfinished implementation of embeddings for continuous features
24 | > contained many unresolved issues (the `rtdl_num_embeddings` package, in turn,
25 | > is more efficient and correct).
26 | >
27 | >
28 |
29 | # Papers
30 |
31 | (2024) TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling
32 |
[Paper](https://arxiv.org/abs/2410.24210)
33 | [Code](https://github.com/yandex-research/tabm)
34 | [Usage](https://github.com/yandex-research/tabm#using-tabm-in-practice)
35 |
36 | (2024) TabReD: Analyzing Pitfalls and Filling the Gaps in Tabular Deep Learning Benchmarks
37 |
[Paper](https://arxiv.org/abs/2406.19380)
38 | [Code](https://github.com/yandex-research/tabred)
39 |
40 | (2023) TabR: Tabular Deep Learning Meets Nearest Neighbors
41 |
[Paper](https://arxiv.org/abs/2307.14338)
42 | [Code](https://github.com/yandex-research/tabular-dl-tabr)
43 |
44 | (2022) TabDDPM: Modelling Tabular Data with Diffusion Models
45 |
[Paper](https://arxiv.org/abs/2209.15421)
46 | [Code](https://github.com/yandex-research/tab-ddpm)
47 |
48 | (2022) Revisiting Pretraining Objectives for Tabular Deep Learning
49 |
[Paper](https://arxiv.org/abs/2207.03208)
50 | [Code](https://github.com/puhsu/tabular-dl-pretrain-objectives)
51 |
52 | (2022) On Embeddings for Numerical Features in Tabular Deep Learning
53 |
[Paper](https://arxiv.org/abs/2203.05556)
54 | [Code](https://github.com/yandex-research/rtdl-num-embeddings)
55 | [Package (rtdl_num_embeddings)](https://github.com/yandex-research/rtdl-num-embeddings/tree/main/package/README.md)
56 |
57 | (2021) Revisiting Deep Learning Models for Tabular Data
58 |
[Paper](https://arxiv.org/abs/2106.11959)
59 | [Code](https://github.com/yandex-research/rtdl-revisiting-models)
60 | [Package (rtdl_revisiting_models)](https://github.com/yandex-research/rtdl-revisiting-models/tree/main/package/README.md)
61 |
62 | (2019) Neural Oblivious Decision Ensembles for Deep Learning on Tabular Data
63 |
[Paper](https://arxiv.org/abs/1909.06312)
64 | [Code](https://github.com/Qwicen/node)
65 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # >>> GITHUB DEFAULT PYTHON .GIGIGNORE
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | # lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # pytype static type analyzer
136 | .pytype/
137 |
138 | # Cython debug symbols
139 | cython_debug/
140 |
141 | # <<< GITHUB DEFAULT PYTHON .GIGIGNORE
142 |
143 | # The following directory is automatically generated by Sphnix
144 | docs/api
145 |
146 | # Data, checkpoints, etc.
147 | **/data/**
148 | **/catboost_cached_datasets/**
149 | *.bin
150 | *.csv
151 | *.cbm
152 | *.npy
153 | *.pickle
154 | *.pt
155 | *.pth
156 | *.rar
157 | *.tar*
158 | *.tmp
159 | *.zip
160 | events.out.tfevents.*
161 |
162 | # Other
163 | .DS_Store
164 | .vscode/
165 | local/
166 |
--------------------------------------------------------------------------------
/rtdl/tests/test_modules.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | import rtdl
5 |
6 |
7 | def get_devices():
8 | return ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']
9 |
10 |
11 | def test_bad_mlp():
12 | with pytest.raises(AssertionError):
13 | rtdl.MLP.make_baseline(1, [1, 2, 3, 4], 0.0, 1)
14 |
15 |
16 | @pytest.mark.parametrize('n_blocks', range(5))
17 | @pytest.mark.parametrize('d_out', [1, 2])
18 | @pytest.mark.parametrize('constructor', range(2))
19 | @pytest.mark.parametrize('device', get_devices())
20 | def test_mlp(n_blocks, d_out, constructor, device):
21 | if not n_blocks and not d_out:
22 | return
23 | d = 4
24 | d_last = d + 1
25 | d_layers = []
26 | if n_blocks:
27 | d_layers.append(d)
28 | if n_blocks > 2:
29 | d_layers.extend([d + d_out] * (n_blocks - 2))
30 | if n_blocks > 1:
31 | d_layers.append(d_last)
32 |
33 | def f0():
34 | dropouts = [0.1 * x for x in range(len(d_layers))]
35 | return rtdl.MLP(
36 | d_in=d, d_layers=d_layers, dropouts=dropouts, activation='GELU', d_out=d_out
37 | )
38 |
39 | def f1():
40 | return rtdl.MLP.make_baseline(
41 | d_in=d, d_layers=d_layers, dropout=0.1, d_out=d_out
42 | )
43 |
44 | model = locals()[f'f{constructor}']().to(device)
45 | n = 2
46 | assert model(torch.randn(n, d, device=device)).shape == (
47 | (n, d_out) if d_out else (n, d_last) if n_blocks > 1 else (n, d)
48 | )
49 |
50 |
51 | @pytest.mark.parametrize('n_blocks', [1, 2])
52 | @pytest.mark.parametrize('d_out', [1, 2])
53 | @pytest.mark.parametrize('constructor', range(2))
54 | @pytest.mark.parametrize('device', get_devices())
55 | def test_resnet(n_blocks, d_out, constructor, device):
56 | d = 4
57 |
58 | def f0():
59 | return rtdl.ResNet.make_baseline(
60 | d_in=d,
61 | d_main=d,
62 | d_hidden=d * 3,
63 | dropout_first=0.1,
64 | dropout_second=0.2,
65 | n_blocks=n_blocks,
66 | d_out=d_out,
67 | )
68 |
69 | def f1():
70 | return rtdl.ResNet(
71 | d_in=d,
72 | d_main=d,
73 | d_hidden=d * 3,
74 | dropout_first=0.1,
75 | dropout_second=0.2,
76 | n_blocks=n_blocks,
77 | normalization='Identity',
78 | activation='ReLU6',
79 | d_out=d_out,
80 | )
81 |
82 | model = locals()[f'f{constructor}']().to(device)
83 | n = 2
84 | assert model(torch.randn(n, d, device=device)).shape == (
85 | (n, d_out) if d_out else (n, d)
86 | )
87 |
88 |
89 | @pytest.mark.parametrize('n_blocks', range(1, 7))
90 | @pytest.mark.parametrize('d_out', [1, 2])
91 | @pytest.mark.parametrize('last_layer_query_idx', [None, [-1]])
92 | @pytest.mark.parametrize('constructor', range(3))
93 | @pytest.mark.parametrize('device', get_devices())
94 | def test_ft_transformer(n_blocks, d_out, last_layer_query_idx, constructor, device):
95 | n_num_features = 4
96 | model = rtdl.FTTransformer.make_default(
97 | n_num_features=4,
98 | cat_cardinalities=[2, 3],
99 | n_blocks=n_blocks,
100 | last_layer_query_idx=last_layer_query_idx,
101 | kv_compression_ratio=0.5,
102 | kv_compression_sharing='headwise',
103 | d_out=d_out,
104 | ).to(device)
105 | n = 2
106 |
107 | # check that the following methods do not fail
108 | model.optimization_param_groups()
109 | model.make_default_optimizer()
110 |
111 | def f0():
112 | return rtdl.FTTransformer.make_default(
113 | n_num_features=4,
114 | cat_cardinalities=[2, 3],
115 | n_blocks=n_blocks,
116 | last_layer_query_idx=last_layer_query_idx,
117 | kv_compression_ratio=0.5,
118 | kv_compression_sharing='headwise',
119 | d_out=d_out,
120 | )
121 |
122 | def f1():
123 | return rtdl.FTTransformer.make_baseline(
124 | n_num_features=4,
125 | cat_cardinalities=[2, 3],
126 | n_blocks=n_blocks,
127 | d_token=8,
128 | attention_dropout=0.2,
129 | ffn_d_hidden=8,
130 | ffn_dropout=0.3,
131 | residual_dropout=0.4,
132 | last_layer_query_idx=last_layer_query_idx,
133 | kv_compression_ratio=0.5,
134 | kv_compression_sharing='headwise',
135 | d_out=d_out,
136 | )
137 |
138 | def f2():
139 | d_token = 8
140 | rtdl.Transformer.WARNINGS['prenormalization'] = False
141 | model = rtdl.FTTransformer(
142 | rtdl.FeatureTokenizer(4, [2, 3], d_token),
143 | rtdl.Transformer(
144 | d_token=d_token,
145 | attention_n_heads=1,
146 | attention_dropout=0.3,
147 | attention_initialization='xavier',
148 | attention_normalization='Identity',
149 | ffn_d_hidden=4,
150 | ffn_dropout=0.3,
151 | ffn_activation='SELU',
152 | residual_dropout=0.2,
153 | ffn_normalization='Identity',
154 | prenormalization=False,
155 | first_prenormalization=False,
156 | n_tokens=7,
157 | head_activation='PReLU',
158 | head_normalization='Identity',
159 | n_blocks=n_blocks,
160 | last_layer_query_idx=last_layer_query_idx,
161 | kv_compression_ratio=0.5,
162 | kv_compression_sharing='headwise',
163 | d_out=d_out,
164 | ),
165 | )
166 | rtdl.Transformer.WARNINGS['prenormalization'] = True
167 | return model
168 |
169 | model = locals()[f'f{constructor}']().to(device)
170 | x = model(
171 | torch.randn(n, n_num_features, device=device),
172 | torch.tensor([[0, 2], [1, 2]], device=device),
173 | )
174 | if d_out:
175 | assert x.shape == (n, d_out)
176 | else:
177 | assert x.shape == (
178 | n,
179 | model.feature_tokenizer.n_tokens + 1 if last_layer_query_idx is None else 1,
180 | model.feature_tokenizer.d_token,
181 | )
182 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Yandex LLC
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/rtdl/tests/test_vs_paper.py:
--------------------------------------------------------------------------------
1 | # Tests in this file validate that the models in `rtdl` are LITERALLY THE SAME as
2 | # the ones used in the paper (https://github.com/Yura52/rtdl/tree/main/bin)
3 | # The testing approach:
4 | # (1) copy weights from the correct model to the RTDL model
5 | # (2) check that the two models produce the same output for the same input
6 | import math
7 | import random
8 | import typing as ty
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from pytest import mark
15 | from torch import Tensor
16 |
17 | import rtdl
18 |
19 |
20 | class Model(nn.Module):
21 | def __init__(self, cat_input_module: nn.Module, model: nn.Module):
22 | super().__init__()
23 | self.cat_input_module = cat_input_module
24 | self.model = model
25 |
26 | def forward(self, x_num, x_cat):
27 | return self.model(
28 | torch.cat([x_num, self.cat_input_module(x_cat).flatten(1, -1)], dim=1)
29 | )
30 |
31 |
32 | def set_seeds(seed):
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 |
37 |
38 | @torch.no_grad()
39 | def copy_layer(dst: nn.Module, src: nn.Module):
40 | for key, _ in dst.named_parameters():
41 | getattr(dst, key).copy_(getattr(src, key))
42 |
43 |
44 | @torch.no_grad()
45 | @mark.parametrize('seed', range(10))
46 | def test_mlp(seed):
47 | # Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/bin/mlp.py
48 |
49 | class CorrectMLP(nn.Module):
50 | def __init__(
51 | self,
52 | *,
53 | d_in: int,
54 | d_layers: ty.List[int],
55 | dropout: float,
56 | d_out: int,
57 | categories: ty.Optional[ty.List[int]],
58 | d_embedding: int,
59 | ) -> None:
60 | super().__init__()
61 |
62 | if categories is not None:
63 | d_in += len(categories) * d_embedding
64 | category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
65 | self.register_buffer('category_offsets', category_offsets)
66 | self.category_embeddings = nn.Embedding(sum(categories), d_embedding)
67 | nn.init.kaiming_uniform_(
68 | self.category_embeddings.weight, a=math.sqrt(5)
69 | )
70 | # print(f'{self.category_embeddings.weight.shape=}')
71 |
72 | self.layers = nn.ModuleList(
73 | [
74 | nn.Linear(d_layers[i - 1] if i else d_in, x)
75 | for i, x in enumerate(d_layers)
76 | ]
77 | )
78 | self.dropout = dropout
79 | self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out)
80 |
81 | def forward(self, x_num, x_cat):
82 | if x_cat is not None:
83 | x_cat = self.category_embeddings(x_cat + self.category_offsets[None]) # type: ignore
84 | x = torch.cat([x_num, x_cat.view(x_cat.size(0), -1)], dim=-1)
85 | else:
86 | x = x_num
87 |
88 | for layer in self.layers:
89 | x = layer(x)
90 | x = F.relu(x)
91 | if self.dropout:
92 | x = F.dropout(x, self.dropout, self.training)
93 | x = self.head(x)
94 | return x
95 |
96 | n = 32
97 | d_num = 2
98 | categories = [2, 3]
99 | d_embedding = 3
100 | d_in = d_num + len(categories) * d_embedding
101 | d_layers = [3, 4, 5]
102 | dropout = 0.1
103 | d_out = 2
104 |
105 | set_seeds(seed)
106 | correct_model = CorrectMLP(
107 | d_in=d_num,
108 | d_layers=d_layers,
109 | dropout=dropout,
110 | d_out=d_out,
111 | categories=categories,
112 | d_embedding=d_embedding,
113 | )
114 | rtdl_tokenizer = rtdl.CategoricalFeatureTokenizer(
115 | categories, d_embedding, False, 'uniform'
116 | )
117 | rtdl_backbone = rtdl.MLP(
118 | d_in=d_in,
119 | d_layers=d_layers,
120 | dropouts=dropout,
121 | activation='ReLU',
122 | d_out=d_out,
123 | )
124 |
125 | rtdl_tokenizer.embeddings.weight.copy_(correct_model.category_embeddings.weight)
126 | for correct_layer, block in zip(correct_model.layers, rtdl_backbone.blocks):
127 | copy_layer(block.linear, correct_layer)
128 | copy_layer(rtdl_backbone.head, correct_model.head)
129 |
130 | rtdl_model = Model(rtdl_tokenizer, rtdl_backbone)
131 | x_num = torch.randn(n, d_num)
132 | x_cat = torch.cat([torch.randint(x, (n, 1)) for x in categories], dim=1)
133 | set_seeds(seed)
134 | correct_result = correct_model(x_num, x_cat)
135 | set_seeds(seed)
136 | rtdl_result = rtdl_model(x_num, x_cat)
137 | assert (correct_result == rtdl_result).all()
138 |
139 |
140 | @torch.no_grad()
141 | @mark.parametrize('seed', range(10))
142 | def test_resnet(seed):
143 | # Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/bin/resnet.py
144 |
145 | class CorrectResNet(nn.Module):
146 | def __init__(
147 | self,
148 | *,
149 | d_numerical: int,
150 | categories: ty.Optional[ty.List[int]],
151 | d_embedding: int,
152 | d: int,
153 | d_hidden_factor: float,
154 | n_layers: int,
155 | activation: str,
156 | normalization: str,
157 | hidden_dropout: float,
158 | residual_dropout: float,
159 | d_out: int,
160 | ) -> None:
161 | super().__init__()
162 |
163 | def make_normalization():
164 | return {'BatchNorm1d': nn.BatchNorm1d}[normalization](d)
165 |
166 | assert activation == 'ReLU'
167 | self.main_activation = F.relu
168 | self.last_activation = F.relu
169 | self.residual_dropout = residual_dropout
170 | self.hidden_dropout = hidden_dropout
171 |
172 | d_in = d_numerical
173 | d_hidden = int(d * d_hidden_factor)
174 |
175 | if categories is not None:
176 | d_in += len(categories) * d_embedding
177 | category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
178 | self.register_buffer('category_offsets', category_offsets)
179 | self.category_embeddings = nn.Embedding(sum(categories), d_embedding)
180 | nn.init.kaiming_uniform_(
181 | self.category_embeddings.weight, a=math.sqrt(5)
182 | )
183 | # print(f'{self.category_embeddings.weight.shape=}')
184 |
185 | self.first_layer = nn.Linear(d_in, d)
186 | self.layers = nn.ModuleList(
187 | [
188 | nn.ModuleDict(
189 | {
190 | 'norm': make_normalization(),
191 | 'linear0': nn.Linear(
192 | d, d_hidden * (2 if activation.endswith('glu') else 1)
193 | ),
194 | 'linear1': nn.Linear(d_hidden, d),
195 | }
196 | )
197 | for _ in range(n_layers)
198 | ]
199 | )
200 | self.last_normalization = make_normalization()
201 | self.head = nn.Linear(d, d_out)
202 |
203 | def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
204 | if x_cat is not None:
205 | x_cat = self.category_embeddings(x_cat + self.category_offsets[None]) # type: ignore
206 | x = torch.cat([x_num, x_cat.view(x_cat.size(0), -1)], dim=-1) # type: ignore
207 | else:
208 | x = x_num
209 |
210 | x = self.first_layer(x)
211 | for layer in self.layers:
212 | layer = ty.cast(ty.Dict[str, nn.Module], layer)
213 | z = x
214 | z = layer['norm'](z)
215 | z = layer['linear0'](z)
216 | z = self.main_activation(z)
217 | if self.hidden_dropout:
218 | z = F.dropout(z, self.hidden_dropout, self.training)
219 | z = layer['linear1'](z)
220 | if self.residual_dropout:
221 | z = F.dropout(z, self.residual_dropout, self.training)
222 | x = x + z
223 | x = self.last_normalization(x)
224 | x = self.last_activation(x)
225 | x = self.head(x)
226 | return x
227 |
228 | n = 32
229 | d_num = 2
230 | categories = [2, 3]
231 | d_embedding = 3
232 | d_in = d_num + len(categories) * d_embedding
233 | d = 4
234 | d_hidden_factor = 1.5
235 | n_layers = 2
236 | activation = 'ReLU'
237 | normalization = 'BatchNorm1d'
238 | hidden_dropout = 0.1
239 | residual_dropout = 0.2
240 | d_out = 2
241 |
242 | set_seeds(seed)
243 | correct_model = CorrectResNet(
244 | d_numerical=d_num,
245 | categories=categories,
246 | d_embedding=d_embedding,
247 | d=d,
248 | d_hidden_factor=d_hidden_factor,
249 | n_layers=n_layers,
250 | activation=activation,
251 | normalization=normalization,
252 | hidden_dropout=0.1,
253 | residual_dropout=0.2,
254 | d_out=d_out,
255 | )
256 | rtdl_tokenizer = rtdl.CategoricalFeatureTokenizer(
257 | categories, d_embedding, False, 'uniform'
258 | )
259 | rtdl_backbone = rtdl.ResNet(
260 | d_in=d_in,
261 | n_blocks=n_layers,
262 | d_main=d,
263 | d_hidden=int(d * d_hidden_factor),
264 | dropout_first=hidden_dropout,
265 | dropout_second=residual_dropout,
266 | normalization=normalization,
267 | activation=activation,
268 | d_out=d_out,
269 | )
270 |
271 | rtdl_tokenizer.embeddings.weight.copy_(correct_model.category_embeddings.weight)
272 | copy_layer(rtdl_backbone.first_layer, correct_model.first_layer)
273 | for correct_layer, block in zip(correct_model.layers, rtdl_backbone.blocks):
274 | copy_layer(block.normalization, correct_layer['norm'])
275 | copy_layer(block.linear_first, correct_layer['linear0'])
276 | copy_layer(block.linear_second, correct_layer['linear1'])
277 |
278 | copy_layer(rtdl_backbone.head.normalization, correct_model.last_normalization)
279 | copy_layer(rtdl_backbone.head.linear, correct_model.head)
280 |
281 | rtdl_model = Model(rtdl_tokenizer, rtdl_backbone)
282 | x_num = torch.randn(n, d_num)
283 | x_cat = torch.cat([torch.randint(x, (n, 1)) for x in categories], dim=1)
284 | set_seeds(seed)
285 | correct_result = correct_model(x_num, x_cat)
286 | set_seeds(seed)
287 | rtdl_result = rtdl_model(x_num, x_cat)
288 | assert (correct_result == rtdl_result).all()
289 |
290 |
291 | @torch.no_grad()
292 | @mark.parametrize('seed', range(10))
293 | @mark.parametrize('kv_compression_ratio', [None, 0.5])
294 | def test_ft_transformer(seed, kv_compression_ratio):
295 | # """Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/bin/ft_transformer.py"""
296 | # The only difference is that [CLS] is now the last token.
297 |
298 | def correct_reglu(x: Tensor) -> Tensor:
299 | a, b = x.chunk(2, dim=-1)
300 | return a * F.relu(b)
301 |
302 | class CorrectTokenizer(nn.Module):
303 | category_offsets: ty.Optional[Tensor]
304 |
305 | def __init__(
306 | self,
307 | d_numerical: int,
308 | categories: ty.Optional[ty.List[int]],
309 | d_token: int,
310 | bias: bool,
311 | ) -> None:
312 | super().__init__()
313 | if categories is None:
314 | d_bias = d_numerical
315 | self.category_offsets = None
316 | self.category_embeddings = None
317 | else:
318 | d_bias = d_numerical + len(categories)
319 | category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
320 | self.register_buffer('category_offsets', category_offsets)
321 | self.category_embeddings = nn.Embedding(sum(categories), d_token)
322 | nn.init.kaiming_uniform_(
323 | self.category_embeddings.weight, a=math.sqrt(5)
324 | )
325 |
326 | # take [CLS] token into account
327 | self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
328 | self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
329 | # The initialization is inspired by nn.Linear
330 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
331 | if self.bias is not None:
332 | nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5))
333 |
334 | @property
335 | def n_tokens(self) -> int:
336 | return len(self.weight) + (
337 | 0 if self.category_offsets is None else len(self.category_offsets)
338 | )
339 |
340 | def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
341 | # x_num = torch.cat(
342 | # [
343 | # torch.ones(len(x_num), 1, device=x_num.device),
344 | # x_num,
345 | # ],
346 | # dim=1,
347 | # )
348 | x = self.weight[:-1][None] * x_num[:, :, None]
349 | if x_cat is not None:
350 | x = torch.cat(
351 | [x, self.category_embeddings(x_cat + self.category_offsets[None])],
352 | dim=1,
353 | )
354 | x = torch.cat(
355 | [x, self.weight[-1][None, None].repeat(len(x), 1, 1)],
356 | dim=1,
357 | )
358 | if self.bias is not None:
359 | bias = torch.cat(
360 | [
361 | self.bias,
362 | torch.zeros(1, self.bias.shape[1], device=x_num.device),
363 | ]
364 | )
365 | x = x + bias[None]
366 | return x
367 |
368 | class CorrectMultiheadAttention(nn.Module):
369 | def __init__(
370 | self, d: int, n_heads: int, dropout: float, initialization: str
371 | ) -> None:
372 | if n_heads > 1:
373 | assert d % n_heads == 0
374 | assert initialization in ['xavier', 'kaiming']
375 |
376 | super().__init__()
377 | self.W_q = nn.Linear(d, d)
378 | self.W_k = nn.Linear(d, d)
379 | self.W_v = nn.Linear(d, d)
380 | self.W_out = nn.Linear(d, d) if n_heads > 1 else None
381 | self.n_heads = n_heads
382 | self.dropout = nn.Dropout(dropout) if dropout else None
383 |
384 | for m in [self.W_q, self.W_k, self.W_v]:
385 | if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
386 | # gain is needed since W_qkv is represented with 3 separate layers
387 | nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
388 | nn.init.zeros_(m.bias)
389 | if self.W_out is not None:
390 | nn.init.zeros_(self.W_out.bias)
391 |
392 | def _reshape(self, x: Tensor) -> Tensor:
393 | batch_size, n_tokens, d = x.shape
394 | d_head = d // self.n_heads
395 | return (
396 | x.reshape(batch_size, n_tokens, self.n_heads, d_head)
397 | .transpose(1, 2)
398 | .reshape(batch_size * self.n_heads, n_tokens, d_head)
399 | )
400 |
401 | def forward(
402 | self,
403 | x_q: Tensor,
404 | x_kv: Tensor,
405 | key_compression: ty.Optional[nn.Linear],
406 | value_compression: ty.Optional[nn.Linear],
407 | ) -> Tensor:
408 | q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
409 | for tensor in [q, k, v]:
410 | assert tensor.shape[-1] % self.n_heads == 0
411 | if key_compression is not None:
412 | assert value_compression is not None
413 | k = key_compression(k.transpose(1, 2)).transpose(1, 2)
414 | v = value_compression(v.transpose(1, 2)).transpose(1, 2)
415 | else:
416 | assert value_compression is None
417 |
418 | batch_size = len(q)
419 | d_head_key = k.shape[-1] // self.n_heads
420 | d_head_value = v.shape[-1] // self.n_heads
421 | n_q_tokens = q.shape[1]
422 |
423 | q = self._reshape(q)
424 | k = self._reshape(k)
425 | attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1)
426 | if self.dropout is not None:
427 | attention = self.dropout(attention)
428 | x = attention @ self._reshape(v)
429 | x = (
430 | x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
431 | .transpose(1, 2)
432 | .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
433 | )
434 | if self.W_out is not None:
435 | x = self.W_out(x)
436 | return x
437 |
438 | class CorrectFTTransformer(nn.Module):
439 | def __init__(
440 | self,
441 | *,
442 | # tokenizer
443 | d_numerical: int,
444 | categories: ty.Optional[ty.List[int]],
445 | token_bias: bool,
446 | # transformer
447 | n_layers: int,
448 | d_token: int,
449 | n_heads: int,
450 | d_ffn_factor: float,
451 | attention_dropout: float,
452 | ffn_dropout: float,
453 | residual_dropout: float,
454 | activation: str,
455 | prenormalization: bool,
456 | initialization: str,
457 | # linformer
458 | kv_compression: ty.Optional[float],
459 | kv_compression_sharing: ty.Optional[str],
460 | #
461 | d_out: int,
462 | ) -> None:
463 | assert (kv_compression is None) ^ (kv_compression_sharing is not None)
464 |
465 | super().__init__()
466 | self.tokenizer = CorrectTokenizer(
467 | d_numerical, categories, d_token, token_bias
468 | )
469 | n_tokens = self.tokenizer.n_tokens
470 |
471 | def make_kv_compression():
472 | assert kv_compression
473 | compression = nn.Linear(
474 | n_tokens, int(n_tokens * kv_compression), bias=False
475 | )
476 | if initialization == 'xavier':
477 | nn.init.xavier_uniform_(compression.weight)
478 | return compression
479 |
480 | self.shared_kv_compression = (
481 | make_kv_compression()
482 | if kv_compression and kv_compression_sharing == 'layerwise'
483 | else None
484 | )
485 |
486 | def make_normalization():
487 | return nn.LayerNorm(d_token)
488 |
489 | d_hidden = int(d_token * d_ffn_factor)
490 | self.layers = nn.ModuleList([])
491 | for layer_idx in range(n_layers):
492 | layer = nn.ModuleDict(
493 | {
494 | 'attention': CorrectMultiheadAttention(
495 | d_token, n_heads, attention_dropout, initialization
496 | ),
497 | 'linear0': nn.Linear(
498 | d_token, d_hidden * (2 if activation.endswith('glu') else 1)
499 | ),
500 | 'linear1': nn.Linear(d_hidden, d_token),
501 | 'norm1': make_normalization(),
502 | }
503 | )
504 | if not prenormalization or layer_idx:
505 | layer['norm0'] = make_normalization()
506 | if kv_compression and self.shared_kv_compression is None:
507 | layer['key_compression'] = make_kv_compression()
508 | if kv_compression_sharing == 'headwise':
509 | layer['value_compression'] = make_kv_compression()
510 | else:
511 | assert kv_compression_sharing == 'key-value'
512 | self.layers.append(layer)
513 |
514 | assert activation == 'reglu'
515 | self.activation = correct_reglu
516 | self.last_activation = F.relu
517 | self.prenormalization = prenormalization
518 | self.last_normalization = make_normalization() if prenormalization else None
519 | self.ffn_dropout = ffn_dropout
520 | self.residual_dropout = residual_dropout
521 | self.head = nn.Linear(d_token, d_out)
522 |
523 | def _get_kv_compressions(self, layer):
524 | return (
525 | (self.shared_kv_compression, self.shared_kv_compression)
526 | if self.shared_kv_compression is not None
527 | else (layer['key_compression'], layer['value_compression'])
528 | if 'key_compression' in layer and 'value_compression' in layer
529 | else (layer['key_compression'], layer['key_compression'])
530 | if 'key_compression' in layer
531 | else (None, None)
532 | )
533 |
534 | def _start_residual(self, x, layer, norm_idx):
535 | x_residual = x
536 | if self.prenormalization:
537 | norm_key = f'norm{norm_idx}'
538 | if norm_key in layer:
539 | x_residual = layer[norm_key](x_residual)
540 | return x_residual
541 |
542 | def _end_residual(self, x, x_residual, layer, norm_idx):
543 | if self.residual_dropout:
544 | x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
545 | x = x + x_residual
546 | if not self.prenormalization:
547 | x = layer[f'norm{norm_idx}'](x)
548 | return x
549 |
550 | def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor:
551 | x = self.tokenizer(x_num, x_cat)
552 |
553 | for layer_idx, layer in enumerate(self.layers):
554 | is_last_layer = layer_idx + 1 == len(self.layers)
555 | layer = ty.cast(ty.Dict[str, nn.Module], layer)
556 |
557 | x_residual = self._start_residual(x, layer, 0)
558 | x_residual = layer['attention'](
559 | # for the last attention, it is enough to process only [CLS]
560 | (x_residual[:, -1:] if is_last_layer else x_residual),
561 | x_residual,
562 | *self._get_kv_compressions(layer),
563 | )
564 | if is_last_layer:
565 | x = x[:, -1:]
566 | x = self._end_residual(x, x_residual, layer, 0)
567 |
568 | x_residual = self._start_residual(x, layer, 1)
569 | x_residual = layer['linear0'](x_residual)
570 | x_residual = self.activation(x_residual)
571 | if self.ffn_dropout:
572 | x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
573 | x_residual = layer['linear1'](x_residual)
574 | x = self._end_residual(x, x_residual, layer, 1)
575 |
576 | assert x.shape[1] == 1
577 | x = x[:, 0]
578 | if self.last_normalization is not None:
579 | x = self.last_normalization(x)
580 | x = self.last_activation(x)
581 | x = self.head(x)
582 | x = x.squeeze(-1)
583 | return x
584 |
585 | # Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/output/california_housing/ft_transformer/default/0.toml
586 | default_config = {
587 | 'seed': 0,
588 | 'data': {
589 | 'normalization': 'quantile_normal',
590 | 'path': 'data/california_housing',
591 | 'y_policy': 'mean_std',
592 | },
593 | 'model': {
594 | 'activation': 'reglu',
595 | 'attention_dropout': 0.2,
596 | 'd_ffn_factor': 4 / 3,
597 | 'd_token': 192,
598 | 'ffn_dropout': 0.1,
599 | 'initialization': 'kaiming',
600 | 'n_heads': 8,
601 | 'n_layers': 3,
602 | 'prenormalization': True,
603 | 'residual_dropout': 0.0,
604 | },
605 | 'training': {
606 | 'batch_size': 256,
607 | 'eval_batch_size': 8192,
608 | 'lr': 0.0001,
609 | 'lr_n_decays': 0,
610 | 'n_epochs': 1000000000,
611 | 'optimizer': 'adamw',
612 | 'patience': 16,
613 | 'weight_decay': 1e-05,
614 | },
615 | }
616 | dcm = default_config['model']
617 |
618 | n = 4
619 | d_num = 2
620 | categories = [2, 3]
621 | n_tokens = d_num + len(categories) + 1
622 | kv_compression_sharing = 'key-value' if kv_compression_ratio else None
623 | d_out = 2
624 |
625 | set_seeds(seed)
626 | correct_model = CorrectFTTransformer(
627 | d_numerical=d_num,
628 | categories=categories,
629 | token_bias=True,
630 | kv_compression=kv_compression_ratio,
631 | kv_compression_sharing=kv_compression_sharing,
632 | d_out=d_out,
633 | **dcm,
634 | )
635 | rtdl_model = rtdl.FTTransformer(
636 | rtdl.FeatureTokenizer(d_num, categories, dcm['d_token']),
637 | rtdl.Transformer(
638 | d_token=dcm['d_token'],
639 | n_blocks=dcm['n_layers'],
640 | attention_n_heads=dcm['n_heads'],
641 | attention_dropout=dcm['attention_dropout'],
642 | attention_initialization=dcm['initialization'],
643 | attention_normalization='LayerNorm',
644 | ffn_d_hidden=int(dcm['d_token'] * dcm['d_ffn_factor']),
645 | ffn_dropout=dcm['ffn_dropout'],
646 | ffn_activation='ReGLU',
647 | ffn_normalization='LayerNorm',
648 | residual_dropout=dcm['residual_dropout'],
649 | prenormalization=dcm['prenormalization'],
650 | first_prenormalization=False,
651 | last_layer_query_idx=[-1],
652 | n_tokens=n_tokens if kv_compression_ratio else None,
653 | kv_compression_ratio=kv_compression_ratio,
654 | kv_compression_sharing=kv_compression_sharing,
655 | head_activation='ReLU',
656 | head_normalization='LayerNorm',
657 | d_out=d_out,
658 | ),
659 | )
660 | rtdl_default_model = rtdl.FTTransformer.make_default(
661 | n_num_features=d_num,
662 | cat_cardinalities=categories,
663 | last_layer_query_idx=[-1],
664 | kv_compression_ratio=kv_compression_ratio,
665 | kv_compression_sharing=kv_compression_sharing,
666 | d_out=d_out,
667 | )
668 |
669 | rtdl_model.feature_tokenizer.num_tokenizer.weight.copy_(
670 | correct_model.tokenizer.weight[:-1]
671 | )
672 | rtdl_model.feature_tokenizer.num_tokenizer.bias.copy_(
673 | correct_model.tokenizer.bias[:d_num]
674 | )
675 | rtdl_model.feature_tokenizer.cat_tokenizer.embeddings.weight.copy_(
676 | correct_model.tokenizer.category_embeddings.weight
677 | )
678 | rtdl_model.feature_tokenizer.cat_tokenizer.bias.copy_(
679 | correct_model.tokenizer.bias[-len(categories) :]
680 | )
681 | rtdl_model.cls_token.weight.copy_(correct_model.tokenizer.weight[-1])
682 | for correct_layer, block in zip(
683 | correct_model.layers, rtdl_model.transformer.blocks
684 | ):
685 | for key in ['W_q', 'W_k', 'W_v', 'W_out']:
686 | copy_layer(
687 | getattr(block['attention'], key),
688 | getattr(correct_layer['attention'], key),
689 | )
690 | copy_layer(block['ffn'].linear_first, correct_layer['linear0'])
691 | copy_layer(block['ffn'].linear_second, correct_layer['linear1'])
692 | copy_layer(block['ffn_normalization'], correct_layer['norm1'])
693 | if 'norm0' in correct_layer:
694 | copy_layer(block['attention_normalization'], correct_layer['norm0'])
695 | for key in ['key_compression', 'value_compression']:
696 | if key in correct_layer:
697 | copy_layer(block[key], correct_layer[key])
698 | copy_layer(
699 | rtdl_model.transformer.head.normalization, correct_model.last_normalization
700 | )
701 | copy_layer(rtdl_model.transformer.head.linear, correct_model.head)
702 | rtdl_default_model.load_state_dict(rtdl_model.state_dict())
703 |
704 | x_num = torch.randn(n, d_num)
705 | x_cat = torch.cat([torch.randint(x, (n, 1)) for x in categories], dim=1)
706 | results = []
707 | for m in [correct_model, rtdl_model, rtdl_default_model]:
708 | set_seeds(seed)
709 | results.append(m(x_num, x_cat))
710 | correct_result = results[0]
711 | assert (results[1] == results[2]).all()
712 | assert (results[1] == correct_result).all()
713 |
--------------------------------------------------------------------------------
/rtdl/modules.py:
--------------------------------------------------------------------------------
1 | """@private"""
2 | import enum
3 | import math
4 | import time
5 | import warnings
6 | from typing import (
7 | Any,
8 | Callable,
9 | ClassVar,
10 | Dict,
11 | List,
12 | Optional,
13 | Tuple,
14 | Type,
15 | Union,
16 | cast,
17 | )
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch.optim
23 | from torch import Tensor
24 |
25 | from . import functional as rtdlF
26 |
27 | ModuleType = Union[str, Callable[..., nn.Module]]
28 | _INTERNAL_ERROR_MESSAGE = 'Internal error. Please, open an issue.'
29 |
30 |
31 | def _is_glu_activation(activation: ModuleType):
32 | return (
33 | isinstance(activation, str)
34 | and activation.endswith('GLU')
35 | or activation in [ReGLU, GEGLU]
36 | )
37 |
38 |
39 | def _all_or_none(values):
40 | return all(x is None for x in values) or all(x is not None for x in values)
41 |
42 |
43 | class ReGLU(nn.Module):
44 | """The ReGLU activation function from [shazeer2020glu].
45 |
46 | Examples:
47 | .. testcode::
48 |
49 | module = ReGLU()
50 | x = torch.randn(3, 4)
51 | assert module(x).shape == (3, 2)
52 |
53 | References:
54 | * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020
55 | """
56 |
57 | def forward(self, x: Tensor) -> Tensor:
58 | return rtdlF.reglu(x)
59 |
60 |
61 | class GEGLU(nn.Module):
62 | """The GEGLU activation function from [shazeer2020glu].
63 |
64 | Examples:
65 | .. testcode::
66 |
67 | module = GEGLU()
68 | x = torch.randn(3, 4)
69 | assert module(x).shape == (3, 2)
70 |
71 | References:
72 | * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020
73 | """
74 |
75 | def forward(self, x: Tensor) -> Tensor:
76 | return rtdlF.geglu(x)
77 |
78 |
79 | class _TokenInitialization(enum.Enum):
80 | UNIFORM = 'uniform'
81 | NORMAL = 'normal'
82 |
83 | @classmethod
84 | def from_str(cls, initialization: str) -> '_TokenInitialization':
85 | try:
86 | return cls(initialization)
87 | except ValueError:
88 | valid_values = [x.value for x in _TokenInitialization]
89 | raise ValueError(f'initialization must be one of {valid_values}')
90 |
91 | def apply(self, x: Tensor, d: int) -> None:
92 | d_sqrt_inv = 1 / math.sqrt(d)
93 | if self == _TokenInitialization.UNIFORM:
94 | # used in the paper "Revisiting Deep Learning Models for Tabular Data";
95 | # is equivalent to `nn.init.kaiming_uniform_(x, a=math.sqrt(5))` (which is
96 | # used by torch to initialize nn.Linear.weight, for example)
97 | nn.init.uniform_(x, a=-d_sqrt_inv, b=d_sqrt_inv)
98 | elif self == _TokenInitialization.NORMAL:
99 | nn.init.normal_(x, std=d_sqrt_inv)
100 |
101 |
102 | class NumericalFeatureTokenizer(nn.Module):
103 | """Transforms continuous features to tokens (embeddings).
104 |
105 | See `FeatureTokenizer` for the illustration.
106 |
107 | For one feature, the transformation consists of two steps:
108 |
109 | * the feature is multiplied by a trainable vector
110 | * another trainable vector is added
111 |
112 | Note that each feature has its separate pair of trainable vectors, i.e. the vectors
113 | are not shared between features.
114 |
115 | Examples:
116 | .. testcode::
117 |
118 | x = torch.randn(4, 2)
119 | n_objects, n_features = x.shape
120 | d_token = 3
121 | tokenizer = NumericalFeatureTokenizer(n_features, d_token, True, 'uniform')
122 | tokens = tokenizer(x)
123 | assert tokens.shape == (n_objects, n_features, d_token)
124 | """
125 |
126 | def __init__(
127 | self,
128 | n_features: int,
129 | d_token: int,
130 | bias: bool,
131 | initialization: str,
132 | ) -> None:
133 | """
134 | Args:
135 | n_features: the number of continuous (scalar) features
136 | d_token: the size of one token
137 | bias: if `False`, then the transformation will include only multiplication.
138 | **Warning**: :code:`bias=False` leads to significantly worse results for
139 | Transformer-like (token-based) architectures.
140 | initialization: initialization policy for parameters. Must be one of
141 | :code:`['uniform', 'normal']`. Let :code:`s = d ** -0.5`. Then, the
142 | corresponding distributions are :code:`Uniform(-s, s)` and :code:`Normal(0, s)`.
143 | In [gorishniy2021revisiting], the 'uniform' initialization was used.
144 |
145 | References:
146 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
147 | """
148 | super().__init__()
149 | initialization_ = _TokenInitialization.from_str(initialization)
150 | self.weight = nn.Parameter(Tensor(n_features, d_token))
151 | self.bias = nn.Parameter(Tensor(n_features, d_token)) if bias else None
152 | for parameter in [self.weight, self.bias]:
153 | if parameter is not None:
154 | initialization_.apply(parameter, d_token)
155 |
156 | @property
157 | def n_tokens(self) -> int:
158 | """The number of tokens."""
159 | return len(self.weight)
160 |
161 | @property
162 | def d_token(self) -> int:
163 | """The size of one token."""
164 | return self.weight.shape[1]
165 |
166 | def forward(self, x: Tensor) -> Tensor:
167 | x = self.weight[None] * x[..., None]
168 | if self.bias is not None:
169 | x = x + self.bias[None]
170 | return x
171 |
172 |
173 | class CategoricalFeatureTokenizer(nn.Module):
174 | """Transforms categorical features to tokens (embeddings).
175 |
176 | See `FeatureTokenizer` for the illustration.
177 |
178 | The module efficiently implements a collection of `torch.nn.Embedding` (with
179 | optional biases).
180 |
181 | Examples:
182 | .. testcode::
183 |
184 | # the input must contain integers. For example, if the first feature can
185 | # take 3 distinct values, then its cardinality is 3 and the first column
186 | # must contain values from the range `[0, 1, 2]`.
187 | cardinalities = [3, 10]
188 | x = torch.tensor([
189 | [0, 5],
190 | [1, 7],
191 | [0, 2],
192 | [2, 4]
193 | ])
194 | n_objects, n_features = x.shape
195 | d_token = 3
196 | tokenizer = CategoricalFeatureTokenizer(cardinalities, d_token, True, 'uniform')
197 | tokens = tokenizer(x)
198 | assert tokens.shape == (n_objects, n_features, d_token)
199 | """
200 |
201 | category_offsets: Tensor
202 |
203 | def __init__(
204 | self,
205 | cardinalities: List[int],
206 | d_token: int,
207 | bias: bool,
208 | initialization: str,
209 | ) -> None:
210 | """
211 | Args:
212 | cardinalities: the number of distinct values for each feature. For example,
213 | :code:`cardinalities=[3, 4]` describes two features: the first one can
214 | take values in the range :code:`[0, 1, 2]` and the second one can take
215 | values in the range :code:`[0, 1, 2, 3]`.
216 | d_token: the size of one token.
217 | bias: if `True`, for each feature, a trainable vector is added to the
218 | embedding regardless of feature value. The bias vectors are not shared
219 | between features.
220 | initialization: initialization policy for parameters. Must be one of
221 | :code:`['uniform', 'normal']`. Let :code:`s = d ** -0.5`. Then, the
222 | corresponding distributions are :code:`Uniform(-s, s)` and :code:`Normal(0, s)`. In
223 | the paper [gorishniy2021revisiting], the 'uniform' initialization was
224 | used.
225 |
226 | References:
227 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
228 | """
229 | super().__init__()
230 | assert cardinalities, 'cardinalities must be non-empty'
231 | assert d_token > 0, 'd_token must be positive'
232 | initialization_ = _TokenInitialization.from_str(initialization)
233 |
234 | category_offsets = torch.tensor([0] + cardinalities[:-1]).cumsum(0)
235 | self.register_buffer('category_offsets', category_offsets, persistent=False)
236 | self.embeddings = nn.Embedding(sum(cardinalities), d_token)
237 | self.bias = nn.Parameter(Tensor(len(cardinalities), d_token)) if bias else None
238 |
239 | for parameter in [self.embeddings.weight, self.bias]:
240 | if parameter is not None:
241 | initialization_.apply(parameter, d_token)
242 |
243 | @property
244 | def n_tokens(self) -> int:
245 | """The number of tokens."""
246 | return len(self.category_offsets)
247 |
248 | @property
249 | def d_token(self) -> int:
250 | """The size of one token."""
251 | return self.embeddings.embedding_dim
252 |
253 | def forward(self, x: Tensor) -> Tensor:
254 | x = self.embeddings(x + self.category_offsets[None])
255 | if self.bias is not None:
256 | x = x + self.bias[None]
257 | return x
258 |
259 |
260 | class FeatureTokenizer(nn.Module):
261 | """Combines `NumericalFeatureTokenizer` and `CategoricalFeatureTokenizer`.
262 |
263 | The "Feature Tokenizer" module from [gorishniy2021revisiting]. The module transforms
264 | continuous and categorical features to tokens (embeddings).
265 |
266 | In the illustration below, the red module in the upper brackets represents
267 | `NumericalFeatureTokenizer` and the green module in the lower brackets represents
268 | `CategoricalFeatureTokenizer`.
269 |
270 | .. image:: ../images/feature_tokenizer.png
271 | :scale: 33%
272 | :alt: Feature Tokenizer
273 |
274 | Examples:
275 | .. testcode::
276 |
277 | n_objects = 4
278 | n_num_features = 3
279 | n_cat_features = 2
280 | d_token = 7
281 | x_num = torch.randn(n_objects, n_num_features)
282 | x_cat = torch.tensor([[0, 1], [1, 0], [0, 2], [1, 1]])
283 | # [2, 3] reflects cardinalities fr
284 | tokenizer = FeatureTokenizer(n_num_features, [2, 3], d_token)
285 | tokens = tokenizer(x_num, x_cat)
286 | assert tokens.shape == (n_objects, n_num_features + n_cat_features, d_token)
287 |
288 | References:
289 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko "Revisiting Deep Learning Models for Tabular Data", 2021
290 | """
291 |
292 | def __init__(
293 | self,
294 | n_num_features: int,
295 | cat_cardinalities: List[int],
296 | d_token: int,
297 | ) -> None:
298 | """
299 | Args:
300 | n_num_features: the number of continuous features. Pass :code:`0` if there
301 | are no numerical features.
302 | cat_cardinalities: the number of unique values for each feature. See
303 | `CategoricalFeatureTokenizer` for details. Pass an empty list if there
304 | are no categorical features.
305 | d_token: the size of one token.
306 | """
307 | super().__init__()
308 | assert n_num_features >= 0, 'n_num_features must be non-negative'
309 | assert (
310 | n_num_features or cat_cardinalities
311 | ), 'at least one of n_num_features or cat_cardinalities must be positive/non-empty'
312 | self.initialization = 'uniform'
313 | self.num_tokenizer = (
314 | NumericalFeatureTokenizer(
315 | n_features=n_num_features,
316 | d_token=d_token,
317 | bias=True,
318 | initialization=self.initialization,
319 | )
320 | if n_num_features
321 | else None
322 | )
323 | self.cat_tokenizer = (
324 | CategoricalFeatureTokenizer(
325 | cat_cardinalities, d_token, True, self.initialization
326 | )
327 | if cat_cardinalities
328 | else None
329 | )
330 |
331 | @property
332 | def n_tokens(self) -> int:
333 | """The number of tokens."""
334 | return sum(
335 | x.n_tokens
336 | for x in [self.num_tokenizer, self.cat_tokenizer]
337 | if x is not None
338 | )
339 |
340 | @property
341 | def d_token(self) -> int:
342 | """The size of one token."""
343 | return (
344 | self.cat_tokenizer.d_token # type: ignore
345 | if self.num_tokenizer is None
346 | else self.num_tokenizer.d_token
347 | )
348 |
349 | def forward(self, x_num: Optional[Tensor], x_cat: Optional[Tensor]) -> Tensor:
350 | """Perform the forward pass.
351 |
352 | Args:
353 | x_num: continuous features. Must be presented if :code:`n_num_features > 0`
354 | was passed to the constructor.
355 | x_cat: categorical features (see `CategoricalFeatureTokenizer.forward` for
356 | details). Must be presented if non-empty :code:`cat_cardinalities` was
357 | passed to the constructor.
358 | Returns:
359 | tokens
360 | Raises:
361 | AssertionError: if the described requirements for the inputs are not met.
362 | """
363 | assert (
364 | x_num is not None or x_cat is not None
365 | ), 'At least one of x_num and x_cat must be presented'
366 | assert _all_or_none(
367 | [self.num_tokenizer, x_num]
368 | ), 'If self.num_tokenizer is (not) None, then x_num must (not) be None'
369 | assert _all_or_none(
370 | [self.cat_tokenizer, x_cat]
371 | ), 'If self.cat_tokenizer is (not) None, then x_cat must (not) be None'
372 | x = []
373 | if self.num_tokenizer is not None:
374 | x.append(self.num_tokenizer(x_num))
375 | if self.cat_tokenizer is not None:
376 | x.append(self.cat_tokenizer(x_cat))
377 | return x[0] if len(x) == 1 else torch.cat(x, dim=1)
378 |
379 |
380 | class CLSToken(nn.Module):
381 | """[CLS]-token for BERT-like inference.
382 |
383 | To learn about the [CLS]-based inference, see [devlin2018bert].
384 |
385 | When used as a module, the [CLS]-token is appended **to the end** of each item in
386 | the batch.
387 |
388 | Examples:
389 | .. testcode::
390 |
391 | batch_size = 2
392 | n_tokens = 3
393 | d_token = 4
394 | cls_token = CLSToken(d_token, 'uniform')
395 | x = torch.randn(batch_size, n_tokens, d_token)
396 | x = cls_token(x)
397 | assert x.shape == (batch_size, n_tokens + 1, d_token)
398 | assert (x[:, -1, :] == cls_token.expand(len(x))).all()
399 |
400 | References:
401 | * [devlin2018bert] Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" 2018
402 | """
403 |
404 | def __init__(self, d_token: int, initialization: str) -> None:
405 | """
406 | Args:
407 | d_token: the size of token
408 | initialization: initialization policy for parameters. Must be one of
409 | :code:`['uniform', 'normal']`. Let :code:`s = d ** -0.5`. Then, the
410 | corresponding distributions are :code:`Uniform(-s, s)` and :code:`Normal(0, s)`. In
411 | the paper [gorishniy2021revisiting], the 'uniform' initialization was
412 | used.
413 |
414 | References:
415 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko "Revisiting Deep Learning Models for Tabular Data", 2021
416 | """
417 | super().__init__()
418 | initialization_ = _TokenInitialization.from_str(initialization)
419 | self.weight = nn.Parameter(Tensor(d_token))
420 | initialization_.apply(self.weight, d_token)
421 |
422 | def expand(self, *leading_dimensions: int) -> Tensor:
423 | """Expand (repeat) the underlying [CLS]-token to a tensor with the given leading dimensions.
424 |
425 | A possible use case is building a batch of [CLS]-tokens. See `CLSToken` for
426 | examples of usage.
427 |
428 | Note:
429 | Under the hood, the `torch.Tensor.expand` method is applied to the
430 | underlying :code:`weight` parameter, so gradients will be propagated as
431 | expected.
432 |
433 | Args:
434 | leading_dimensions: the additional new dimensions
435 |
436 | Returns:
437 | tensor of the shape :code:`(*leading_dimensions, len(self.weight))`
438 | """
439 | if not leading_dimensions:
440 | return self.weight
441 | new_dims = (1,) * (len(leading_dimensions) - 1)
442 | return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1)
443 |
444 | def forward(self, x: Tensor) -> Tensor:
445 | """Append self **to the end** of each item in the batch (see `CLSToken`)."""
446 | return torch.cat([x, self.expand(len(x), 1)], dim=1)
447 |
448 |
449 | def _make_nn_module(module_type: ModuleType, *args) -> nn.Module:
450 | if isinstance(module_type, str):
451 | if module_type == 'ReGLU':
452 | return ReGLU()
453 | elif module_type == 'GEGLU':
454 | return GEGLU()
455 | else:
456 | try:
457 | cls = getattr(nn, module_type)
458 | except AttributeError as err:
459 | raise ValueError(
460 | f'Failed to construct the module {module_type} with the arguments {args}'
461 | ) from err
462 | return cls(*args)
463 | else:
464 | return module_type(*args)
465 |
466 |
467 | class MLP(nn.Module):
468 | """The MLP model used in [gorishniy2021revisiting].
469 |
470 | The following scheme describes the architecture:
471 |
472 | .. code-block:: text
473 |
474 | MLP: (in) -> Block -> ... -> Block -> Linear -> (out)
475 | Block: (in) -> Linear -> Activation -> Dropout -> (out)
476 |
477 | Examples:
478 | .. testcode::
479 |
480 | x = torch.randn(4, 2)
481 | module = MLP.make_baseline(x.shape[1], [3, 5], 0.1, 1)
482 | assert module(x).shape == (len(x), 1)
483 |
484 | References:
485 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
486 | """
487 |
488 | class Block(nn.Module):
489 | """The main building block of `MLP`."""
490 |
491 | def __init__(
492 | self,
493 | *,
494 | d_in: int,
495 | d_out: int,
496 | bias: bool,
497 | activation: ModuleType,
498 | dropout: float,
499 | ) -> None:
500 | super().__init__()
501 | self.linear = nn.Linear(d_in, d_out, bias)
502 | self.activation = _make_nn_module(activation)
503 | self.dropout = nn.Dropout(dropout)
504 |
505 | def forward(self, x: Tensor) -> Tensor:
506 | return self.dropout(self.activation(self.linear(x)))
507 |
508 | def __init__(
509 | self,
510 | *,
511 | d_in: int,
512 | d_layers: List[int],
513 | dropouts: Union[float, List[float]],
514 | activation: Union[str, Callable[[], nn.Module]],
515 | d_out: int,
516 | ) -> None:
517 | """
518 | Note:
519 | `make_baseline` is the recommended constructor.
520 | """
521 | super().__init__()
522 | if isinstance(dropouts, float):
523 | dropouts = [dropouts] * len(d_layers)
524 | assert len(d_layers) == len(dropouts)
525 |
526 | self.blocks = nn.Sequential(
527 | *[
528 | MLP.Block(
529 | d_in=d_layers[i - 1] if i else d_in,
530 | d_out=d,
531 | bias=True,
532 | activation=activation,
533 | dropout=dropout,
534 | )
535 | for i, (d, dropout) in enumerate(zip(d_layers, dropouts))
536 | ]
537 | )
538 | self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out)
539 |
540 | @classmethod
541 | def make_baseline(
542 | cls: Type['MLP'],
543 | d_in: int,
544 | d_layers: List[int],
545 | dropout: float,
546 | d_out: int,
547 | ) -> 'MLP':
548 | """Create a "baseline" `MLP`.
549 |
550 | This variation of MLP was used in [gorishniy2021revisiting]. Features:
551 |
552 | * :code:`Activation` = :code:`ReLU`
553 | * all linear layers except for the first one and the last one are of the same dimension
554 | * the dropout rate is the same for all dropout layers
555 |
556 | Args:
557 | d_in: the input size
558 | d_layers: the dimensions of the linear layers. If there are more than two
559 | layers, then all of them except for the first and the last ones must
560 | have the same dimension. Valid examples: :code:`[]`, :code:`[8]`,
561 | :code:`[8, 16]`, :code:`[2, 2, 2, 2]`, :code:`[1, 2, 2, 4]`. Invalid
562 | example: :code:`[1, 2, 3, 4]`.
563 | dropout: the dropout rate for all hidden layers
564 | d_out: the output size
565 | Returns:
566 | MLP
567 |
568 | References:
569 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
570 | """
571 | assert isinstance(dropout, float), 'In this constructor, dropout must be float'
572 | if len(d_layers) > 2:
573 | assert len(set(d_layers[1:-1])) == 1, (
574 | 'In this constructor, if d_layers contains more than two elements, then'
575 | ' all elements except for the first and the last ones must be equal.'
576 | )
577 | return MLP(
578 | d_in=d_in,
579 | d_layers=d_layers, # type: ignore
580 | dropouts=dropout,
581 | activation='ReLU',
582 | d_out=d_out,
583 | )
584 |
585 | def forward(self, x: Tensor) -> Tensor:
586 | x = self.blocks(x)
587 | x = self.head(x)
588 | return x
589 |
590 |
591 | class ResNet(nn.Module):
592 | """The ResNet model used in [gorishniy2021revisiting].
593 |
594 | The following scheme describes the architecture:
595 |
596 | .. code-block:: text
597 |
598 | ResNet: (in) -> Linear -> Block -> ... -> Block -> Head -> (out)
599 |
600 | |-> Norm -> Linear -> Activation -> Dropout -> Linear -> Dropout ->|
601 | | |
602 | Block: (in) ------------------------------------------------------------> Add -> (out)
603 |
604 | Head: (in) -> Norm -> Activation -> Linear -> (out)
605 |
606 | Examples:
607 | .. testcode::
608 |
609 | x = torch.randn(4, 2)
610 | module = ResNet.make_baseline(
611 | d_in=x.shape[1],
612 | n_blocks=2,
613 | d_main=3,
614 | d_hidden=4,
615 | dropout_first=0.25,
616 | dropout_second=0.0,
617 | d_out=1
618 | )
619 | assert module(x).shape == (len(x), 1)
620 |
621 | References:
622 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
623 | """
624 |
625 | class Block(nn.Module):
626 | """The main building block of `ResNet`."""
627 |
628 | def __init__(
629 | self,
630 | *,
631 | d_main: int,
632 | d_hidden: int,
633 | bias_first: bool,
634 | bias_second: bool,
635 | dropout_first: float,
636 | dropout_second: float,
637 | normalization: ModuleType,
638 | activation: ModuleType,
639 | skip_connection: bool,
640 | ) -> None:
641 | super().__init__()
642 | self.normalization = _make_nn_module(normalization, d_main)
643 | self.linear_first = nn.Linear(d_main, d_hidden, bias_first)
644 | self.activation = _make_nn_module(activation)
645 | self.dropout_first = nn.Dropout(dropout_first)
646 | self.linear_second = nn.Linear(d_hidden, d_main, bias_second)
647 | self.dropout_second = nn.Dropout(dropout_second)
648 | self.skip_connection = skip_connection
649 |
650 | def forward(self, x: Tensor) -> Tensor:
651 | x_input = x
652 | x = self.normalization(x)
653 | x = self.linear_first(x)
654 | x = self.activation(x)
655 | x = self.dropout_first(x)
656 | x = self.linear_second(x)
657 | x = self.dropout_second(x)
658 | if self.skip_connection:
659 | x = x_input + x
660 | return x
661 |
662 | class Head(nn.Module):
663 | """The final module of `ResNet`."""
664 |
665 | def __init__(
666 | self,
667 | *,
668 | d_in: int,
669 | d_out: int,
670 | bias: bool,
671 | normalization: ModuleType,
672 | activation: ModuleType,
673 | ) -> None:
674 | super().__init__()
675 | self.normalization = _make_nn_module(normalization, d_in)
676 | self.activation = _make_nn_module(activation)
677 | self.linear = nn.Linear(d_in, d_out, bias)
678 |
679 | def forward(self, x: Tensor) -> Tensor:
680 | if self.normalization is not None:
681 | x = self.normalization(x)
682 | x = self.activation(x)
683 | x = self.linear(x)
684 | return x
685 |
686 | def __init__(
687 | self,
688 | *,
689 | d_in: int,
690 | n_blocks: int,
691 | d_main: int,
692 | d_hidden: int,
693 | dropout_first: float,
694 | dropout_second: float,
695 | normalization: ModuleType,
696 | activation: ModuleType,
697 | d_out: int,
698 | ) -> None:
699 | """
700 | Note:
701 | `make_baseline` is the recommended constructor.
702 | """
703 | super().__init__()
704 |
705 | self.first_layer = nn.Linear(d_in, d_main)
706 | if d_main is None:
707 | d_main = d_in
708 | self.blocks = nn.Sequential(
709 | *[
710 | ResNet.Block(
711 | d_main=d_main,
712 | d_hidden=d_hidden,
713 | bias_first=True,
714 | bias_second=True,
715 | dropout_first=dropout_first,
716 | dropout_second=dropout_second,
717 | normalization=normalization,
718 | activation=activation,
719 | skip_connection=True,
720 | )
721 | for _ in range(n_blocks)
722 | ]
723 | )
724 | self.head = ResNet.Head(
725 | d_in=d_main,
726 | d_out=d_out,
727 | bias=True,
728 | normalization=normalization,
729 | activation=activation,
730 | )
731 |
732 | @classmethod
733 | def make_baseline(
734 | cls: Type['ResNet'],
735 | *,
736 | d_in: int,
737 | n_blocks: int,
738 | d_main: int,
739 | d_hidden: int,
740 | dropout_first: float,
741 | dropout_second: float,
742 | d_out: int,
743 | ) -> 'ResNet':
744 | """Create a "baseline" `ResNet`.
745 |
746 | This variation of ResNet was used in [gorishniy2021revisiting]. Features:
747 |
748 | * :code:`Activation` = :code:`ReLU`
749 | * :code:`Norm` = :code:`BatchNorm1d`
750 |
751 | Args:
752 | d_in: the input size
753 | n_blocks: the number of Blocks
754 | d_main: the input size (or, equivalently, the output size) of each Block
755 | d_hidden: the output size of the first linear layer in each Block
756 | dropout_first: the dropout rate of the first dropout layer in each Block.
757 | dropout_second: the dropout rate of the second dropout layer in each Block.
758 |
759 | References:
760 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
761 | """
762 | return cls(
763 | d_in=d_in,
764 | n_blocks=n_blocks,
765 | d_main=d_main,
766 | d_hidden=d_hidden,
767 | dropout_first=dropout_first,
768 | dropout_second=dropout_second,
769 | normalization='BatchNorm1d',
770 | activation='ReLU',
771 | d_out=d_out,
772 | )
773 |
774 | def forward(self, x: Tensor) -> Tensor:
775 | x = self.first_layer(x)
776 | x = self.blocks(x)
777 | x = self.head(x)
778 | return x
779 |
780 |
781 | class MultiheadAttention(nn.Module):
782 | """Multihead Attention (self-/cross-) with optional 'linear' attention.
783 |
784 | To learn more about Multihead Attention, see [devlin2018bert]. See the implementation
785 | of `Transformer` and the examples below to learn how to use the compression technique
786 | from [wang2020linformer] to speed up the module when the number of tokens is large.
787 |
788 | Examples:
789 | .. testcode::
790 |
791 | n_objects, n_tokens, d_token = 2, 3, 12
792 | n_heads = 6
793 | a = torch.randn(n_objects, n_tokens, d_token)
794 | b = torch.randn(n_objects, n_tokens * 2, d_token)
795 | module = MultiheadAttention(
796 | d_token=d_token, n_heads=n_heads, dropout=0.2, bias=True, initialization='kaiming'
797 | )
798 |
799 | # self-attention
800 | x, attention_stats = module(a, a, None, None)
801 | assert x.shape == a.shape
802 | assert attention_stats['attention_probs'].shape == (n_objects * n_heads, n_tokens, n_tokens)
803 | assert attention_stats['attention_logits'].shape == (n_objects * n_heads, n_tokens, n_tokens)
804 |
805 | # cross-attention
806 | assert module(a, b, None, None)
807 |
808 | # Linformer self-attention with the 'headwise' sharing policy
809 | k_compression = torch.nn.Linear(n_tokens, n_tokens // 4)
810 | v_compression = torch.nn.Linear(n_tokens, n_tokens // 4)
811 | assert module(a, a, k_compression, v_compression)
812 |
813 | # Linformer self-attention with the 'key-value' sharing policy
814 | kv_compression = torch.nn.Linear(n_tokens, n_tokens // 4)
815 | assert module(a, a, kv_compression, kv_compression)
816 |
817 | References:
818 | * [devlin2018bert] Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" 2018
819 | * [wang2020linformer] Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma "Linformer: Self-Attention with Linear Complexity", 2020
820 | """
821 |
822 | def __init__(
823 | self,
824 | *,
825 | d_token: int,
826 | n_heads: int,
827 | dropout: float,
828 | bias: bool,
829 | initialization: str,
830 | ) -> None:
831 | """
832 | Args:
833 | d_token: the token size. Must be a multiple of :code:`n_heads`.
834 | n_heads: the number of heads. If greater than 1, then the module will have
835 | an addition output layer (so called "mixing" layer).
836 | dropout: dropout rate for the attention map. The dropout is applied to
837 | *probabilities* and do not affect logits.
838 | bias: if `True`, then input (and output, if presented) layers also have bias.
839 | `True` is a reasonable default choice.
840 | initialization: initialization for input projection layers. Must be one of
841 | :code:`['kaiming', 'xavier']`. `kaiming` is a reasonable default choice.
842 | Raises:
843 | AssertionError: if requirements for the inputs are not met.
844 | """
845 | super().__init__()
846 | if n_heads > 1:
847 | assert d_token % n_heads == 0, 'd_token must be a multiple of n_heads'
848 | assert initialization in ['kaiming', 'xavier']
849 |
850 | self.W_q = nn.Linear(d_token, d_token, bias)
851 | self.W_k = nn.Linear(d_token, d_token, bias)
852 | self.W_v = nn.Linear(d_token, d_token, bias)
853 | self.W_out = nn.Linear(d_token, d_token, bias) if n_heads > 1 else None
854 | self.n_heads = n_heads
855 | self.dropout = nn.Dropout(dropout) if dropout else None
856 |
857 | for m in [self.W_q, self.W_k, self.W_v]:
858 | # the "xavier" branch tries to follow torch.nn.MultiheadAttention;
859 | # the second condition checks if W_v plays the role of W_out; the latter one
860 | # is initialized with Kaiming in torch
861 | if initialization == 'xavier' and (
862 | m is not self.W_v or self.W_out is not None
863 | ):
864 | # gain is needed since W_qkv is represented with 3 separate layers (it
865 | # implies different fan_out)
866 | nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
867 | if m.bias is not None:
868 | nn.init.zeros_(m.bias)
869 | if self.W_out is not None:
870 | nn.init.zeros_(self.W_out.bias)
871 |
872 | def _reshape(self, x: Tensor) -> Tensor:
873 | batch_size, n_tokens, d = x.shape
874 | d_head = d // self.n_heads
875 | return (
876 | x.reshape(batch_size, n_tokens, self.n_heads, d_head)
877 | .transpose(1, 2)
878 | .reshape(batch_size * self.n_heads, n_tokens, d_head)
879 | )
880 |
881 | def forward(
882 | self,
883 | x_q: Tensor,
884 | x_kv: Tensor,
885 | key_compression: Optional[nn.Linear],
886 | value_compression: Optional[nn.Linear],
887 | ) -> Tuple[Tensor, Dict[str, Tensor]]:
888 | """Perform the forward pass.
889 |
890 | Args:
891 | x_q: query tokens
892 | x_kv: key-value tokens
893 | key_compression: Linformer-style compression for keys
894 | value_compression: Linformer-style compression for values
895 | Returns:
896 | (tokens, attention_stats)
897 | """
898 | assert _all_or_none(
899 | [key_compression, value_compression]
900 | ), 'If key_compression is (not) None, then value_compression must (not) be None'
901 | q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
902 | for tensor in [q, k, v]:
903 | assert tensor.shape[-1] % self.n_heads == 0, _INTERNAL_ERROR_MESSAGE
904 | if key_compression is not None:
905 | k = key_compression(k.transpose(1, 2)).transpose(1, 2)
906 | v = value_compression(v.transpose(1, 2)).transpose(1, 2) # type: ignore
907 |
908 | batch_size = len(q)
909 | d_head_key = k.shape[-1] // self.n_heads
910 | d_head_value = v.shape[-1] // self.n_heads
911 | n_q_tokens = q.shape[1]
912 |
913 | q = self._reshape(q)
914 | k = self._reshape(k)
915 | attention_logits = q @ k.transpose(1, 2) / math.sqrt(d_head_key)
916 | attention_probs = F.softmax(attention_logits, dim=-1)
917 | if self.dropout is not None:
918 | attention_probs = self.dropout(attention_probs)
919 | x = attention_probs @ self._reshape(v)
920 | x = (
921 | x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
922 | .transpose(1, 2)
923 | .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
924 | )
925 | if self.W_out is not None:
926 | x = self.W_out(x)
927 | return x, {
928 | 'attention_logits': attention_logits,
929 | 'attention_probs': attention_probs,
930 | }
931 |
932 |
933 | class Transformer(nn.Module):
934 | """Transformer with extra features.
935 |
936 | This module is the backbone of `FTTransformer`."""
937 |
938 | WARNINGS: ClassVar[Dict[str, bool]] = {
939 | 'first_prenormalization': True,
940 | 'prenormalization': True,
941 | }
942 |
943 | class FFN(nn.Module):
944 | """The Feed-Forward Network module used in every `Transformer` block."""
945 |
946 | def __init__(
947 | self,
948 | *,
949 | d_token: int,
950 | d_hidden: int,
951 | bias_first: bool,
952 | bias_second: bool,
953 | dropout: float,
954 | activation: ModuleType,
955 | ):
956 | super().__init__()
957 | self.linear_first = nn.Linear(
958 | d_token,
959 | d_hidden * (2 if _is_glu_activation(activation) else 1),
960 | bias_first,
961 | )
962 | self.activation = _make_nn_module(activation)
963 | self.dropout = nn.Dropout(dropout)
964 | self.linear_second = nn.Linear(d_hidden, d_token, bias_second)
965 |
966 | def forward(self, x: Tensor) -> Tensor:
967 | x = self.linear_first(x)
968 | x = self.activation(x)
969 | x = self.dropout(x)
970 | x = self.linear_second(x)
971 | return x
972 |
973 | class Head(nn.Module):
974 | """The final module of the `Transformer` that performs BERT-like inference."""
975 |
976 | def __init__(
977 | self,
978 | *,
979 | d_in: int,
980 | bias: bool,
981 | activation: ModuleType,
982 | normalization: ModuleType,
983 | d_out: int,
984 | ):
985 | super().__init__()
986 | self.normalization = _make_nn_module(normalization, d_in)
987 | self.activation = _make_nn_module(activation)
988 | self.linear = nn.Linear(d_in, d_out, bias)
989 |
990 | def forward(self, x: Tensor) -> Tensor:
991 | x = x[:, -1]
992 | x = self.normalization(x)
993 | x = self.activation(x)
994 | x = self.linear(x)
995 | return x
996 |
997 | def __init__(
998 | self,
999 | *,
1000 | d_token: int,
1001 | n_blocks: int,
1002 | attention_n_heads: int,
1003 | attention_dropout: float,
1004 | attention_initialization: str,
1005 | attention_normalization: str,
1006 | ffn_d_hidden: int,
1007 | ffn_dropout: float,
1008 | ffn_activation: str,
1009 | ffn_normalization: str,
1010 | residual_dropout: float,
1011 | prenormalization: bool,
1012 | first_prenormalization: bool,
1013 | last_layer_query_idx: Union[None, List[int], slice],
1014 | n_tokens: Optional[int],
1015 | kv_compression_ratio: Optional[float],
1016 | kv_compression_sharing: Optional[str],
1017 | head_activation: ModuleType,
1018 | head_normalization: ModuleType,
1019 | d_out: int,
1020 | ) -> None:
1021 | super().__init__()
1022 | if isinstance(last_layer_query_idx, int):
1023 | raise ValueError(
1024 | 'last_layer_query_idx must be None, list[int] or slice. '
1025 | f'Do you mean last_layer_query_idx=[{last_layer_query_idx}] ?'
1026 | )
1027 | if not prenormalization:
1028 | assert (
1029 | not first_prenormalization
1030 | ), 'If `prenormalization` is False, then `first_prenormalization` must be False'
1031 | assert _all_or_none([n_tokens, kv_compression_ratio, kv_compression_sharing]), (
1032 | 'If any of the following arguments is (not) None, then all of them must (not) be None: '
1033 | 'n_tokens, kv_compression_ratio, kv_compression_sharing'
1034 | )
1035 | assert kv_compression_sharing in [None, 'headwise', 'key-value', 'layerwise']
1036 | if not prenormalization:
1037 | if self.WARNINGS['prenormalization']:
1038 | warnings.warn(
1039 | 'prenormalization is set to False. Are you sure about this? '
1040 | 'The training can become less stable. '
1041 | 'You can turn off this warning by tweaking the '
1042 | 'rtdl.Transformer.WARNINGS dictionary.',
1043 | UserWarning,
1044 | )
1045 | assert (
1046 | not first_prenormalization
1047 | ), 'If prenormalization is False, then first_prenormalization is ignored and must be set to False'
1048 | if (
1049 | prenormalization
1050 | and first_prenormalization
1051 | and self.WARNINGS['first_prenormalization']
1052 | ):
1053 | warnings.warn(
1054 | 'first_prenormalization is set to True. Are you sure about this? '
1055 | 'For example, the vanilla FTTransformer with '
1056 | 'first_prenormalization=True performs SIGNIFICANTLY worse. '
1057 | 'You can turn off this warning by tweaking the '
1058 | 'rtdl.Transformer.WARNINGS dictionary.',
1059 | UserWarning,
1060 | )
1061 | time.sleep(3)
1062 |
1063 | def make_kv_compression():
1064 | assert (
1065 | n_tokens and kv_compression_ratio
1066 | ), _INTERNAL_ERROR_MESSAGE # for mypy
1067 | # https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/examples/linformer/linformer_src/modules/multihead_linear_attention.py#L83
1068 | return nn.Linear(n_tokens, int(n_tokens * kv_compression_ratio), bias=False)
1069 |
1070 | self.shared_kv_compression = (
1071 | make_kv_compression()
1072 | if kv_compression_ratio and kv_compression_sharing == 'layerwise'
1073 | else None
1074 | )
1075 |
1076 | self.prenormalization = prenormalization
1077 | self.last_layer_query_idx = last_layer_query_idx
1078 |
1079 | self.blocks = nn.ModuleList([])
1080 | for layer_idx in range(n_blocks):
1081 | layer = nn.ModuleDict(
1082 | {
1083 | 'attention': MultiheadAttention(
1084 | d_token=d_token,
1085 | n_heads=attention_n_heads,
1086 | dropout=attention_dropout,
1087 | bias=True,
1088 | initialization=attention_initialization,
1089 | ),
1090 | 'ffn': Transformer.FFN(
1091 | d_token=d_token,
1092 | d_hidden=ffn_d_hidden,
1093 | bias_first=True,
1094 | bias_second=True,
1095 | dropout=ffn_dropout,
1096 | activation=ffn_activation,
1097 | ),
1098 | 'attention_residual_dropout': nn.Dropout(residual_dropout),
1099 | 'ffn_residual_dropout': nn.Dropout(residual_dropout),
1100 | 'output': nn.Identity(), # for hooks-based introspection
1101 | }
1102 | )
1103 | if layer_idx or not prenormalization or first_prenormalization:
1104 | layer['attention_normalization'] = _make_nn_module(
1105 | attention_normalization, d_token
1106 | )
1107 | layer['ffn_normalization'] = _make_nn_module(ffn_normalization, d_token)
1108 | if kv_compression_ratio and self.shared_kv_compression is None:
1109 | layer['key_compression'] = make_kv_compression()
1110 | if kv_compression_sharing == 'headwise':
1111 | layer['value_compression'] = make_kv_compression()
1112 | else:
1113 | assert (
1114 | kv_compression_sharing == 'key-value'
1115 | ), _INTERNAL_ERROR_MESSAGE
1116 | self.blocks.append(layer)
1117 |
1118 | self.head = Transformer.Head(
1119 | d_in=d_token,
1120 | d_out=d_out,
1121 | bias=True,
1122 | activation=head_activation, # type: ignore
1123 | normalization=head_normalization if prenormalization else 'Identity',
1124 | )
1125 |
1126 | def _get_kv_compressions(self, layer):
1127 | return (
1128 | (self.shared_kv_compression, self.shared_kv_compression)
1129 | if self.shared_kv_compression is not None
1130 | else (layer['key_compression'], layer['value_compression'])
1131 | if 'key_compression' in layer and 'value_compression' in layer
1132 | else (layer['key_compression'], layer['key_compression'])
1133 | if 'key_compression' in layer
1134 | else (None, None)
1135 | )
1136 |
1137 | def _start_residual(self, layer, stage, x):
1138 | assert stage in ['attention', 'ffn'], _INTERNAL_ERROR_MESSAGE
1139 | x_residual = x
1140 | if self.prenormalization:
1141 | norm_key = f'{stage}_normalization'
1142 | if norm_key in layer:
1143 | x_residual = layer[norm_key](x_residual)
1144 | return x_residual
1145 |
1146 | def _end_residual(self, layer, stage, x, x_residual):
1147 | assert stage in ['attention', 'ffn'], _INTERNAL_ERROR_MESSAGE
1148 | x_residual = layer[f'{stage}_residual_dropout'](x_residual)
1149 | x = x + x_residual
1150 | if not self.prenormalization:
1151 | x = layer[f'{stage}_normalization'](x)
1152 | return x
1153 |
1154 | def forward(self, x: Tensor) -> Tensor:
1155 | assert (
1156 | x.ndim == 3
1157 | ), 'The input must have 3 dimensions: (n_objects, n_tokens, d_token)'
1158 | for layer_idx, layer in enumerate(self.blocks):
1159 | layer = cast(nn.ModuleDict, layer)
1160 |
1161 | query_idx = (
1162 | self.last_layer_query_idx if layer_idx + 1 == len(self.blocks) else None
1163 | )
1164 | x_residual = self._start_residual(layer, 'attention', x)
1165 | x_residual, _ = layer['attention'](
1166 | x_residual if query_idx is None else x_residual[:, query_idx],
1167 | x_residual,
1168 | *self._get_kv_compressions(layer),
1169 | )
1170 | if query_idx is not None:
1171 | x = x[:, query_idx]
1172 | x = self._end_residual(layer, 'attention', x, x_residual)
1173 |
1174 | x_residual = self._start_residual(layer, 'ffn', x)
1175 | x_residual = layer['ffn'](x_residual)
1176 | x = self._end_residual(layer, 'ffn', x, x_residual)
1177 | x = layer['output'](x)
1178 |
1179 | x = self.head(x)
1180 | return x
1181 |
1182 |
1183 | class FTTransformer(nn.Module):
1184 | """The FT-Transformer model proposed in [gorishniy2021revisiting].
1185 |
1186 | Transforms features to tokens with `FeatureTokenizer` and applies `Transformer` [vaswani2017attention]
1187 | to the tokens. The following illustration provides a high-level overview of the
1188 | architecture:
1189 |
1190 | .. image:: ../images/ft_transformer.png
1191 | :scale: 25%
1192 | :alt: FT-Transformer
1193 |
1194 | The following illustration demonstrates one Transformer block for :code:`prenormalization=True`:
1195 |
1196 | .. image:: ../images/transformer_block.png
1197 | :scale: 25%
1198 | :alt: PreNorm Transformer block
1199 |
1200 | Examples:
1201 | .. testcode::
1202 |
1203 | x_num = torch.randn(4, 3)
1204 | x_cat = torch.tensor([[0, 1], [1, 0], [0, 2], [1, 1]])
1205 |
1206 | module = FTTransformer.make_baseline(
1207 | n_num_features=3,
1208 | cat_cardinalities=[2, 3],
1209 | d_token=8,
1210 | n_blocks=2,
1211 | attention_dropout=0.2,
1212 | ffn_d_hidden=6,
1213 | ffn_dropout=0.2,
1214 | residual_dropout=0.0,
1215 | d_out=1,
1216 | )
1217 | x = module(x_num, x_cat)
1218 | assert x.shape == (4, 1)
1219 |
1220 | module = FTTransformer.make_default(
1221 | n_num_features=3,
1222 | cat_cardinalities=[2, 3],
1223 | d_out=1,
1224 | )
1225 | x = module(x_num, x_cat)
1226 | assert x.shape == (4, 1)
1227 |
1228 | To learn more about the baseline and default parameters:
1229 |
1230 | .. testcode::
1231 |
1232 | baseline_parameters = FTTransformer.get_baseline_transformer_subconfig()
1233 | default_parameters = FTTransformer.get_default_transformer_config()
1234 |
1235 | References:
1236 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
1237 | * [vaswani2017attention] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, "Attention Is All You Need", 2017
1238 | """
1239 |
1240 | def __init__(
1241 | self, feature_tokenizer: FeatureTokenizer, transformer: Transformer
1242 | ) -> None:
1243 | """
1244 | Note:
1245 | `make_baseline` and `make_default` are the recommended constructors.
1246 | """
1247 | super().__init__()
1248 | if transformer.prenormalization:
1249 | assert 'attention_normalization' not in transformer.blocks[0], (
1250 | 'In the prenormalization setting, FT-Transformer does not '
1251 | 'allow using the first normalization layer '
1252 | 'in the first transformer block'
1253 | )
1254 | self.feature_tokenizer = feature_tokenizer
1255 | self.cls_token = CLSToken(
1256 | feature_tokenizer.d_token, feature_tokenizer.initialization
1257 | )
1258 | self.transformer = transformer
1259 |
1260 | @classmethod
1261 | def get_baseline_transformer_subconfig(
1262 | cls: Type['FTTransformer'],
1263 | ) -> Dict[str, Any]:
1264 | """Get the baseline subset of parameters for the backbone."""
1265 | return {
1266 | 'attention_n_heads': 8,
1267 | 'attention_initialization': 'kaiming',
1268 | 'ffn_activation': 'ReGLU',
1269 | 'attention_normalization': 'LayerNorm',
1270 | 'ffn_normalization': 'LayerNorm',
1271 | 'prenormalization': True,
1272 | 'first_prenormalization': False,
1273 | 'last_layer_query_idx': None,
1274 | 'n_tokens': None,
1275 | 'kv_compression_ratio': None,
1276 | 'kv_compression_sharing': None,
1277 | 'head_activation': 'ReLU',
1278 | 'head_normalization': 'LayerNorm',
1279 | }
1280 |
1281 | @classmethod
1282 | def get_default_transformer_config(
1283 | cls: Type['FTTransformer'], *, n_blocks: int = 3
1284 | ) -> Dict[str, Any]:
1285 | """Get the default parameters for the backbone.
1286 |
1287 | Note:
1288 | The configurations are different for different values of:code:`n_blocks`.
1289 | """
1290 | assert 1 <= n_blocks <= 6
1291 | grid = {
1292 | 'd_token': [96, 128, 192, 256, 320, 384],
1293 | 'attention_dropout': [0.1, 0.15, 0.2, 0.25, 0.3, 0.35],
1294 | 'ffn_dropout': [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
1295 | }
1296 | arch_subconfig = {k: v[n_blocks - 1] for k, v in grid.items()} # type: ignore
1297 | baseline_subconfig = cls.get_baseline_transformer_subconfig()
1298 | # (4 / 3) for ReGLU/GEGLU activations results in almost the same parameter count
1299 | # as (2.0) for element-wise activations (e.g. ReLU or GELU; see the "else" branch)
1300 | ffn_d_hidden_factor = (
1301 | (4 / 3) if _is_glu_activation(baseline_subconfig['ffn_activation']) else 2.0
1302 | )
1303 | return {
1304 | 'n_blocks': n_blocks,
1305 | 'residual_dropout': 0.0,
1306 | 'ffn_d_hidden': int(arch_subconfig['d_token'] * ffn_d_hidden_factor),
1307 | **arch_subconfig,
1308 | **baseline_subconfig,
1309 | }
1310 |
1311 | @classmethod
1312 | def _make(
1313 | cls,
1314 | n_num_features,
1315 | cat_cardinalities,
1316 | transformer_config,
1317 | ):
1318 | feature_tokenizer = FeatureTokenizer(
1319 | n_num_features=n_num_features,
1320 | cat_cardinalities=cat_cardinalities,
1321 | d_token=transformer_config['d_token'],
1322 | )
1323 | if transformer_config['d_out'] is None:
1324 | transformer_config['head_activation'] = None
1325 | if transformer_config['kv_compression_ratio'] is not None:
1326 | transformer_config['n_tokens'] = feature_tokenizer.n_tokens + 1
1327 | return FTTransformer(
1328 | feature_tokenizer,
1329 | Transformer(**transformer_config),
1330 | )
1331 |
1332 | @classmethod
1333 | def make_baseline(
1334 | cls: Type['FTTransformer'],
1335 | *,
1336 | n_num_features: int,
1337 | cat_cardinalities: Optional[List[int]],
1338 | d_token: int,
1339 | n_blocks: int,
1340 | attention_dropout: float,
1341 | ffn_d_hidden: int,
1342 | ffn_dropout: float,
1343 | residual_dropout: float,
1344 | last_layer_query_idx: Union[None, List[int], slice] = None,
1345 | kv_compression_ratio: Optional[float] = None,
1346 | kv_compression_sharing: Optional[str] = None,
1347 | d_out: int,
1348 | ) -> 'FTTransformer':
1349 | """Create a "baseline" `FTTransformer`.
1350 |
1351 | This variation of FT-Transformer was used in [gorishniy2021revisiting]. See
1352 | `get_baseline_transformer_subconfig` to learn the values of other parameters.
1353 | See `FTTransformer` for usage examples.
1354 |
1355 | Tip:
1356 | `get_default_transformer_config` can serve as a starting point for choosing
1357 | hyperparameter values.
1358 |
1359 | Args:
1360 | n_num_features: the number of continuous features
1361 | cat_cardinalities: the cardinalities of categorical features (see
1362 | `CategoricalFeatureTokenizer` to learn more about cardinalities)
1363 | d_token: the token size for each feature. Must be a multiple of :code:`n_heads=8`.
1364 | n_blocks: the number of Transformer blocks
1365 | attention_dropout: the dropout for attention blocks (see `MultiheadAttention`).
1366 | Usually, positive values work better (even when the number of features is low).
1367 | ffn_d_hidden: the *input* size for the *second* linear layer in `Transformer.FFN`.
1368 | Note that it can be different from the output size of the first linear
1369 | layer, since activations such as ReGLU or GEGLU change the size of input.
1370 | For example, if :code:`ffn_d_hidden=10` and the activation is ReGLU (which
1371 | is always true for the baseline and default configurations), then the
1372 | output size of the first linear layer will be set to :code:`20`.
1373 | ffn_dropout: the dropout rate after the first linear layer in `Transformer.FFN`.
1374 | residual_dropout: the dropout rate for the output of each residual branch of
1375 | all Transformer blocks.
1376 | last_layer_query_idx: indices of tokens that should be processed by the last
1377 | Transformer block. Note that for most cases there is no need to apply
1378 | the last Transformer block to anything except for the [CLS]-token. Hence,
1379 | runtime and memory can be saved by setting :code:`last_layer_query_idx=[-1]`,
1380 | since the :code:`-1` is the position of [CLS]-token in FT-Transformer.
1381 | Note that this will not affect the result in any way.
1382 | kv_compression_ratio: apply the technique from [wang2020linformer] to speed
1383 | up attention modules when the number of features is large. Can actually
1384 | slow things down if the number of features is too low. Note that this
1385 | option can affect task metrics in unpredictable way. Overall, use this
1386 | option with caution. See `MultiheadAttention` for some examples and the
1387 | implementation of `Transformer` to see how this option is used.
1388 | kv_compression_sharing: weight sharing policy for :code:`kv_compression_ratio`.
1389 | Must be one of :code:`[None, 'headwise', 'key-value', 'layerwise']`.
1390 | See [wang2020linformer] to learn more about sharing policies.
1391 | :code:`headwise` and :code:`key-value` are reasonable default choices. If
1392 | :code:`kv_compression_ratio` is `None`, then this parameter also must be
1393 | `None`. Otherwise, it must not be `None` (compression parameters must be
1394 | shared in some way).
1395 | d_out: the output size.
1396 |
1397 | References:
1398 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
1399 | * [wang2020linformer] Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma "Linformer: Self-Attention with Linear Complexity", 2020
1400 | """
1401 | transformer_config = cls.get_baseline_transformer_subconfig()
1402 | for arg_name in [
1403 | 'n_blocks',
1404 | 'd_token',
1405 | 'attention_dropout',
1406 | 'ffn_d_hidden',
1407 | 'ffn_dropout',
1408 | 'residual_dropout',
1409 | 'last_layer_query_idx',
1410 | 'kv_compression_ratio',
1411 | 'kv_compression_sharing',
1412 | 'd_out',
1413 | ]:
1414 | transformer_config[arg_name] = locals()[arg_name]
1415 | return cls._make(n_num_features, cat_cardinalities, transformer_config)
1416 |
1417 | @classmethod
1418 | def make_default(
1419 | cls: Type['FTTransformer'],
1420 | *,
1421 | n_num_features: int,
1422 | cat_cardinalities: Optional[List[int]],
1423 | n_blocks: int = 3,
1424 | last_layer_query_idx: Union[None, List[int], slice] = None,
1425 | kv_compression_ratio: Optional[float] = None,
1426 | kv_compression_sharing: Optional[str] = None,
1427 | d_out: int,
1428 | ) -> 'FTTransformer':
1429 | """Create the default `FTTransformer`.
1430 |
1431 | With :code:`n_blocks=3` (default) it is the FT-Transformer variation that is
1432 | referred to as "default FT-Transformer" in [gorishniy2021revisiting]. See
1433 | `FTTransformer` for usage examples. See `FTTransformer.make_baseline` for
1434 | parameter descriptions.
1435 |
1436 | Note:
1437 | The second component of the default FT-Transformer is the default optimizer,
1438 | which can be created with the `make_default_optimizer` method.
1439 |
1440 | Note:
1441 | According to [gorishniy2021revisiting], the default FT-Transformer is
1442 | effective in the ensembling mode (i.e. when predictions of several default
1443 | FT-Transformers are averaged). For a single FT-Transformer, it is still
1444 | possible to achieve better results by tuning hyperparameter for the
1445 | `make_baseline` constructor.
1446 |
1447 | References:
1448 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021
1449 | """
1450 | transformer_config = cls.get_default_transformer_config(n_blocks=n_blocks)
1451 | for arg_name in [
1452 | 'last_layer_query_idx',
1453 | 'kv_compression_ratio',
1454 | 'kv_compression_sharing',
1455 | 'd_out',
1456 | ]:
1457 | transformer_config[arg_name] = locals()[arg_name]
1458 | return cls._make(n_num_features, cat_cardinalities, transformer_config)
1459 |
1460 | def optimization_param_groups(self) -> List[Dict[str, Any]]:
1461 | """The replacement for :code:`.parameters()` when creating optimizers.
1462 |
1463 | Example::
1464 |
1465 | optimizer = AdamW(
1466 | model.optimization_param_groups(), lr=1e-4, weight_decay=1e-5
1467 | )
1468 | """
1469 | no_wd_names = ['feature_tokenizer', 'normalization', '.bias']
1470 | assert isinstance(
1471 | getattr(self, no_wd_names[0], None), FeatureTokenizer
1472 | ), _INTERNAL_ERROR_MESSAGE
1473 | assert (
1474 | sum(1 for name, _ in self.named_modules() if no_wd_names[1] in name)
1475 | == len(self.transformer.blocks) * 2
1476 | - int('attention_normalization' not in self.transformer.blocks[0]) # type: ignore
1477 | + 1
1478 | ), _INTERNAL_ERROR_MESSAGE
1479 |
1480 | def needs_wd(name):
1481 | return all(x not in name for x in no_wd_names)
1482 |
1483 | return [
1484 | {'params': [v for k, v in self.named_parameters() if needs_wd(k)]},
1485 | {
1486 | 'params': [v for k, v in self.named_parameters() if not needs_wd(k)],
1487 | 'weight_decay': 0.0,
1488 | },
1489 | ]
1490 |
1491 | def make_default_optimizer(self) -> torch.optim.AdamW:
1492 | """Make the optimizer for the default FT-Transformer."""
1493 | return torch.optim.AdamW(
1494 | self.optimization_param_groups(),
1495 | lr=1e-4,
1496 | weight_decay=1e-5,
1497 | )
1498 |
1499 | def forward(self, x_num: Optional[Tensor], x_cat: Optional[Tensor]) -> Tensor:
1500 | x = self.feature_tokenizer(x_num, x_cat)
1501 | x = self.cls_token(x)
1502 | x = self.transformer(x)
1503 | return x
1504 |
--------------------------------------------------------------------------------