├── rtdl ├── tests │ ├── __init__.py │ ├── test_data.py │ ├── test_modules.py │ └── test_vs_paper.py ├── _utils.py ├── __init__.py ├── functional.py ├── data.py └── modules.py ├── environment.yaml ├── Makefile ├── pyproject.toml ├── README.md ├── .gitignore └── LICENSE /rtdl/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """@private""" 2 | -------------------------------------------------------------------------------- /rtdl/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | from typing_extensions import ParamSpec 4 | 5 | INTERNAL_ERROR_MESSAGE = ( 6 | 'Internal error. Please, open an issue here:' 7 | ' https://github.com/Yura52/rtdl/issues/new' 8 | ) 9 | 10 | 11 | def all_or_none(values): 12 | return all(x is None for x in values) or all(x is not None for x in values) 13 | 14 | 15 | P = ParamSpec('P') 16 | T = TypeVar('T') 17 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: rtdl 2 | channels: 3 | - https://conda.anaconda.org/conda-forge 4 | - nodefaults 5 | dependencies: 6 | - black=23.10.1 7 | - flit=3.9.0 8 | - isort=5.12.0 9 | - jupyterlab<5 10 | - mypy=1.6.1 11 | - numpy=1.19.5 12 | - pdoc=14.1.0 13 | - pip<24 14 | - pytest=7.4.2 15 | - python=3.8 16 | - pytorch=1.8.0 17 | - scikit-learn=1.0.2 18 | - ruff=0.1.7 19 | - tomli=2.0.1 20 | - tqdm=4.66.1 21 | - xdoctest=1.1.2 22 | - pip: 23 | - delu==0.0.23 24 | -------------------------------------------------------------------------------- /rtdl/__init__.py: -------------------------------------------------------------------------------- 1 | """Research on tabular deep learning.""" 2 | 3 | __version__ = '0.0.14.dev7' 4 | 5 | import warnings 6 | 7 | warnings.warn( 8 | 'The rtdl package is deprecated. See the GitHub repository to learn more.', 9 | DeprecationWarning, 10 | ) 11 | 12 | from . import data # noqa: F401 13 | from .functional import geglu, reglu # noqa: F401 14 | from .modules import ( # noqa: F401 15 | GEGLU, 16 | MLP, 17 | CategoricalFeatureTokenizer, 18 | CLSToken, 19 | FeatureTokenizer, 20 | FTTransformer, 21 | MultiheadAttention, 22 | NumericalFeatureTokenizer, 23 | ReGLU, 24 | ResNet, 25 | Transformer, 26 | ) 27 | -------------------------------------------------------------------------------- /rtdl/functional.py: -------------------------------------------------------------------------------- 1 | """@private""" 2 | import torch.nn.functional as F 3 | from torch import Tensor 4 | 5 | 6 | def reglu(x: Tensor) -> Tensor: 7 | """The ReGLU activation function from [1]. 8 | 9 | References: 10 | 11 | [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 12 | """ 13 | assert x.shape[-1] % 2 == 0 14 | a, b = x.chunk(2, dim=-1) 15 | return a * F.relu(b) 16 | 17 | 18 | def geglu(x: Tensor) -> Tensor: 19 | """The GEGLU activation function from [1]. 20 | 21 | References: 22 | 23 | [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020 24 | """ 25 | assert x.shape[-1] % 2 == 0 26 | a, b = x.chunk(2, dim=-1) 27 | return a * F.gelu(b) 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: default clean doctest lint pre-commit typecheck 2 | 3 | PACKAGE_ROOT = rtdl 4 | 5 | default: 6 | echo "Hello, World!" 7 | 8 | clean: 9 | find . -type f -name "*.py[co]" -delete -o -type d -name __pycache__ -delete 10 | rm -rf .ipynb_checkpoints 11 | rm -rf .mypy_cache 12 | rm -rf .pytest_cache 13 | rm -rf .ruff_cache 14 | rm -rf dist 15 | 16 | doctest: 17 | xdoctest $(PACKAGE_ROOT) 18 | python test_code_blocks.py rtdl/revisiting_models/README.md 19 | python test_code_blocks.py rtdl/num_embeddings/README.md 20 | 21 | lint: 22 | isort $(PACKAGE_ROOT) --check-only 23 | black $(PACKAGE_ROOT) --check 24 | ruff check . 25 | 26 | # The order is important. 27 | pre-commit: clean lint doctest typecheck 28 | 29 | typecheck: 30 | mypy $(PACKAGE_ROOT) 31 | -------------------------------------------------------------------------------- /rtdl/tests/test_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import rtdl.data 5 | 6 | 7 | def test_get_category_sizes(): 8 | get_category_sizes = rtdl.data.get_category_sizes 9 | 10 | # not two dimensions 11 | with pytest.raises(ValueError): 12 | get_category_sizes(np.array([0, 0, 0])) 13 | with pytest.raises(ValueError): 14 | get_category_sizes(np.array([[[0, 0, 0]]])) 15 | 16 | # not signed integers 17 | for dtype in [np.uint32, np.float32, str]: 18 | with pytest.raises(ValueError): 19 | get_category_sizes(np.array([[0, 0, 0]], dtype=dtype)) 20 | 21 | # non-zero min value 22 | for x in [-1, 1]: 23 | with pytest.raises(ValueError): 24 | get_category_sizes(np.array([[0, 0, x]])) 25 | 26 | # not full range 27 | with pytest.raises(ValueError): 28 | get_category_sizes(np.array([[0, 0, 0], [2, 1, 1]])) 29 | 30 | # correctness 31 | assert get_category_sizes(np.array([[0]])) == [1] 32 | assert get_category_sizes(np.array([[0], [1]])) == [2] 33 | assert get_category_sizes(np.array([[0, 0]])) == [1, 1] 34 | assert get_category_sizes(np.array([[0, 0], [0, 1]])) == [1, 2] 35 | assert get_category_sizes(np.array([[1, 0], [0, 1]])) == [2, 2] 36 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "flit_core.buildapi" 3 | requires = ["flit_core >=3.2,<4"] 4 | 5 | [project] 6 | authors = [{ name = "Yury Gorishniy" }] 7 | classifiers = [ 8 | "Intended Audience :: Developers", 9 | "Intended Audience :: Science/Research", 10 | "License :: OSI Approved :: Apache Software License", 11 | "Programming Language :: Python :: 3", 12 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 13 | "Topic :: Software Development :: Libraries :: Python Modules", 14 | ] 15 | dependencies = ["torch >=1.8,<3"] 16 | dynamic = ["version", "description"] 17 | keywords = [ 18 | "artificial intelligence", 19 | "deep learning", 20 | "library", 21 | "python", 22 | "pytorch", 23 | "research", 24 | "torch", 25 | "tabular", 26 | "tabular data", 27 | ] 28 | license = { file = "LICENSE" } 29 | name = "rtdl" 30 | requires-python = ">=3.8" 31 | 32 | [project.urls] 33 | Code = "https://github.com/yandex-research/rtdl" 34 | Documentation = "https://github.com/yandex-research/rtdl" 35 | 36 | [tool.black] 37 | skip_string_normalization = true 38 | 39 | [tool.flit.module] 40 | name = "rtdl" 41 | 42 | [tool.isort] 43 | profile = "black" 44 | multi_line_output = 3 45 | known_first_party = ["rtdl"] 46 | 47 | [tool.mypy] 48 | check_untyped_defs = true 49 | ignore_missing_imports = true 50 | 51 | [tool.ruff] 52 | line-length = 88 53 | extend-select = ["RUF", "UP", "E101", "E501"] 54 | target-version = "py38" 55 | 56 | [tool.ruff.per-file-ignores] 57 | "rtdl/_utils.py" = ["E501"] 58 | "rtdl/data.py" = ["E501"] 59 | "rtdl/modules.py" = ["E501"] 60 | "rtdl/nn/*" = ["E501"] 61 | 62 | [[tool.mypy.overrides]] 63 | module = "rtdl.*.tests.*" 64 | ignore_errors = true 65 | -------------------------------------------------------------------------------- /rtdl/data.py: -------------------------------------------------------------------------------- 1 | """Tools for data (pre)processing. @private""" 2 | 3 | __all__ = ['get_category_sizes'] 4 | 5 | from typing import List, TypeVar 6 | 7 | import numpy as np 8 | 9 | Number = TypeVar('Number', int, float) 10 | 11 | 12 | def get_category_sizes(X: np.ndarray) -> List[int]: 13 | """Validate encoded categorical features and count distinct values. 14 | 15 | The function calculates the "category sizes" that can be used to construct 16 | `rtdl.CategoricalFeatureTokenizer` and `rtdl.FTTransformer`. Additionally, the 17 | following conditions are checked: 18 | 19 | * the data is a two-dimensional array of signed integers 20 | * distinct values of each column form zero-based ranges 21 | 22 | Note: 23 | For valid inputs, the result equals :code:`X.max(0) + 1`. 24 | 25 | Args: 26 | X: encoded categorical features (e.g. the output of :code:`sklearn.preprocessing.OrdinalEncoder`) 27 | 28 | Returns: 29 | The counts of distinct values for all columns. 30 | 31 | Examples: 32 | .. testcode:: 33 | 34 | assert get_category_sizes(np.array( 35 | [ 36 | [0, 0, 0], 37 | [1, 0, 0], 38 | [2, 1, 0], 39 | ] 40 | )) == [3, 2, 1] 41 | """ 42 | if X.ndim != 2: 43 | raise ValueError('X must be two-dimensional') 44 | if not issubclass(X.dtype.type, np.signedinteger): 45 | raise ValueError('X data type must be integer') 46 | sizes = [] 47 | for i, column in enumerate(X.T): 48 | unique_values = np.unique(column) 49 | min_value = unique_values.min() 50 | if min_value != 0: 51 | raise ValueError( 52 | f'The minimum value of column {i} is {min_value}, but it must be zero.' 53 | ) 54 | max_value = unique_values.max() 55 | if max_value + 1 != len(unique_values): 56 | raise ValueError( 57 | f'The values of column {i} do not fully cover the range from zero to maximum_value={max_value}' 58 | ) 59 | 60 | sizes.append(len(unique_values)) 61 | return sizes 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RTDL (Research on Tabular Deep Learning) 2 | 3 | RTDL (**R**esearch on **T**abular **D**eep **L**earning) is a collection of papers and packages 4 | on deep learning for tabular data. 5 | 6 | :bell: *To follow announcements on new projects, subscribe to releases in this GitHub repository: 7 | "Watch -> Custom -> Releases".* 8 | 9 | > [!NOTE] 10 | > The list of projects below is up-to-date, but the `rtdl` Python package is deprecated. 11 | > If you used the rtdl package, please, read the details. 12 | > 13 | >
14 | > 15 | > 1. First, to clarify, this repository is **NOT** deprecated, 16 | > only the package `rtdl` is deprecated: it is replaced with other packages. 17 | > 2. If you used the latest `rtdl==0.0.13` installed from PyPI (not from GitHub!) 18 | > as `pip install rtdl`, then the same models 19 | > (MLP, ResNet, FT-Transformer) can be found in the `rtdl_revisiting_models` package, 20 | > though API is slightly different. 21 | > 3. :exclamation: **If you used the unfinished code from the main branch, it is highly** 22 | > **recommended to switch to the new packages.** In particular, 23 | > the unfinished implementation of embeddings for continuous features 24 | > contained many unresolved issues (the `rtdl_num_embeddings` package, in turn, 25 | > is more efficient and correct). 26 | > 27 | >
28 | 29 | # Papers 30 | 31 | (2024) TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling 32 |
[Paper](https://arxiv.org/abs/2410.24210) 33 |   [Code](https://github.com/yandex-research/tabm) 34 |   [Usage](https://github.com/yandex-research/tabm#using-tabm-in-practice) 35 | 36 | (2024) TabReD: Analyzing Pitfalls and Filling the Gaps in Tabular Deep Learning Benchmarks 37 |
[Paper](https://arxiv.org/abs/2406.19380) 38 |   [Code](https://github.com/yandex-research/tabred) 39 | 40 | (2023) TabR: Tabular Deep Learning Meets Nearest Neighbors 41 |
[Paper](https://arxiv.org/abs/2307.14338) 42 |   [Code](https://github.com/yandex-research/tabular-dl-tabr) 43 | 44 | (2022) TabDDPM: Modelling Tabular Data with Diffusion Models 45 |
[Paper](https://arxiv.org/abs/2209.15421) 46 |   [Code](https://github.com/yandex-research/tab-ddpm) 47 | 48 | (2022) Revisiting Pretraining Objectives for Tabular Deep Learning 49 |
[Paper](https://arxiv.org/abs/2207.03208) 50 |   [Code](https://github.com/puhsu/tabular-dl-pretrain-objectives) 51 | 52 | (2022) On Embeddings for Numerical Features in Tabular Deep Learning 53 |
[Paper](https://arxiv.org/abs/2203.05556) 54 |   [Code](https://github.com/yandex-research/rtdl-num-embeddings) 55 |   [Package (rtdl_num_embeddings)](https://github.com/yandex-research/rtdl-num-embeddings/tree/main/package/README.md) 56 | 57 | (2021) Revisiting Deep Learning Models for Tabular Data 58 |
[Paper](https://arxiv.org/abs/2106.11959) 59 |   [Code](https://github.com/yandex-research/rtdl-revisiting-models) 60 |   [Package (rtdl_revisiting_models)](https://github.com/yandex-research/rtdl-revisiting-models/tree/main/package/README.md) 61 | 62 | (2019) Neural Oblivious Decision Ensembles for Deep Learning on Tabular Data 63 |
[Paper](https://arxiv.org/abs/1909.06312) 64 |   [Code](https://github.com/Qwicen/node) 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # >>> GITHUB DEFAULT PYTHON .GIGIGNORE 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | # lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # <<< GITHUB DEFAULT PYTHON .GIGIGNORE 142 | 143 | # The following directory is automatically generated by Sphnix 144 | docs/api 145 | 146 | # Data, checkpoints, etc. 147 | **/data/** 148 | **/catboost_cached_datasets/** 149 | *.bin 150 | *.csv 151 | *.cbm 152 | *.npy 153 | *.pickle 154 | *.pt 155 | *.pth 156 | *.rar 157 | *.tar* 158 | *.tmp 159 | *.zip 160 | events.out.tfevents.* 161 | 162 | # Other 163 | .DS_Store 164 | .vscode/ 165 | local/ 166 | -------------------------------------------------------------------------------- /rtdl/tests/test_modules.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import rtdl 5 | 6 | 7 | def get_devices(): 8 | return ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu'] 9 | 10 | 11 | def test_bad_mlp(): 12 | with pytest.raises(AssertionError): 13 | rtdl.MLP.make_baseline(1, [1, 2, 3, 4], 0.0, 1) 14 | 15 | 16 | @pytest.mark.parametrize('n_blocks', range(5)) 17 | @pytest.mark.parametrize('d_out', [1, 2]) 18 | @pytest.mark.parametrize('constructor', range(2)) 19 | @pytest.mark.parametrize('device', get_devices()) 20 | def test_mlp(n_blocks, d_out, constructor, device): 21 | if not n_blocks and not d_out: 22 | return 23 | d = 4 24 | d_last = d + 1 25 | d_layers = [] 26 | if n_blocks: 27 | d_layers.append(d) 28 | if n_blocks > 2: 29 | d_layers.extend([d + d_out] * (n_blocks - 2)) 30 | if n_blocks > 1: 31 | d_layers.append(d_last) 32 | 33 | def f0(): 34 | dropouts = [0.1 * x for x in range(len(d_layers))] 35 | return rtdl.MLP( 36 | d_in=d, d_layers=d_layers, dropouts=dropouts, activation='GELU', d_out=d_out 37 | ) 38 | 39 | def f1(): 40 | return rtdl.MLP.make_baseline( 41 | d_in=d, d_layers=d_layers, dropout=0.1, d_out=d_out 42 | ) 43 | 44 | model = locals()[f'f{constructor}']().to(device) 45 | n = 2 46 | assert model(torch.randn(n, d, device=device)).shape == ( 47 | (n, d_out) if d_out else (n, d_last) if n_blocks > 1 else (n, d) 48 | ) 49 | 50 | 51 | @pytest.mark.parametrize('n_blocks', [1, 2]) 52 | @pytest.mark.parametrize('d_out', [1, 2]) 53 | @pytest.mark.parametrize('constructor', range(2)) 54 | @pytest.mark.parametrize('device', get_devices()) 55 | def test_resnet(n_blocks, d_out, constructor, device): 56 | d = 4 57 | 58 | def f0(): 59 | return rtdl.ResNet.make_baseline( 60 | d_in=d, 61 | d_main=d, 62 | d_hidden=d * 3, 63 | dropout_first=0.1, 64 | dropout_second=0.2, 65 | n_blocks=n_blocks, 66 | d_out=d_out, 67 | ) 68 | 69 | def f1(): 70 | return rtdl.ResNet( 71 | d_in=d, 72 | d_main=d, 73 | d_hidden=d * 3, 74 | dropout_first=0.1, 75 | dropout_second=0.2, 76 | n_blocks=n_blocks, 77 | normalization='Identity', 78 | activation='ReLU6', 79 | d_out=d_out, 80 | ) 81 | 82 | model = locals()[f'f{constructor}']().to(device) 83 | n = 2 84 | assert model(torch.randn(n, d, device=device)).shape == ( 85 | (n, d_out) if d_out else (n, d) 86 | ) 87 | 88 | 89 | @pytest.mark.parametrize('n_blocks', range(1, 7)) 90 | @pytest.mark.parametrize('d_out', [1, 2]) 91 | @pytest.mark.parametrize('last_layer_query_idx', [None, [-1]]) 92 | @pytest.mark.parametrize('constructor', range(3)) 93 | @pytest.mark.parametrize('device', get_devices()) 94 | def test_ft_transformer(n_blocks, d_out, last_layer_query_idx, constructor, device): 95 | n_num_features = 4 96 | model = rtdl.FTTransformer.make_default( 97 | n_num_features=4, 98 | cat_cardinalities=[2, 3], 99 | n_blocks=n_blocks, 100 | last_layer_query_idx=last_layer_query_idx, 101 | kv_compression_ratio=0.5, 102 | kv_compression_sharing='headwise', 103 | d_out=d_out, 104 | ).to(device) 105 | n = 2 106 | 107 | # check that the following methods do not fail 108 | model.optimization_param_groups() 109 | model.make_default_optimizer() 110 | 111 | def f0(): 112 | return rtdl.FTTransformer.make_default( 113 | n_num_features=4, 114 | cat_cardinalities=[2, 3], 115 | n_blocks=n_blocks, 116 | last_layer_query_idx=last_layer_query_idx, 117 | kv_compression_ratio=0.5, 118 | kv_compression_sharing='headwise', 119 | d_out=d_out, 120 | ) 121 | 122 | def f1(): 123 | return rtdl.FTTransformer.make_baseline( 124 | n_num_features=4, 125 | cat_cardinalities=[2, 3], 126 | n_blocks=n_blocks, 127 | d_token=8, 128 | attention_dropout=0.2, 129 | ffn_d_hidden=8, 130 | ffn_dropout=0.3, 131 | residual_dropout=0.4, 132 | last_layer_query_idx=last_layer_query_idx, 133 | kv_compression_ratio=0.5, 134 | kv_compression_sharing='headwise', 135 | d_out=d_out, 136 | ) 137 | 138 | def f2(): 139 | d_token = 8 140 | rtdl.Transformer.WARNINGS['prenormalization'] = False 141 | model = rtdl.FTTransformer( 142 | rtdl.FeatureTokenizer(4, [2, 3], d_token), 143 | rtdl.Transformer( 144 | d_token=d_token, 145 | attention_n_heads=1, 146 | attention_dropout=0.3, 147 | attention_initialization='xavier', 148 | attention_normalization='Identity', 149 | ffn_d_hidden=4, 150 | ffn_dropout=0.3, 151 | ffn_activation='SELU', 152 | residual_dropout=0.2, 153 | ffn_normalization='Identity', 154 | prenormalization=False, 155 | first_prenormalization=False, 156 | n_tokens=7, 157 | head_activation='PReLU', 158 | head_normalization='Identity', 159 | n_blocks=n_blocks, 160 | last_layer_query_idx=last_layer_query_idx, 161 | kv_compression_ratio=0.5, 162 | kv_compression_sharing='headwise', 163 | d_out=d_out, 164 | ), 165 | ) 166 | rtdl.Transformer.WARNINGS['prenormalization'] = True 167 | return model 168 | 169 | model = locals()[f'f{constructor}']().to(device) 170 | x = model( 171 | torch.randn(n, n_num_features, device=device), 172 | torch.tensor([[0, 2], [1, 2]], device=device), 173 | ) 174 | if d_out: 175 | assert x.shape == (n, d_out) 176 | else: 177 | assert x.shape == ( 178 | n, 179 | model.feature_tokenizer.n_tokens + 1 if last_layer_query_idx is None else 1, 180 | model.feature_tokenizer.d_token, 181 | ) 182 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Yandex LLC 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /rtdl/tests/test_vs_paper.py: -------------------------------------------------------------------------------- 1 | # Tests in this file validate that the models in `rtdl` are LITERALLY THE SAME as 2 | # the ones used in the paper (https://github.com/Yura52/rtdl/tree/main/bin) 3 | # The testing approach: 4 | # (1) copy weights from the correct model to the RTDL model 5 | # (2) check that the two models produce the same output for the same input 6 | import math 7 | import random 8 | import typing as ty 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from pytest import mark 15 | from torch import Tensor 16 | 17 | import rtdl 18 | 19 | 20 | class Model(nn.Module): 21 | def __init__(self, cat_input_module: nn.Module, model: nn.Module): 22 | super().__init__() 23 | self.cat_input_module = cat_input_module 24 | self.model = model 25 | 26 | def forward(self, x_num, x_cat): 27 | return self.model( 28 | torch.cat([x_num, self.cat_input_module(x_cat).flatten(1, -1)], dim=1) 29 | ) 30 | 31 | 32 | def set_seeds(seed): 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | 37 | 38 | @torch.no_grad() 39 | def copy_layer(dst: nn.Module, src: nn.Module): 40 | for key, _ in dst.named_parameters(): 41 | getattr(dst, key).copy_(getattr(src, key)) 42 | 43 | 44 | @torch.no_grad() 45 | @mark.parametrize('seed', range(10)) 46 | def test_mlp(seed): 47 | # Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/bin/mlp.py 48 | 49 | class CorrectMLP(nn.Module): 50 | def __init__( 51 | self, 52 | *, 53 | d_in: int, 54 | d_layers: ty.List[int], 55 | dropout: float, 56 | d_out: int, 57 | categories: ty.Optional[ty.List[int]], 58 | d_embedding: int, 59 | ) -> None: 60 | super().__init__() 61 | 62 | if categories is not None: 63 | d_in += len(categories) * d_embedding 64 | category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0) 65 | self.register_buffer('category_offsets', category_offsets) 66 | self.category_embeddings = nn.Embedding(sum(categories), d_embedding) 67 | nn.init.kaiming_uniform_( 68 | self.category_embeddings.weight, a=math.sqrt(5) 69 | ) 70 | # print(f'{self.category_embeddings.weight.shape=}') 71 | 72 | self.layers = nn.ModuleList( 73 | [ 74 | nn.Linear(d_layers[i - 1] if i else d_in, x) 75 | for i, x in enumerate(d_layers) 76 | ] 77 | ) 78 | self.dropout = dropout 79 | self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out) 80 | 81 | def forward(self, x_num, x_cat): 82 | if x_cat is not None: 83 | x_cat = self.category_embeddings(x_cat + self.category_offsets[None]) # type: ignore 84 | x = torch.cat([x_num, x_cat.view(x_cat.size(0), -1)], dim=-1) 85 | else: 86 | x = x_num 87 | 88 | for layer in self.layers: 89 | x = layer(x) 90 | x = F.relu(x) 91 | if self.dropout: 92 | x = F.dropout(x, self.dropout, self.training) 93 | x = self.head(x) 94 | return x 95 | 96 | n = 32 97 | d_num = 2 98 | categories = [2, 3] 99 | d_embedding = 3 100 | d_in = d_num + len(categories) * d_embedding 101 | d_layers = [3, 4, 5] 102 | dropout = 0.1 103 | d_out = 2 104 | 105 | set_seeds(seed) 106 | correct_model = CorrectMLP( 107 | d_in=d_num, 108 | d_layers=d_layers, 109 | dropout=dropout, 110 | d_out=d_out, 111 | categories=categories, 112 | d_embedding=d_embedding, 113 | ) 114 | rtdl_tokenizer = rtdl.CategoricalFeatureTokenizer( 115 | categories, d_embedding, False, 'uniform' 116 | ) 117 | rtdl_backbone = rtdl.MLP( 118 | d_in=d_in, 119 | d_layers=d_layers, 120 | dropouts=dropout, 121 | activation='ReLU', 122 | d_out=d_out, 123 | ) 124 | 125 | rtdl_tokenizer.embeddings.weight.copy_(correct_model.category_embeddings.weight) 126 | for correct_layer, block in zip(correct_model.layers, rtdl_backbone.blocks): 127 | copy_layer(block.linear, correct_layer) 128 | copy_layer(rtdl_backbone.head, correct_model.head) 129 | 130 | rtdl_model = Model(rtdl_tokenizer, rtdl_backbone) 131 | x_num = torch.randn(n, d_num) 132 | x_cat = torch.cat([torch.randint(x, (n, 1)) for x in categories], dim=1) 133 | set_seeds(seed) 134 | correct_result = correct_model(x_num, x_cat) 135 | set_seeds(seed) 136 | rtdl_result = rtdl_model(x_num, x_cat) 137 | assert (correct_result == rtdl_result).all() 138 | 139 | 140 | @torch.no_grad() 141 | @mark.parametrize('seed', range(10)) 142 | def test_resnet(seed): 143 | # Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/bin/resnet.py 144 | 145 | class CorrectResNet(nn.Module): 146 | def __init__( 147 | self, 148 | *, 149 | d_numerical: int, 150 | categories: ty.Optional[ty.List[int]], 151 | d_embedding: int, 152 | d: int, 153 | d_hidden_factor: float, 154 | n_layers: int, 155 | activation: str, 156 | normalization: str, 157 | hidden_dropout: float, 158 | residual_dropout: float, 159 | d_out: int, 160 | ) -> None: 161 | super().__init__() 162 | 163 | def make_normalization(): 164 | return {'BatchNorm1d': nn.BatchNorm1d}[normalization](d) 165 | 166 | assert activation == 'ReLU' 167 | self.main_activation = F.relu 168 | self.last_activation = F.relu 169 | self.residual_dropout = residual_dropout 170 | self.hidden_dropout = hidden_dropout 171 | 172 | d_in = d_numerical 173 | d_hidden = int(d * d_hidden_factor) 174 | 175 | if categories is not None: 176 | d_in += len(categories) * d_embedding 177 | category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0) 178 | self.register_buffer('category_offsets', category_offsets) 179 | self.category_embeddings = nn.Embedding(sum(categories), d_embedding) 180 | nn.init.kaiming_uniform_( 181 | self.category_embeddings.weight, a=math.sqrt(5) 182 | ) 183 | # print(f'{self.category_embeddings.weight.shape=}') 184 | 185 | self.first_layer = nn.Linear(d_in, d) 186 | self.layers = nn.ModuleList( 187 | [ 188 | nn.ModuleDict( 189 | { 190 | 'norm': make_normalization(), 191 | 'linear0': nn.Linear( 192 | d, d_hidden * (2 if activation.endswith('glu') else 1) 193 | ), 194 | 'linear1': nn.Linear(d_hidden, d), 195 | } 196 | ) 197 | for _ in range(n_layers) 198 | ] 199 | ) 200 | self.last_normalization = make_normalization() 201 | self.head = nn.Linear(d, d_out) 202 | 203 | def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor: 204 | if x_cat is not None: 205 | x_cat = self.category_embeddings(x_cat + self.category_offsets[None]) # type: ignore 206 | x = torch.cat([x_num, x_cat.view(x_cat.size(0), -1)], dim=-1) # type: ignore 207 | else: 208 | x = x_num 209 | 210 | x = self.first_layer(x) 211 | for layer in self.layers: 212 | layer = ty.cast(ty.Dict[str, nn.Module], layer) 213 | z = x 214 | z = layer['norm'](z) 215 | z = layer['linear0'](z) 216 | z = self.main_activation(z) 217 | if self.hidden_dropout: 218 | z = F.dropout(z, self.hidden_dropout, self.training) 219 | z = layer['linear1'](z) 220 | if self.residual_dropout: 221 | z = F.dropout(z, self.residual_dropout, self.training) 222 | x = x + z 223 | x = self.last_normalization(x) 224 | x = self.last_activation(x) 225 | x = self.head(x) 226 | return x 227 | 228 | n = 32 229 | d_num = 2 230 | categories = [2, 3] 231 | d_embedding = 3 232 | d_in = d_num + len(categories) * d_embedding 233 | d = 4 234 | d_hidden_factor = 1.5 235 | n_layers = 2 236 | activation = 'ReLU' 237 | normalization = 'BatchNorm1d' 238 | hidden_dropout = 0.1 239 | residual_dropout = 0.2 240 | d_out = 2 241 | 242 | set_seeds(seed) 243 | correct_model = CorrectResNet( 244 | d_numerical=d_num, 245 | categories=categories, 246 | d_embedding=d_embedding, 247 | d=d, 248 | d_hidden_factor=d_hidden_factor, 249 | n_layers=n_layers, 250 | activation=activation, 251 | normalization=normalization, 252 | hidden_dropout=0.1, 253 | residual_dropout=0.2, 254 | d_out=d_out, 255 | ) 256 | rtdl_tokenizer = rtdl.CategoricalFeatureTokenizer( 257 | categories, d_embedding, False, 'uniform' 258 | ) 259 | rtdl_backbone = rtdl.ResNet( 260 | d_in=d_in, 261 | n_blocks=n_layers, 262 | d_main=d, 263 | d_hidden=int(d * d_hidden_factor), 264 | dropout_first=hidden_dropout, 265 | dropout_second=residual_dropout, 266 | normalization=normalization, 267 | activation=activation, 268 | d_out=d_out, 269 | ) 270 | 271 | rtdl_tokenizer.embeddings.weight.copy_(correct_model.category_embeddings.weight) 272 | copy_layer(rtdl_backbone.first_layer, correct_model.first_layer) 273 | for correct_layer, block in zip(correct_model.layers, rtdl_backbone.blocks): 274 | copy_layer(block.normalization, correct_layer['norm']) 275 | copy_layer(block.linear_first, correct_layer['linear0']) 276 | copy_layer(block.linear_second, correct_layer['linear1']) 277 | 278 | copy_layer(rtdl_backbone.head.normalization, correct_model.last_normalization) 279 | copy_layer(rtdl_backbone.head.linear, correct_model.head) 280 | 281 | rtdl_model = Model(rtdl_tokenizer, rtdl_backbone) 282 | x_num = torch.randn(n, d_num) 283 | x_cat = torch.cat([torch.randint(x, (n, 1)) for x in categories], dim=1) 284 | set_seeds(seed) 285 | correct_result = correct_model(x_num, x_cat) 286 | set_seeds(seed) 287 | rtdl_result = rtdl_model(x_num, x_cat) 288 | assert (correct_result == rtdl_result).all() 289 | 290 | 291 | @torch.no_grad() 292 | @mark.parametrize('seed', range(10)) 293 | @mark.parametrize('kv_compression_ratio', [None, 0.5]) 294 | def test_ft_transformer(seed, kv_compression_ratio): 295 | # """Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/bin/ft_transformer.py""" 296 | # The only difference is that [CLS] is now the last token. 297 | 298 | def correct_reglu(x: Tensor) -> Tensor: 299 | a, b = x.chunk(2, dim=-1) 300 | return a * F.relu(b) 301 | 302 | class CorrectTokenizer(nn.Module): 303 | category_offsets: ty.Optional[Tensor] 304 | 305 | def __init__( 306 | self, 307 | d_numerical: int, 308 | categories: ty.Optional[ty.List[int]], 309 | d_token: int, 310 | bias: bool, 311 | ) -> None: 312 | super().__init__() 313 | if categories is None: 314 | d_bias = d_numerical 315 | self.category_offsets = None 316 | self.category_embeddings = None 317 | else: 318 | d_bias = d_numerical + len(categories) 319 | category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0) 320 | self.register_buffer('category_offsets', category_offsets) 321 | self.category_embeddings = nn.Embedding(sum(categories), d_token) 322 | nn.init.kaiming_uniform_( 323 | self.category_embeddings.weight, a=math.sqrt(5) 324 | ) 325 | 326 | # take [CLS] token into account 327 | self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token)) 328 | self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None 329 | # The initialization is inspired by nn.Linear 330 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 331 | if self.bias is not None: 332 | nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5)) 333 | 334 | @property 335 | def n_tokens(self) -> int: 336 | return len(self.weight) + ( 337 | 0 if self.category_offsets is None else len(self.category_offsets) 338 | ) 339 | 340 | def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor: 341 | # x_num = torch.cat( 342 | # [ 343 | # torch.ones(len(x_num), 1, device=x_num.device), 344 | # x_num, 345 | # ], 346 | # dim=1, 347 | # ) 348 | x = self.weight[:-1][None] * x_num[:, :, None] 349 | if x_cat is not None: 350 | x = torch.cat( 351 | [x, self.category_embeddings(x_cat + self.category_offsets[None])], 352 | dim=1, 353 | ) 354 | x = torch.cat( 355 | [x, self.weight[-1][None, None].repeat(len(x), 1, 1)], 356 | dim=1, 357 | ) 358 | if self.bias is not None: 359 | bias = torch.cat( 360 | [ 361 | self.bias, 362 | torch.zeros(1, self.bias.shape[1], device=x_num.device), 363 | ] 364 | ) 365 | x = x + bias[None] 366 | return x 367 | 368 | class CorrectMultiheadAttention(nn.Module): 369 | def __init__( 370 | self, d: int, n_heads: int, dropout: float, initialization: str 371 | ) -> None: 372 | if n_heads > 1: 373 | assert d % n_heads == 0 374 | assert initialization in ['xavier', 'kaiming'] 375 | 376 | super().__init__() 377 | self.W_q = nn.Linear(d, d) 378 | self.W_k = nn.Linear(d, d) 379 | self.W_v = nn.Linear(d, d) 380 | self.W_out = nn.Linear(d, d) if n_heads > 1 else None 381 | self.n_heads = n_heads 382 | self.dropout = nn.Dropout(dropout) if dropout else None 383 | 384 | for m in [self.W_q, self.W_k, self.W_v]: 385 | if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v): 386 | # gain is needed since W_qkv is represented with 3 separate layers 387 | nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2)) 388 | nn.init.zeros_(m.bias) 389 | if self.W_out is not None: 390 | nn.init.zeros_(self.W_out.bias) 391 | 392 | def _reshape(self, x: Tensor) -> Tensor: 393 | batch_size, n_tokens, d = x.shape 394 | d_head = d // self.n_heads 395 | return ( 396 | x.reshape(batch_size, n_tokens, self.n_heads, d_head) 397 | .transpose(1, 2) 398 | .reshape(batch_size * self.n_heads, n_tokens, d_head) 399 | ) 400 | 401 | def forward( 402 | self, 403 | x_q: Tensor, 404 | x_kv: Tensor, 405 | key_compression: ty.Optional[nn.Linear], 406 | value_compression: ty.Optional[nn.Linear], 407 | ) -> Tensor: 408 | q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv) 409 | for tensor in [q, k, v]: 410 | assert tensor.shape[-1] % self.n_heads == 0 411 | if key_compression is not None: 412 | assert value_compression is not None 413 | k = key_compression(k.transpose(1, 2)).transpose(1, 2) 414 | v = value_compression(v.transpose(1, 2)).transpose(1, 2) 415 | else: 416 | assert value_compression is None 417 | 418 | batch_size = len(q) 419 | d_head_key = k.shape[-1] // self.n_heads 420 | d_head_value = v.shape[-1] // self.n_heads 421 | n_q_tokens = q.shape[1] 422 | 423 | q = self._reshape(q) 424 | k = self._reshape(k) 425 | attention = F.softmax(q @ k.transpose(1, 2) / math.sqrt(d_head_key), dim=-1) 426 | if self.dropout is not None: 427 | attention = self.dropout(attention) 428 | x = attention @ self._reshape(v) 429 | x = ( 430 | x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value) 431 | .transpose(1, 2) 432 | .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value) 433 | ) 434 | if self.W_out is not None: 435 | x = self.W_out(x) 436 | return x 437 | 438 | class CorrectFTTransformer(nn.Module): 439 | def __init__( 440 | self, 441 | *, 442 | # tokenizer 443 | d_numerical: int, 444 | categories: ty.Optional[ty.List[int]], 445 | token_bias: bool, 446 | # transformer 447 | n_layers: int, 448 | d_token: int, 449 | n_heads: int, 450 | d_ffn_factor: float, 451 | attention_dropout: float, 452 | ffn_dropout: float, 453 | residual_dropout: float, 454 | activation: str, 455 | prenormalization: bool, 456 | initialization: str, 457 | # linformer 458 | kv_compression: ty.Optional[float], 459 | kv_compression_sharing: ty.Optional[str], 460 | # 461 | d_out: int, 462 | ) -> None: 463 | assert (kv_compression is None) ^ (kv_compression_sharing is not None) 464 | 465 | super().__init__() 466 | self.tokenizer = CorrectTokenizer( 467 | d_numerical, categories, d_token, token_bias 468 | ) 469 | n_tokens = self.tokenizer.n_tokens 470 | 471 | def make_kv_compression(): 472 | assert kv_compression 473 | compression = nn.Linear( 474 | n_tokens, int(n_tokens * kv_compression), bias=False 475 | ) 476 | if initialization == 'xavier': 477 | nn.init.xavier_uniform_(compression.weight) 478 | return compression 479 | 480 | self.shared_kv_compression = ( 481 | make_kv_compression() 482 | if kv_compression and kv_compression_sharing == 'layerwise' 483 | else None 484 | ) 485 | 486 | def make_normalization(): 487 | return nn.LayerNorm(d_token) 488 | 489 | d_hidden = int(d_token * d_ffn_factor) 490 | self.layers = nn.ModuleList([]) 491 | for layer_idx in range(n_layers): 492 | layer = nn.ModuleDict( 493 | { 494 | 'attention': CorrectMultiheadAttention( 495 | d_token, n_heads, attention_dropout, initialization 496 | ), 497 | 'linear0': nn.Linear( 498 | d_token, d_hidden * (2 if activation.endswith('glu') else 1) 499 | ), 500 | 'linear1': nn.Linear(d_hidden, d_token), 501 | 'norm1': make_normalization(), 502 | } 503 | ) 504 | if not prenormalization or layer_idx: 505 | layer['norm0'] = make_normalization() 506 | if kv_compression and self.shared_kv_compression is None: 507 | layer['key_compression'] = make_kv_compression() 508 | if kv_compression_sharing == 'headwise': 509 | layer['value_compression'] = make_kv_compression() 510 | else: 511 | assert kv_compression_sharing == 'key-value' 512 | self.layers.append(layer) 513 | 514 | assert activation == 'reglu' 515 | self.activation = correct_reglu 516 | self.last_activation = F.relu 517 | self.prenormalization = prenormalization 518 | self.last_normalization = make_normalization() if prenormalization else None 519 | self.ffn_dropout = ffn_dropout 520 | self.residual_dropout = residual_dropout 521 | self.head = nn.Linear(d_token, d_out) 522 | 523 | def _get_kv_compressions(self, layer): 524 | return ( 525 | (self.shared_kv_compression, self.shared_kv_compression) 526 | if self.shared_kv_compression is not None 527 | else (layer['key_compression'], layer['value_compression']) 528 | if 'key_compression' in layer and 'value_compression' in layer 529 | else (layer['key_compression'], layer['key_compression']) 530 | if 'key_compression' in layer 531 | else (None, None) 532 | ) 533 | 534 | def _start_residual(self, x, layer, norm_idx): 535 | x_residual = x 536 | if self.prenormalization: 537 | norm_key = f'norm{norm_idx}' 538 | if norm_key in layer: 539 | x_residual = layer[norm_key](x_residual) 540 | return x_residual 541 | 542 | def _end_residual(self, x, x_residual, layer, norm_idx): 543 | if self.residual_dropout: 544 | x_residual = F.dropout(x_residual, self.residual_dropout, self.training) 545 | x = x + x_residual 546 | if not self.prenormalization: 547 | x = layer[f'norm{norm_idx}'](x) 548 | return x 549 | 550 | def forward(self, x_num: Tensor, x_cat: ty.Optional[Tensor]) -> Tensor: 551 | x = self.tokenizer(x_num, x_cat) 552 | 553 | for layer_idx, layer in enumerate(self.layers): 554 | is_last_layer = layer_idx + 1 == len(self.layers) 555 | layer = ty.cast(ty.Dict[str, nn.Module], layer) 556 | 557 | x_residual = self._start_residual(x, layer, 0) 558 | x_residual = layer['attention']( 559 | # for the last attention, it is enough to process only [CLS] 560 | (x_residual[:, -1:] if is_last_layer else x_residual), 561 | x_residual, 562 | *self._get_kv_compressions(layer), 563 | ) 564 | if is_last_layer: 565 | x = x[:, -1:] 566 | x = self._end_residual(x, x_residual, layer, 0) 567 | 568 | x_residual = self._start_residual(x, layer, 1) 569 | x_residual = layer['linear0'](x_residual) 570 | x_residual = self.activation(x_residual) 571 | if self.ffn_dropout: 572 | x_residual = F.dropout(x_residual, self.ffn_dropout, self.training) 573 | x_residual = layer['linear1'](x_residual) 574 | x = self._end_residual(x, x_residual, layer, 1) 575 | 576 | assert x.shape[1] == 1 577 | x = x[:, 0] 578 | if self.last_normalization is not None: 579 | x = self.last_normalization(x) 580 | x = self.last_activation(x) 581 | x = self.head(x) 582 | x = x.squeeze(-1) 583 | return x 584 | 585 | # Source: https://github.com/Yura52/rtdl/blob/0e5169659c7ce552bc05bbaa85f7e204adc3d88e/output/california_housing/ft_transformer/default/0.toml 586 | default_config = { 587 | 'seed': 0, 588 | 'data': { 589 | 'normalization': 'quantile_normal', 590 | 'path': 'data/california_housing', 591 | 'y_policy': 'mean_std', 592 | }, 593 | 'model': { 594 | 'activation': 'reglu', 595 | 'attention_dropout': 0.2, 596 | 'd_ffn_factor': 4 / 3, 597 | 'd_token': 192, 598 | 'ffn_dropout': 0.1, 599 | 'initialization': 'kaiming', 600 | 'n_heads': 8, 601 | 'n_layers': 3, 602 | 'prenormalization': True, 603 | 'residual_dropout': 0.0, 604 | }, 605 | 'training': { 606 | 'batch_size': 256, 607 | 'eval_batch_size': 8192, 608 | 'lr': 0.0001, 609 | 'lr_n_decays': 0, 610 | 'n_epochs': 1000000000, 611 | 'optimizer': 'adamw', 612 | 'patience': 16, 613 | 'weight_decay': 1e-05, 614 | }, 615 | } 616 | dcm = default_config['model'] 617 | 618 | n = 4 619 | d_num = 2 620 | categories = [2, 3] 621 | n_tokens = d_num + len(categories) + 1 622 | kv_compression_sharing = 'key-value' if kv_compression_ratio else None 623 | d_out = 2 624 | 625 | set_seeds(seed) 626 | correct_model = CorrectFTTransformer( 627 | d_numerical=d_num, 628 | categories=categories, 629 | token_bias=True, 630 | kv_compression=kv_compression_ratio, 631 | kv_compression_sharing=kv_compression_sharing, 632 | d_out=d_out, 633 | **dcm, 634 | ) 635 | rtdl_model = rtdl.FTTransformer( 636 | rtdl.FeatureTokenizer(d_num, categories, dcm['d_token']), 637 | rtdl.Transformer( 638 | d_token=dcm['d_token'], 639 | n_blocks=dcm['n_layers'], 640 | attention_n_heads=dcm['n_heads'], 641 | attention_dropout=dcm['attention_dropout'], 642 | attention_initialization=dcm['initialization'], 643 | attention_normalization='LayerNorm', 644 | ffn_d_hidden=int(dcm['d_token'] * dcm['d_ffn_factor']), 645 | ffn_dropout=dcm['ffn_dropout'], 646 | ffn_activation='ReGLU', 647 | ffn_normalization='LayerNorm', 648 | residual_dropout=dcm['residual_dropout'], 649 | prenormalization=dcm['prenormalization'], 650 | first_prenormalization=False, 651 | last_layer_query_idx=[-1], 652 | n_tokens=n_tokens if kv_compression_ratio else None, 653 | kv_compression_ratio=kv_compression_ratio, 654 | kv_compression_sharing=kv_compression_sharing, 655 | head_activation='ReLU', 656 | head_normalization='LayerNorm', 657 | d_out=d_out, 658 | ), 659 | ) 660 | rtdl_default_model = rtdl.FTTransformer.make_default( 661 | n_num_features=d_num, 662 | cat_cardinalities=categories, 663 | last_layer_query_idx=[-1], 664 | kv_compression_ratio=kv_compression_ratio, 665 | kv_compression_sharing=kv_compression_sharing, 666 | d_out=d_out, 667 | ) 668 | 669 | rtdl_model.feature_tokenizer.num_tokenizer.weight.copy_( 670 | correct_model.tokenizer.weight[:-1] 671 | ) 672 | rtdl_model.feature_tokenizer.num_tokenizer.bias.copy_( 673 | correct_model.tokenizer.bias[:d_num] 674 | ) 675 | rtdl_model.feature_tokenizer.cat_tokenizer.embeddings.weight.copy_( 676 | correct_model.tokenizer.category_embeddings.weight 677 | ) 678 | rtdl_model.feature_tokenizer.cat_tokenizer.bias.copy_( 679 | correct_model.tokenizer.bias[-len(categories) :] 680 | ) 681 | rtdl_model.cls_token.weight.copy_(correct_model.tokenizer.weight[-1]) 682 | for correct_layer, block in zip( 683 | correct_model.layers, rtdl_model.transformer.blocks 684 | ): 685 | for key in ['W_q', 'W_k', 'W_v', 'W_out']: 686 | copy_layer( 687 | getattr(block['attention'], key), 688 | getattr(correct_layer['attention'], key), 689 | ) 690 | copy_layer(block['ffn'].linear_first, correct_layer['linear0']) 691 | copy_layer(block['ffn'].linear_second, correct_layer['linear1']) 692 | copy_layer(block['ffn_normalization'], correct_layer['norm1']) 693 | if 'norm0' in correct_layer: 694 | copy_layer(block['attention_normalization'], correct_layer['norm0']) 695 | for key in ['key_compression', 'value_compression']: 696 | if key in correct_layer: 697 | copy_layer(block[key], correct_layer[key]) 698 | copy_layer( 699 | rtdl_model.transformer.head.normalization, correct_model.last_normalization 700 | ) 701 | copy_layer(rtdl_model.transformer.head.linear, correct_model.head) 702 | rtdl_default_model.load_state_dict(rtdl_model.state_dict()) 703 | 704 | x_num = torch.randn(n, d_num) 705 | x_cat = torch.cat([torch.randint(x, (n, 1)) for x in categories], dim=1) 706 | results = [] 707 | for m in [correct_model, rtdl_model, rtdl_default_model]: 708 | set_seeds(seed) 709 | results.append(m(x_num, x_cat)) 710 | correct_result = results[0] 711 | assert (results[1] == results[2]).all() 712 | assert (results[1] == correct_result).all() 713 | -------------------------------------------------------------------------------- /rtdl/modules.py: -------------------------------------------------------------------------------- 1 | """@private""" 2 | import enum 3 | import math 4 | import time 5 | import warnings 6 | from typing import ( 7 | Any, 8 | Callable, 9 | ClassVar, 10 | Dict, 11 | List, 12 | Optional, 13 | Tuple, 14 | Type, 15 | Union, 16 | cast, 17 | ) 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim 23 | from torch import Tensor 24 | 25 | from . import functional as rtdlF 26 | 27 | ModuleType = Union[str, Callable[..., nn.Module]] 28 | _INTERNAL_ERROR_MESSAGE = 'Internal error. Please, open an issue.' 29 | 30 | 31 | def _is_glu_activation(activation: ModuleType): 32 | return ( 33 | isinstance(activation, str) 34 | and activation.endswith('GLU') 35 | or activation in [ReGLU, GEGLU] 36 | ) 37 | 38 | 39 | def _all_or_none(values): 40 | return all(x is None for x in values) or all(x is not None for x in values) 41 | 42 | 43 | class ReGLU(nn.Module): 44 | """The ReGLU activation function from [shazeer2020glu]. 45 | 46 | Examples: 47 | .. testcode:: 48 | 49 | module = ReGLU() 50 | x = torch.randn(3, 4) 51 | assert module(x).shape == (3, 2) 52 | 53 | References: 54 | * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 55 | """ 56 | 57 | def forward(self, x: Tensor) -> Tensor: 58 | return rtdlF.reglu(x) 59 | 60 | 61 | class GEGLU(nn.Module): 62 | """The GEGLU activation function from [shazeer2020glu]. 63 | 64 | Examples: 65 | .. testcode:: 66 | 67 | module = GEGLU() 68 | x = torch.randn(3, 4) 69 | assert module(x).shape == (3, 2) 70 | 71 | References: 72 | * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020 73 | """ 74 | 75 | def forward(self, x: Tensor) -> Tensor: 76 | return rtdlF.geglu(x) 77 | 78 | 79 | class _TokenInitialization(enum.Enum): 80 | UNIFORM = 'uniform' 81 | NORMAL = 'normal' 82 | 83 | @classmethod 84 | def from_str(cls, initialization: str) -> '_TokenInitialization': 85 | try: 86 | return cls(initialization) 87 | except ValueError: 88 | valid_values = [x.value for x in _TokenInitialization] 89 | raise ValueError(f'initialization must be one of {valid_values}') 90 | 91 | def apply(self, x: Tensor, d: int) -> None: 92 | d_sqrt_inv = 1 / math.sqrt(d) 93 | if self == _TokenInitialization.UNIFORM: 94 | # used in the paper "Revisiting Deep Learning Models for Tabular Data"; 95 | # is equivalent to `nn.init.kaiming_uniform_(x, a=math.sqrt(5))` (which is 96 | # used by torch to initialize nn.Linear.weight, for example) 97 | nn.init.uniform_(x, a=-d_sqrt_inv, b=d_sqrt_inv) 98 | elif self == _TokenInitialization.NORMAL: 99 | nn.init.normal_(x, std=d_sqrt_inv) 100 | 101 | 102 | class NumericalFeatureTokenizer(nn.Module): 103 | """Transforms continuous features to tokens (embeddings). 104 | 105 | See `FeatureTokenizer` for the illustration. 106 | 107 | For one feature, the transformation consists of two steps: 108 | 109 | * the feature is multiplied by a trainable vector 110 | * another trainable vector is added 111 | 112 | Note that each feature has its separate pair of trainable vectors, i.e. the vectors 113 | are not shared between features. 114 | 115 | Examples: 116 | .. testcode:: 117 | 118 | x = torch.randn(4, 2) 119 | n_objects, n_features = x.shape 120 | d_token = 3 121 | tokenizer = NumericalFeatureTokenizer(n_features, d_token, True, 'uniform') 122 | tokens = tokenizer(x) 123 | assert tokens.shape == (n_objects, n_features, d_token) 124 | """ 125 | 126 | def __init__( 127 | self, 128 | n_features: int, 129 | d_token: int, 130 | bias: bool, 131 | initialization: str, 132 | ) -> None: 133 | """ 134 | Args: 135 | n_features: the number of continuous (scalar) features 136 | d_token: the size of one token 137 | bias: if `False`, then the transformation will include only multiplication. 138 | **Warning**: :code:`bias=False` leads to significantly worse results for 139 | Transformer-like (token-based) architectures. 140 | initialization: initialization policy for parameters. Must be one of 141 | :code:`['uniform', 'normal']`. Let :code:`s = d ** -0.5`. Then, the 142 | corresponding distributions are :code:`Uniform(-s, s)` and :code:`Normal(0, s)`. 143 | In [gorishniy2021revisiting], the 'uniform' initialization was used. 144 | 145 | References: 146 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 147 | """ 148 | super().__init__() 149 | initialization_ = _TokenInitialization.from_str(initialization) 150 | self.weight = nn.Parameter(Tensor(n_features, d_token)) 151 | self.bias = nn.Parameter(Tensor(n_features, d_token)) if bias else None 152 | for parameter in [self.weight, self.bias]: 153 | if parameter is not None: 154 | initialization_.apply(parameter, d_token) 155 | 156 | @property 157 | def n_tokens(self) -> int: 158 | """The number of tokens.""" 159 | return len(self.weight) 160 | 161 | @property 162 | def d_token(self) -> int: 163 | """The size of one token.""" 164 | return self.weight.shape[1] 165 | 166 | def forward(self, x: Tensor) -> Tensor: 167 | x = self.weight[None] * x[..., None] 168 | if self.bias is not None: 169 | x = x + self.bias[None] 170 | return x 171 | 172 | 173 | class CategoricalFeatureTokenizer(nn.Module): 174 | """Transforms categorical features to tokens (embeddings). 175 | 176 | See `FeatureTokenizer` for the illustration. 177 | 178 | The module efficiently implements a collection of `torch.nn.Embedding` (with 179 | optional biases). 180 | 181 | Examples: 182 | .. testcode:: 183 | 184 | # the input must contain integers. For example, if the first feature can 185 | # take 3 distinct values, then its cardinality is 3 and the first column 186 | # must contain values from the range `[0, 1, 2]`. 187 | cardinalities = [3, 10] 188 | x = torch.tensor([ 189 | [0, 5], 190 | [1, 7], 191 | [0, 2], 192 | [2, 4] 193 | ]) 194 | n_objects, n_features = x.shape 195 | d_token = 3 196 | tokenizer = CategoricalFeatureTokenizer(cardinalities, d_token, True, 'uniform') 197 | tokens = tokenizer(x) 198 | assert tokens.shape == (n_objects, n_features, d_token) 199 | """ 200 | 201 | category_offsets: Tensor 202 | 203 | def __init__( 204 | self, 205 | cardinalities: List[int], 206 | d_token: int, 207 | bias: bool, 208 | initialization: str, 209 | ) -> None: 210 | """ 211 | Args: 212 | cardinalities: the number of distinct values for each feature. For example, 213 | :code:`cardinalities=[3, 4]` describes two features: the first one can 214 | take values in the range :code:`[0, 1, 2]` and the second one can take 215 | values in the range :code:`[0, 1, 2, 3]`. 216 | d_token: the size of one token. 217 | bias: if `True`, for each feature, a trainable vector is added to the 218 | embedding regardless of feature value. The bias vectors are not shared 219 | between features. 220 | initialization: initialization policy for parameters. Must be one of 221 | :code:`['uniform', 'normal']`. Let :code:`s = d ** -0.5`. Then, the 222 | corresponding distributions are :code:`Uniform(-s, s)` and :code:`Normal(0, s)`. In 223 | the paper [gorishniy2021revisiting], the 'uniform' initialization was 224 | used. 225 | 226 | References: 227 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 228 | """ 229 | super().__init__() 230 | assert cardinalities, 'cardinalities must be non-empty' 231 | assert d_token > 0, 'd_token must be positive' 232 | initialization_ = _TokenInitialization.from_str(initialization) 233 | 234 | category_offsets = torch.tensor([0] + cardinalities[:-1]).cumsum(0) 235 | self.register_buffer('category_offsets', category_offsets, persistent=False) 236 | self.embeddings = nn.Embedding(sum(cardinalities), d_token) 237 | self.bias = nn.Parameter(Tensor(len(cardinalities), d_token)) if bias else None 238 | 239 | for parameter in [self.embeddings.weight, self.bias]: 240 | if parameter is not None: 241 | initialization_.apply(parameter, d_token) 242 | 243 | @property 244 | def n_tokens(self) -> int: 245 | """The number of tokens.""" 246 | return len(self.category_offsets) 247 | 248 | @property 249 | def d_token(self) -> int: 250 | """The size of one token.""" 251 | return self.embeddings.embedding_dim 252 | 253 | def forward(self, x: Tensor) -> Tensor: 254 | x = self.embeddings(x + self.category_offsets[None]) 255 | if self.bias is not None: 256 | x = x + self.bias[None] 257 | return x 258 | 259 | 260 | class FeatureTokenizer(nn.Module): 261 | """Combines `NumericalFeatureTokenizer` and `CategoricalFeatureTokenizer`. 262 | 263 | The "Feature Tokenizer" module from [gorishniy2021revisiting]. The module transforms 264 | continuous and categorical features to tokens (embeddings). 265 | 266 | In the illustration below, the red module in the upper brackets represents 267 | `NumericalFeatureTokenizer` and the green module in the lower brackets represents 268 | `CategoricalFeatureTokenizer`. 269 | 270 | .. image:: ../images/feature_tokenizer.png 271 | :scale: 33% 272 | :alt: Feature Tokenizer 273 | 274 | Examples: 275 | .. testcode:: 276 | 277 | n_objects = 4 278 | n_num_features = 3 279 | n_cat_features = 2 280 | d_token = 7 281 | x_num = torch.randn(n_objects, n_num_features) 282 | x_cat = torch.tensor([[0, 1], [1, 0], [0, 2], [1, 1]]) 283 | # [2, 3] reflects cardinalities fr 284 | tokenizer = FeatureTokenizer(n_num_features, [2, 3], d_token) 285 | tokens = tokenizer(x_num, x_cat) 286 | assert tokens.shape == (n_objects, n_num_features + n_cat_features, d_token) 287 | 288 | References: 289 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko "Revisiting Deep Learning Models for Tabular Data", 2021 290 | """ 291 | 292 | def __init__( 293 | self, 294 | n_num_features: int, 295 | cat_cardinalities: List[int], 296 | d_token: int, 297 | ) -> None: 298 | """ 299 | Args: 300 | n_num_features: the number of continuous features. Pass :code:`0` if there 301 | are no numerical features. 302 | cat_cardinalities: the number of unique values for each feature. See 303 | `CategoricalFeatureTokenizer` for details. Pass an empty list if there 304 | are no categorical features. 305 | d_token: the size of one token. 306 | """ 307 | super().__init__() 308 | assert n_num_features >= 0, 'n_num_features must be non-negative' 309 | assert ( 310 | n_num_features or cat_cardinalities 311 | ), 'at least one of n_num_features or cat_cardinalities must be positive/non-empty' 312 | self.initialization = 'uniform' 313 | self.num_tokenizer = ( 314 | NumericalFeatureTokenizer( 315 | n_features=n_num_features, 316 | d_token=d_token, 317 | bias=True, 318 | initialization=self.initialization, 319 | ) 320 | if n_num_features 321 | else None 322 | ) 323 | self.cat_tokenizer = ( 324 | CategoricalFeatureTokenizer( 325 | cat_cardinalities, d_token, True, self.initialization 326 | ) 327 | if cat_cardinalities 328 | else None 329 | ) 330 | 331 | @property 332 | def n_tokens(self) -> int: 333 | """The number of tokens.""" 334 | return sum( 335 | x.n_tokens 336 | for x in [self.num_tokenizer, self.cat_tokenizer] 337 | if x is not None 338 | ) 339 | 340 | @property 341 | def d_token(self) -> int: 342 | """The size of one token.""" 343 | return ( 344 | self.cat_tokenizer.d_token # type: ignore 345 | if self.num_tokenizer is None 346 | else self.num_tokenizer.d_token 347 | ) 348 | 349 | def forward(self, x_num: Optional[Tensor], x_cat: Optional[Tensor]) -> Tensor: 350 | """Perform the forward pass. 351 | 352 | Args: 353 | x_num: continuous features. Must be presented if :code:`n_num_features > 0` 354 | was passed to the constructor. 355 | x_cat: categorical features (see `CategoricalFeatureTokenizer.forward` for 356 | details). Must be presented if non-empty :code:`cat_cardinalities` was 357 | passed to the constructor. 358 | Returns: 359 | tokens 360 | Raises: 361 | AssertionError: if the described requirements for the inputs are not met. 362 | """ 363 | assert ( 364 | x_num is not None or x_cat is not None 365 | ), 'At least one of x_num and x_cat must be presented' 366 | assert _all_or_none( 367 | [self.num_tokenizer, x_num] 368 | ), 'If self.num_tokenizer is (not) None, then x_num must (not) be None' 369 | assert _all_or_none( 370 | [self.cat_tokenizer, x_cat] 371 | ), 'If self.cat_tokenizer is (not) None, then x_cat must (not) be None' 372 | x = [] 373 | if self.num_tokenizer is not None: 374 | x.append(self.num_tokenizer(x_num)) 375 | if self.cat_tokenizer is not None: 376 | x.append(self.cat_tokenizer(x_cat)) 377 | return x[0] if len(x) == 1 else torch.cat(x, dim=1) 378 | 379 | 380 | class CLSToken(nn.Module): 381 | """[CLS]-token for BERT-like inference. 382 | 383 | To learn about the [CLS]-based inference, see [devlin2018bert]. 384 | 385 | When used as a module, the [CLS]-token is appended **to the end** of each item in 386 | the batch. 387 | 388 | Examples: 389 | .. testcode:: 390 | 391 | batch_size = 2 392 | n_tokens = 3 393 | d_token = 4 394 | cls_token = CLSToken(d_token, 'uniform') 395 | x = torch.randn(batch_size, n_tokens, d_token) 396 | x = cls_token(x) 397 | assert x.shape == (batch_size, n_tokens + 1, d_token) 398 | assert (x[:, -1, :] == cls_token.expand(len(x))).all() 399 | 400 | References: 401 | * [devlin2018bert] Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" 2018 402 | """ 403 | 404 | def __init__(self, d_token: int, initialization: str) -> None: 405 | """ 406 | Args: 407 | d_token: the size of token 408 | initialization: initialization policy for parameters. Must be one of 409 | :code:`['uniform', 'normal']`. Let :code:`s = d ** -0.5`. Then, the 410 | corresponding distributions are :code:`Uniform(-s, s)` and :code:`Normal(0, s)`. In 411 | the paper [gorishniy2021revisiting], the 'uniform' initialization was 412 | used. 413 | 414 | References: 415 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko "Revisiting Deep Learning Models for Tabular Data", 2021 416 | """ 417 | super().__init__() 418 | initialization_ = _TokenInitialization.from_str(initialization) 419 | self.weight = nn.Parameter(Tensor(d_token)) 420 | initialization_.apply(self.weight, d_token) 421 | 422 | def expand(self, *leading_dimensions: int) -> Tensor: 423 | """Expand (repeat) the underlying [CLS]-token to a tensor with the given leading dimensions. 424 | 425 | A possible use case is building a batch of [CLS]-tokens. See `CLSToken` for 426 | examples of usage. 427 | 428 | Note: 429 | Under the hood, the `torch.Tensor.expand` method is applied to the 430 | underlying :code:`weight` parameter, so gradients will be propagated as 431 | expected. 432 | 433 | Args: 434 | leading_dimensions: the additional new dimensions 435 | 436 | Returns: 437 | tensor of the shape :code:`(*leading_dimensions, len(self.weight))` 438 | """ 439 | if not leading_dimensions: 440 | return self.weight 441 | new_dims = (1,) * (len(leading_dimensions) - 1) 442 | return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1) 443 | 444 | def forward(self, x: Tensor) -> Tensor: 445 | """Append self **to the end** of each item in the batch (see `CLSToken`).""" 446 | return torch.cat([x, self.expand(len(x), 1)], dim=1) 447 | 448 | 449 | def _make_nn_module(module_type: ModuleType, *args) -> nn.Module: 450 | if isinstance(module_type, str): 451 | if module_type == 'ReGLU': 452 | return ReGLU() 453 | elif module_type == 'GEGLU': 454 | return GEGLU() 455 | else: 456 | try: 457 | cls = getattr(nn, module_type) 458 | except AttributeError as err: 459 | raise ValueError( 460 | f'Failed to construct the module {module_type} with the arguments {args}' 461 | ) from err 462 | return cls(*args) 463 | else: 464 | return module_type(*args) 465 | 466 | 467 | class MLP(nn.Module): 468 | """The MLP model used in [gorishniy2021revisiting]. 469 | 470 | The following scheme describes the architecture: 471 | 472 | .. code-block:: text 473 | 474 | MLP: (in) -> Block -> ... -> Block -> Linear -> (out) 475 | Block: (in) -> Linear -> Activation -> Dropout -> (out) 476 | 477 | Examples: 478 | .. testcode:: 479 | 480 | x = torch.randn(4, 2) 481 | module = MLP.make_baseline(x.shape[1], [3, 5], 0.1, 1) 482 | assert module(x).shape == (len(x), 1) 483 | 484 | References: 485 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 486 | """ 487 | 488 | class Block(nn.Module): 489 | """The main building block of `MLP`.""" 490 | 491 | def __init__( 492 | self, 493 | *, 494 | d_in: int, 495 | d_out: int, 496 | bias: bool, 497 | activation: ModuleType, 498 | dropout: float, 499 | ) -> None: 500 | super().__init__() 501 | self.linear = nn.Linear(d_in, d_out, bias) 502 | self.activation = _make_nn_module(activation) 503 | self.dropout = nn.Dropout(dropout) 504 | 505 | def forward(self, x: Tensor) -> Tensor: 506 | return self.dropout(self.activation(self.linear(x))) 507 | 508 | def __init__( 509 | self, 510 | *, 511 | d_in: int, 512 | d_layers: List[int], 513 | dropouts: Union[float, List[float]], 514 | activation: Union[str, Callable[[], nn.Module]], 515 | d_out: int, 516 | ) -> None: 517 | """ 518 | Note: 519 | `make_baseline` is the recommended constructor. 520 | """ 521 | super().__init__() 522 | if isinstance(dropouts, float): 523 | dropouts = [dropouts] * len(d_layers) 524 | assert len(d_layers) == len(dropouts) 525 | 526 | self.blocks = nn.Sequential( 527 | *[ 528 | MLP.Block( 529 | d_in=d_layers[i - 1] if i else d_in, 530 | d_out=d, 531 | bias=True, 532 | activation=activation, 533 | dropout=dropout, 534 | ) 535 | for i, (d, dropout) in enumerate(zip(d_layers, dropouts)) 536 | ] 537 | ) 538 | self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out) 539 | 540 | @classmethod 541 | def make_baseline( 542 | cls: Type['MLP'], 543 | d_in: int, 544 | d_layers: List[int], 545 | dropout: float, 546 | d_out: int, 547 | ) -> 'MLP': 548 | """Create a "baseline" `MLP`. 549 | 550 | This variation of MLP was used in [gorishniy2021revisiting]. Features: 551 | 552 | * :code:`Activation` = :code:`ReLU` 553 | * all linear layers except for the first one and the last one are of the same dimension 554 | * the dropout rate is the same for all dropout layers 555 | 556 | Args: 557 | d_in: the input size 558 | d_layers: the dimensions of the linear layers. If there are more than two 559 | layers, then all of them except for the first and the last ones must 560 | have the same dimension. Valid examples: :code:`[]`, :code:`[8]`, 561 | :code:`[8, 16]`, :code:`[2, 2, 2, 2]`, :code:`[1, 2, 2, 4]`. Invalid 562 | example: :code:`[1, 2, 3, 4]`. 563 | dropout: the dropout rate for all hidden layers 564 | d_out: the output size 565 | Returns: 566 | MLP 567 | 568 | References: 569 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 570 | """ 571 | assert isinstance(dropout, float), 'In this constructor, dropout must be float' 572 | if len(d_layers) > 2: 573 | assert len(set(d_layers[1:-1])) == 1, ( 574 | 'In this constructor, if d_layers contains more than two elements, then' 575 | ' all elements except for the first and the last ones must be equal.' 576 | ) 577 | return MLP( 578 | d_in=d_in, 579 | d_layers=d_layers, # type: ignore 580 | dropouts=dropout, 581 | activation='ReLU', 582 | d_out=d_out, 583 | ) 584 | 585 | def forward(self, x: Tensor) -> Tensor: 586 | x = self.blocks(x) 587 | x = self.head(x) 588 | return x 589 | 590 | 591 | class ResNet(nn.Module): 592 | """The ResNet model used in [gorishniy2021revisiting]. 593 | 594 | The following scheme describes the architecture: 595 | 596 | .. code-block:: text 597 | 598 | ResNet: (in) -> Linear -> Block -> ... -> Block -> Head -> (out) 599 | 600 | |-> Norm -> Linear -> Activation -> Dropout -> Linear -> Dropout ->| 601 | | | 602 | Block: (in) ------------------------------------------------------------> Add -> (out) 603 | 604 | Head: (in) -> Norm -> Activation -> Linear -> (out) 605 | 606 | Examples: 607 | .. testcode:: 608 | 609 | x = torch.randn(4, 2) 610 | module = ResNet.make_baseline( 611 | d_in=x.shape[1], 612 | n_blocks=2, 613 | d_main=3, 614 | d_hidden=4, 615 | dropout_first=0.25, 616 | dropout_second=0.0, 617 | d_out=1 618 | ) 619 | assert module(x).shape == (len(x), 1) 620 | 621 | References: 622 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 623 | """ 624 | 625 | class Block(nn.Module): 626 | """The main building block of `ResNet`.""" 627 | 628 | def __init__( 629 | self, 630 | *, 631 | d_main: int, 632 | d_hidden: int, 633 | bias_first: bool, 634 | bias_second: bool, 635 | dropout_first: float, 636 | dropout_second: float, 637 | normalization: ModuleType, 638 | activation: ModuleType, 639 | skip_connection: bool, 640 | ) -> None: 641 | super().__init__() 642 | self.normalization = _make_nn_module(normalization, d_main) 643 | self.linear_first = nn.Linear(d_main, d_hidden, bias_first) 644 | self.activation = _make_nn_module(activation) 645 | self.dropout_first = nn.Dropout(dropout_first) 646 | self.linear_second = nn.Linear(d_hidden, d_main, bias_second) 647 | self.dropout_second = nn.Dropout(dropout_second) 648 | self.skip_connection = skip_connection 649 | 650 | def forward(self, x: Tensor) -> Tensor: 651 | x_input = x 652 | x = self.normalization(x) 653 | x = self.linear_first(x) 654 | x = self.activation(x) 655 | x = self.dropout_first(x) 656 | x = self.linear_second(x) 657 | x = self.dropout_second(x) 658 | if self.skip_connection: 659 | x = x_input + x 660 | return x 661 | 662 | class Head(nn.Module): 663 | """The final module of `ResNet`.""" 664 | 665 | def __init__( 666 | self, 667 | *, 668 | d_in: int, 669 | d_out: int, 670 | bias: bool, 671 | normalization: ModuleType, 672 | activation: ModuleType, 673 | ) -> None: 674 | super().__init__() 675 | self.normalization = _make_nn_module(normalization, d_in) 676 | self.activation = _make_nn_module(activation) 677 | self.linear = nn.Linear(d_in, d_out, bias) 678 | 679 | def forward(self, x: Tensor) -> Tensor: 680 | if self.normalization is not None: 681 | x = self.normalization(x) 682 | x = self.activation(x) 683 | x = self.linear(x) 684 | return x 685 | 686 | def __init__( 687 | self, 688 | *, 689 | d_in: int, 690 | n_blocks: int, 691 | d_main: int, 692 | d_hidden: int, 693 | dropout_first: float, 694 | dropout_second: float, 695 | normalization: ModuleType, 696 | activation: ModuleType, 697 | d_out: int, 698 | ) -> None: 699 | """ 700 | Note: 701 | `make_baseline` is the recommended constructor. 702 | """ 703 | super().__init__() 704 | 705 | self.first_layer = nn.Linear(d_in, d_main) 706 | if d_main is None: 707 | d_main = d_in 708 | self.blocks = nn.Sequential( 709 | *[ 710 | ResNet.Block( 711 | d_main=d_main, 712 | d_hidden=d_hidden, 713 | bias_first=True, 714 | bias_second=True, 715 | dropout_first=dropout_first, 716 | dropout_second=dropout_second, 717 | normalization=normalization, 718 | activation=activation, 719 | skip_connection=True, 720 | ) 721 | for _ in range(n_blocks) 722 | ] 723 | ) 724 | self.head = ResNet.Head( 725 | d_in=d_main, 726 | d_out=d_out, 727 | bias=True, 728 | normalization=normalization, 729 | activation=activation, 730 | ) 731 | 732 | @classmethod 733 | def make_baseline( 734 | cls: Type['ResNet'], 735 | *, 736 | d_in: int, 737 | n_blocks: int, 738 | d_main: int, 739 | d_hidden: int, 740 | dropout_first: float, 741 | dropout_second: float, 742 | d_out: int, 743 | ) -> 'ResNet': 744 | """Create a "baseline" `ResNet`. 745 | 746 | This variation of ResNet was used in [gorishniy2021revisiting]. Features: 747 | 748 | * :code:`Activation` = :code:`ReLU` 749 | * :code:`Norm` = :code:`BatchNorm1d` 750 | 751 | Args: 752 | d_in: the input size 753 | n_blocks: the number of Blocks 754 | d_main: the input size (or, equivalently, the output size) of each Block 755 | d_hidden: the output size of the first linear layer in each Block 756 | dropout_first: the dropout rate of the first dropout layer in each Block. 757 | dropout_second: the dropout rate of the second dropout layer in each Block. 758 | 759 | References: 760 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 761 | """ 762 | return cls( 763 | d_in=d_in, 764 | n_blocks=n_blocks, 765 | d_main=d_main, 766 | d_hidden=d_hidden, 767 | dropout_first=dropout_first, 768 | dropout_second=dropout_second, 769 | normalization='BatchNorm1d', 770 | activation='ReLU', 771 | d_out=d_out, 772 | ) 773 | 774 | def forward(self, x: Tensor) -> Tensor: 775 | x = self.first_layer(x) 776 | x = self.blocks(x) 777 | x = self.head(x) 778 | return x 779 | 780 | 781 | class MultiheadAttention(nn.Module): 782 | """Multihead Attention (self-/cross-) with optional 'linear' attention. 783 | 784 | To learn more about Multihead Attention, see [devlin2018bert]. See the implementation 785 | of `Transformer` and the examples below to learn how to use the compression technique 786 | from [wang2020linformer] to speed up the module when the number of tokens is large. 787 | 788 | Examples: 789 | .. testcode:: 790 | 791 | n_objects, n_tokens, d_token = 2, 3, 12 792 | n_heads = 6 793 | a = torch.randn(n_objects, n_tokens, d_token) 794 | b = torch.randn(n_objects, n_tokens * 2, d_token) 795 | module = MultiheadAttention( 796 | d_token=d_token, n_heads=n_heads, dropout=0.2, bias=True, initialization='kaiming' 797 | ) 798 | 799 | # self-attention 800 | x, attention_stats = module(a, a, None, None) 801 | assert x.shape == a.shape 802 | assert attention_stats['attention_probs'].shape == (n_objects * n_heads, n_tokens, n_tokens) 803 | assert attention_stats['attention_logits'].shape == (n_objects * n_heads, n_tokens, n_tokens) 804 | 805 | # cross-attention 806 | assert module(a, b, None, None) 807 | 808 | # Linformer self-attention with the 'headwise' sharing policy 809 | k_compression = torch.nn.Linear(n_tokens, n_tokens // 4) 810 | v_compression = torch.nn.Linear(n_tokens, n_tokens // 4) 811 | assert module(a, a, k_compression, v_compression) 812 | 813 | # Linformer self-attention with the 'key-value' sharing policy 814 | kv_compression = torch.nn.Linear(n_tokens, n_tokens // 4) 815 | assert module(a, a, kv_compression, kv_compression) 816 | 817 | References: 818 | * [devlin2018bert] Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" 2018 819 | * [wang2020linformer] Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma "Linformer: Self-Attention with Linear Complexity", 2020 820 | """ 821 | 822 | def __init__( 823 | self, 824 | *, 825 | d_token: int, 826 | n_heads: int, 827 | dropout: float, 828 | bias: bool, 829 | initialization: str, 830 | ) -> None: 831 | """ 832 | Args: 833 | d_token: the token size. Must be a multiple of :code:`n_heads`. 834 | n_heads: the number of heads. If greater than 1, then the module will have 835 | an addition output layer (so called "mixing" layer). 836 | dropout: dropout rate for the attention map. The dropout is applied to 837 | *probabilities* and do not affect logits. 838 | bias: if `True`, then input (and output, if presented) layers also have bias. 839 | `True` is a reasonable default choice. 840 | initialization: initialization for input projection layers. Must be one of 841 | :code:`['kaiming', 'xavier']`. `kaiming` is a reasonable default choice. 842 | Raises: 843 | AssertionError: if requirements for the inputs are not met. 844 | """ 845 | super().__init__() 846 | if n_heads > 1: 847 | assert d_token % n_heads == 0, 'd_token must be a multiple of n_heads' 848 | assert initialization in ['kaiming', 'xavier'] 849 | 850 | self.W_q = nn.Linear(d_token, d_token, bias) 851 | self.W_k = nn.Linear(d_token, d_token, bias) 852 | self.W_v = nn.Linear(d_token, d_token, bias) 853 | self.W_out = nn.Linear(d_token, d_token, bias) if n_heads > 1 else None 854 | self.n_heads = n_heads 855 | self.dropout = nn.Dropout(dropout) if dropout else None 856 | 857 | for m in [self.W_q, self.W_k, self.W_v]: 858 | # the "xavier" branch tries to follow torch.nn.MultiheadAttention; 859 | # the second condition checks if W_v plays the role of W_out; the latter one 860 | # is initialized with Kaiming in torch 861 | if initialization == 'xavier' and ( 862 | m is not self.W_v or self.W_out is not None 863 | ): 864 | # gain is needed since W_qkv is represented with 3 separate layers (it 865 | # implies different fan_out) 866 | nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2)) 867 | if m.bias is not None: 868 | nn.init.zeros_(m.bias) 869 | if self.W_out is not None: 870 | nn.init.zeros_(self.W_out.bias) 871 | 872 | def _reshape(self, x: Tensor) -> Tensor: 873 | batch_size, n_tokens, d = x.shape 874 | d_head = d // self.n_heads 875 | return ( 876 | x.reshape(batch_size, n_tokens, self.n_heads, d_head) 877 | .transpose(1, 2) 878 | .reshape(batch_size * self.n_heads, n_tokens, d_head) 879 | ) 880 | 881 | def forward( 882 | self, 883 | x_q: Tensor, 884 | x_kv: Tensor, 885 | key_compression: Optional[nn.Linear], 886 | value_compression: Optional[nn.Linear], 887 | ) -> Tuple[Tensor, Dict[str, Tensor]]: 888 | """Perform the forward pass. 889 | 890 | Args: 891 | x_q: query tokens 892 | x_kv: key-value tokens 893 | key_compression: Linformer-style compression for keys 894 | value_compression: Linformer-style compression for values 895 | Returns: 896 | (tokens, attention_stats) 897 | """ 898 | assert _all_or_none( 899 | [key_compression, value_compression] 900 | ), 'If key_compression is (not) None, then value_compression must (not) be None' 901 | q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv) 902 | for tensor in [q, k, v]: 903 | assert tensor.shape[-1] % self.n_heads == 0, _INTERNAL_ERROR_MESSAGE 904 | if key_compression is not None: 905 | k = key_compression(k.transpose(1, 2)).transpose(1, 2) 906 | v = value_compression(v.transpose(1, 2)).transpose(1, 2) # type: ignore 907 | 908 | batch_size = len(q) 909 | d_head_key = k.shape[-1] // self.n_heads 910 | d_head_value = v.shape[-1] // self.n_heads 911 | n_q_tokens = q.shape[1] 912 | 913 | q = self._reshape(q) 914 | k = self._reshape(k) 915 | attention_logits = q @ k.transpose(1, 2) / math.sqrt(d_head_key) 916 | attention_probs = F.softmax(attention_logits, dim=-1) 917 | if self.dropout is not None: 918 | attention_probs = self.dropout(attention_probs) 919 | x = attention_probs @ self._reshape(v) 920 | x = ( 921 | x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value) 922 | .transpose(1, 2) 923 | .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value) 924 | ) 925 | if self.W_out is not None: 926 | x = self.W_out(x) 927 | return x, { 928 | 'attention_logits': attention_logits, 929 | 'attention_probs': attention_probs, 930 | } 931 | 932 | 933 | class Transformer(nn.Module): 934 | """Transformer with extra features. 935 | 936 | This module is the backbone of `FTTransformer`.""" 937 | 938 | WARNINGS: ClassVar[Dict[str, bool]] = { 939 | 'first_prenormalization': True, 940 | 'prenormalization': True, 941 | } 942 | 943 | class FFN(nn.Module): 944 | """The Feed-Forward Network module used in every `Transformer` block.""" 945 | 946 | def __init__( 947 | self, 948 | *, 949 | d_token: int, 950 | d_hidden: int, 951 | bias_first: bool, 952 | bias_second: bool, 953 | dropout: float, 954 | activation: ModuleType, 955 | ): 956 | super().__init__() 957 | self.linear_first = nn.Linear( 958 | d_token, 959 | d_hidden * (2 if _is_glu_activation(activation) else 1), 960 | bias_first, 961 | ) 962 | self.activation = _make_nn_module(activation) 963 | self.dropout = nn.Dropout(dropout) 964 | self.linear_second = nn.Linear(d_hidden, d_token, bias_second) 965 | 966 | def forward(self, x: Tensor) -> Tensor: 967 | x = self.linear_first(x) 968 | x = self.activation(x) 969 | x = self.dropout(x) 970 | x = self.linear_second(x) 971 | return x 972 | 973 | class Head(nn.Module): 974 | """The final module of the `Transformer` that performs BERT-like inference.""" 975 | 976 | def __init__( 977 | self, 978 | *, 979 | d_in: int, 980 | bias: bool, 981 | activation: ModuleType, 982 | normalization: ModuleType, 983 | d_out: int, 984 | ): 985 | super().__init__() 986 | self.normalization = _make_nn_module(normalization, d_in) 987 | self.activation = _make_nn_module(activation) 988 | self.linear = nn.Linear(d_in, d_out, bias) 989 | 990 | def forward(self, x: Tensor) -> Tensor: 991 | x = x[:, -1] 992 | x = self.normalization(x) 993 | x = self.activation(x) 994 | x = self.linear(x) 995 | return x 996 | 997 | def __init__( 998 | self, 999 | *, 1000 | d_token: int, 1001 | n_blocks: int, 1002 | attention_n_heads: int, 1003 | attention_dropout: float, 1004 | attention_initialization: str, 1005 | attention_normalization: str, 1006 | ffn_d_hidden: int, 1007 | ffn_dropout: float, 1008 | ffn_activation: str, 1009 | ffn_normalization: str, 1010 | residual_dropout: float, 1011 | prenormalization: bool, 1012 | first_prenormalization: bool, 1013 | last_layer_query_idx: Union[None, List[int], slice], 1014 | n_tokens: Optional[int], 1015 | kv_compression_ratio: Optional[float], 1016 | kv_compression_sharing: Optional[str], 1017 | head_activation: ModuleType, 1018 | head_normalization: ModuleType, 1019 | d_out: int, 1020 | ) -> None: 1021 | super().__init__() 1022 | if isinstance(last_layer_query_idx, int): 1023 | raise ValueError( 1024 | 'last_layer_query_idx must be None, list[int] or slice. ' 1025 | f'Do you mean last_layer_query_idx=[{last_layer_query_idx}] ?' 1026 | ) 1027 | if not prenormalization: 1028 | assert ( 1029 | not first_prenormalization 1030 | ), 'If `prenormalization` is False, then `first_prenormalization` must be False' 1031 | assert _all_or_none([n_tokens, kv_compression_ratio, kv_compression_sharing]), ( 1032 | 'If any of the following arguments is (not) None, then all of them must (not) be None: ' 1033 | 'n_tokens, kv_compression_ratio, kv_compression_sharing' 1034 | ) 1035 | assert kv_compression_sharing in [None, 'headwise', 'key-value', 'layerwise'] 1036 | if not prenormalization: 1037 | if self.WARNINGS['prenormalization']: 1038 | warnings.warn( 1039 | 'prenormalization is set to False. Are you sure about this? ' 1040 | 'The training can become less stable. ' 1041 | 'You can turn off this warning by tweaking the ' 1042 | 'rtdl.Transformer.WARNINGS dictionary.', 1043 | UserWarning, 1044 | ) 1045 | assert ( 1046 | not first_prenormalization 1047 | ), 'If prenormalization is False, then first_prenormalization is ignored and must be set to False' 1048 | if ( 1049 | prenormalization 1050 | and first_prenormalization 1051 | and self.WARNINGS['first_prenormalization'] 1052 | ): 1053 | warnings.warn( 1054 | 'first_prenormalization is set to True. Are you sure about this? ' 1055 | 'For example, the vanilla FTTransformer with ' 1056 | 'first_prenormalization=True performs SIGNIFICANTLY worse. ' 1057 | 'You can turn off this warning by tweaking the ' 1058 | 'rtdl.Transformer.WARNINGS dictionary.', 1059 | UserWarning, 1060 | ) 1061 | time.sleep(3) 1062 | 1063 | def make_kv_compression(): 1064 | assert ( 1065 | n_tokens and kv_compression_ratio 1066 | ), _INTERNAL_ERROR_MESSAGE # for mypy 1067 | # https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/examples/linformer/linformer_src/modules/multihead_linear_attention.py#L83 1068 | return nn.Linear(n_tokens, int(n_tokens * kv_compression_ratio), bias=False) 1069 | 1070 | self.shared_kv_compression = ( 1071 | make_kv_compression() 1072 | if kv_compression_ratio and kv_compression_sharing == 'layerwise' 1073 | else None 1074 | ) 1075 | 1076 | self.prenormalization = prenormalization 1077 | self.last_layer_query_idx = last_layer_query_idx 1078 | 1079 | self.blocks = nn.ModuleList([]) 1080 | for layer_idx in range(n_blocks): 1081 | layer = nn.ModuleDict( 1082 | { 1083 | 'attention': MultiheadAttention( 1084 | d_token=d_token, 1085 | n_heads=attention_n_heads, 1086 | dropout=attention_dropout, 1087 | bias=True, 1088 | initialization=attention_initialization, 1089 | ), 1090 | 'ffn': Transformer.FFN( 1091 | d_token=d_token, 1092 | d_hidden=ffn_d_hidden, 1093 | bias_first=True, 1094 | bias_second=True, 1095 | dropout=ffn_dropout, 1096 | activation=ffn_activation, 1097 | ), 1098 | 'attention_residual_dropout': nn.Dropout(residual_dropout), 1099 | 'ffn_residual_dropout': nn.Dropout(residual_dropout), 1100 | 'output': nn.Identity(), # for hooks-based introspection 1101 | } 1102 | ) 1103 | if layer_idx or not prenormalization or first_prenormalization: 1104 | layer['attention_normalization'] = _make_nn_module( 1105 | attention_normalization, d_token 1106 | ) 1107 | layer['ffn_normalization'] = _make_nn_module(ffn_normalization, d_token) 1108 | if kv_compression_ratio and self.shared_kv_compression is None: 1109 | layer['key_compression'] = make_kv_compression() 1110 | if kv_compression_sharing == 'headwise': 1111 | layer['value_compression'] = make_kv_compression() 1112 | else: 1113 | assert ( 1114 | kv_compression_sharing == 'key-value' 1115 | ), _INTERNAL_ERROR_MESSAGE 1116 | self.blocks.append(layer) 1117 | 1118 | self.head = Transformer.Head( 1119 | d_in=d_token, 1120 | d_out=d_out, 1121 | bias=True, 1122 | activation=head_activation, # type: ignore 1123 | normalization=head_normalization if prenormalization else 'Identity', 1124 | ) 1125 | 1126 | def _get_kv_compressions(self, layer): 1127 | return ( 1128 | (self.shared_kv_compression, self.shared_kv_compression) 1129 | if self.shared_kv_compression is not None 1130 | else (layer['key_compression'], layer['value_compression']) 1131 | if 'key_compression' in layer and 'value_compression' in layer 1132 | else (layer['key_compression'], layer['key_compression']) 1133 | if 'key_compression' in layer 1134 | else (None, None) 1135 | ) 1136 | 1137 | def _start_residual(self, layer, stage, x): 1138 | assert stage in ['attention', 'ffn'], _INTERNAL_ERROR_MESSAGE 1139 | x_residual = x 1140 | if self.prenormalization: 1141 | norm_key = f'{stage}_normalization' 1142 | if norm_key in layer: 1143 | x_residual = layer[norm_key](x_residual) 1144 | return x_residual 1145 | 1146 | def _end_residual(self, layer, stage, x, x_residual): 1147 | assert stage in ['attention', 'ffn'], _INTERNAL_ERROR_MESSAGE 1148 | x_residual = layer[f'{stage}_residual_dropout'](x_residual) 1149 | x = x + x_residual 1150 | if not self.prenormalization: 1151 | x = layer[f'{stage}_normalization'](x) 1152 | return x 1153 | 1154 | def forward(self, x: Tensor) -> Tensor: 1155 | assert ( 1156 | x.ndim == 3 1157 | ), 'The input must have 3 dimensions: (n_objects, n_tokens, d_token)' 1158 | for layer_idx, layer in enumerate(self.blocks): 1159 | layer = cast(nn.ModuleDict, layer) 1160 | 1161 | query_idx = ( 1162 | self.last_layer_query_idx if layer_idx + 1 == len(self.blocks) else None 1163 | ) 1164 | x_residual = self._start_residual(layer, 'attention', x) 1165 | x_residual, _ = layer['attention']( 1166 | x_residual if query_idx is None else x_residual[:, query_idx], 1167 | x_residual, 1168 | *self._get_kv_compressions(layer), 1169 | ) 1170 | if query_idx is not None: 1171 | x = x[:, query_idx] 1172 | x = self._end_residual(layer, 'attention', x, x_residual) 1173 | 1174 | x_residual = self._start_residual(layer, 'ffn', x) 1175 | x_residual = layer['ffn'](x_residual) 1176 | x = self._end_residual(layer, 'ffn', x, x_residual) 1177 | x = layer['output'](x) 1178 | 1179 | x = self.head(x) 1180 | return x 1181 | 1182 | 1183 | class FTTransformer(nn.Module): 1184 | """The FT-Transformer model proposed in [gorishniy2021revisiting]. 1185 | 1186 | Transforms features to tokens with `FeatureTokenizer` and applies `Transformer` [vaswani2017attention] 1187 | to the tokens. The following illustration provides a high-level overview of the 1188 | architecture: 1189 | 1190 | .. image:: ../images/ft_transformer.png 1191 | :scale: 25% 1192 | :alt: FT-Transformer 1193 | 1194 | The following illustration demonstrates one Transformer block for :code:`prenormalization=True`: 1195 | 1196 | .. image:: ../images/transformer_block.png 1197 | :scale: 25% 1198 | :alt: PreNorm Transformer block 1199 | 1200 | Examples: 1201 | .. testcode:: 1202 | 1203 | x_num = torch.randn(4, 3) 1204 | x_cat = torch.tensor([[0, 1], [1, 0], [0, 2], [1, 1]]) 1205 | 1206 | module = FTTransformer.make_baseline( 1207 | n_num_features=3, 1208 | cat_cardinalities=[2, 3], 1209 | d_token=8, 1210 | n_blocks=2, 1211 | attention_dropout=0.2, 1212 | ffn_d_hidden=6, 1213 | ffn_dropout=0.2, 1214 | residual_dropout=0.0, 1215 | d_out=1, 1216 | ) 1217 | x = module(x_num, x_cat) 1218 | assert x.shape == (4, 1) 1219 | 1220 | module = FTTransformer.make_default( 1221 | n_num_features=3, 1222 | cat_cardinalities=[2, 3], 1223 | d_out=1, 1224 | ) 1225 | x = module(x_num, x_cat) 1226 | assert x.shape == (4, 1) 1227 | 1228 | To learn more about the baseline and default parameters: 1229 | 1230 | .. testcode:: 1231 | 1232 | baseline_parameters = FTTransformer.get_baseline_transformer_subconfig() 1233 | default_parameters = FTTransformer.get_default_transformer_config() 1234 | 1235 | References: 1236 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 1237 | * [vaswani2017attention] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, "Attention Is All You Need", 2017 1238 | """ 1239 | 1240 | def __init__( 1241 | self, feature_tokenizer: FeatureTokenizer, transformer: Transformer 1242 | ) -> None: 1243 | """ 1244 | Note: 1245 | `make_baseline` and `make_default` are the recommended constructors. 1246 | """ 1247 | super().__init__() 1248 | if transformer.prenormalization: 1249 | assert 'attention_normalization' not in transformer.blocks[0], ( 1250 | 'In the prenormalization setting, FT-Transformer does not ' 1251 | 'allow using the first normalization layer ' 1252 | 'in the first transformer block' 1253 | ) 1254 | self.feature_tokenizer = feature_tokenizer 1255 | self.cls_token = CLSToken( 1256 | feature_tokenizer.d_token, feature_tokenizer.initialization 1257 | ) 1258 | self.transformer = transformer 1259 | 1260 | @classmethod 1261 | def get_baseline_transformer_subconfig( 1262 | cls: Type['FTTransformer'], 1263 | ) -> Dict[str, Any]: 1264 | """Get the baseline subset of parameters for the backbone.""" 1265 | return { 1266 | 'attention_n_heads': 8, 1267 | 'attention_initialization': 'kaiming', 1268 | 'ffn_activation': 'ReGLU', 1269 | 'attention_normalization': 'LayerNorm', 1270 | 'ffn_normalization': 'LayerNorm', 1271 | 'prenormalization': True, 1272 | 'first_prenormalization': False, 1273 | 'last_layer_query_idx': None, 1274 | 'n_tokens': None, 1275 | 'kv_compression_ratio': None, 1276 | 'kv_compression_sharing': None, 1277 | 'head_activation': 'ReLU', 1278 | 'head_normalization': 'LayerNorm', 1279 | } 1280 | 1281 | @classmethod 1282 | def get_default_transformer_config( 1283 | cls: Type['FTTransformer'], *, n_blocks: int = 3 1284 | ) -> Dict[str, Any]: 1285 | """Get the default parameters for the backbone. 1286 | 1287 | Note: 1288 | The configurations are different for different values of:code:`n_blocks`. 1289 | """ 1290 | assert 1 <= n_blocks <= 6 1291 | grid = { 1292 | 'd_token': [96, 128, 192, 256, 320, 384], 1293 | 'attention_dropout': [0.1, 0.15, 0.2, 0.25, 0.3, 0.35], 1294 | 'ffn_dropout': [0.0, 0.05, 0.1, 0.15, 0.2, 0.25], 1295 | } 1296 | arch_subconfig = {k: v[n_blocks - 1] for k, v in grid.items()} # type: ignore 1297 | baseline_subconfig = cls.get_baseline_transformer_subconfig() 1298 | # (4 / 3) for ReGLU/GEGLU activations results in almost the same parameter count 1299 | # as (2.0) for element-wise activations (e.g. ReLU or GELU; see the "else" branch) 1300 | ffn_d_hidden_factor = ( 1301 | (4 / 3) if _is_glu_activation(baseline_subconfig['ffn_activation']) else 2.0 1302 | ) 1303 | return { 1304 | 'n_blocks': n_blocks, 1305 | 'residual_dropout': 0.0, 1306 | 'ffn_d_hidden': int(arch_subconfig['d_token'] * ffn_d_hidden_factor), 1307 | **arch_subconfig, 1308 | **baseline_subconfig, 1309 | } 1310 | 1311 | @classmethod 1312 | def _make( 1313 | cls, 1314 | n_num_features, 1315 | cat_cardinalities, 1316 | transformer_config, 1317 | ): 1318 | feature_tokenizer = FeatureTokenizer( 1319 | n_num_features=n_num_features, 1320 | cat_cardinalities=cat_cardinalities, 1321 | d_token=transformer_config['d_token'], 1322 | ) 1323 | if transformer_config['d_out'] is None: 1324 | transformer_config['head_activation'] = None 1325 | if transformer_config['kv_compression_ratio'] is not None: 1326 | transformer_config['n_tokens'] = feature_tokenizer.n_tokens + 1 1327 | return FTTransformer( 1328 | feature_tokenizer, 1329 | Transformer(**transformer_config), 1330 | ) 1331 | 1332 | @classmethod 1333 | def make_baseline( 1334 | cls: Type['FTTransformer'], 1335 | *, 1336 | n_num_features: int, 1337 | cat_cardinalities: Optional[List[int]], 1338 | d_token: int, 1339 | n_blocks: int, 1340 | attention_dropout: float, 1341 | ffn_d_hidden: int, 1342 | ffn_dropout: float, 1343 | residual_dropout: float, 1344 | last_layer_query_idx: Union[None, List[int], slice] = None, 1345 | kv_compression_ratio: Optional[float] = None, 1346 | kv_compression_sharing: Optional[str] = None, 1347 | d_out: int, 1348 | ) -> 'FTTransformer': 1349 | """Create a "baseline" `FTTransformer`. 1350 | 1351 | This variation of FT-Transformer was used in [gorishniy2021revisiting]. See 1352 | `get_baseline_transformer_subconfig` to learn the values of other parameters. 1353 | See `FTTransformer` for usage examples. 1354 | 1355 | Tip: 1356 | `get_default_transformer_config` can serve as a starting point for choosing 1357 | hyperparameter values. 1358 | 1359 | Args: 1360 | n_num_features: the number of continuous features 1361 | cat_cardinalities: the cardinalities of categorical features (see 1362 | `CategoricalFeatureTokenizer` to learn more about cardinalities) 1363 | d_token: the token size for each feature. Must be a multiple of :code:`n_heads=8`. 1364 | n_blocks: the number of Transformer blocks 1365 | attention_dropout: the dropout for attention blocks (see `MultiheadAttention`). 1366 | Usually, positive values work better (even when the number of features is low). 1367 | ffn_d_hidden: the *input* size for the *second* linear layer in `Transformer.FFN`. 1368 | Note that it can be different from the output size of the first linear 1369 | layer, since activations such as ReGLU or GEGLU change the size of input. 1370 | For example, if :code:`ffn_d_hidden=10` and the activation is ReGLU (which 1371 | is always true for the baseline and default configurations), then the 1372 | output size of the first linear layer will be set to :code:`20`. 1373 | ffn_dropout: the dropout rate after the first linear layer in `Transformer.FFN`. 1374 | residual_dropout: the dropout rate for the output of each residual branch of 1375 | all Transformer blocks. 1376 | last_layer_query_idx: indices of tokens that should be processed by the last 1377 | Transformer block. Note that for most cases there is no need to apply 1378 | the last Transformer block to anything except for the [CLS]-token. Hence, 1379 | runtime and memory can be saved by setting :code:`last_layer_query_idx=[-1]`, 1380 | since the :code:`-1` is the position of [CLS]-token in FT-Transformer. 1381 | Note that this will not affect the result in any way. 1382 | kv_compression_ratio: apply the technique from [wang2020linformer] to speed 1383 | up attention modules when the number of features is large. Can actually 1384 | slow things down if the number of features is too low. Note that this 1385 | option can affect task metrics in unpredictable way. Overall, use this 1386 | option with caution. See `MultiheadAttention` for some examples and the 1387 | implementation of `Transformer` to see how this option is used. 1388 | kv_compression_sharing: weight sharing policy for :code:`kv_compression_ratio`. 1389 | Must be one of :code:`[None, 'headwise', 'key-value', 'layerwise']`. 1390 | See [wang2020linformer] to learn more about sharing policies. 1391 | :code:`headwise` and :code:`key-value` are reasonable default choices. If 1392 | :code:`kv_compression_ratio` is `None`, then this parameter also must be 1393 | `None`. Otherwise, it must not be `None` (compression parameters must be 1394 | shared in some way). 1395 | d_out: the output size. 1396 | 1397 | References: 1398 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 1399 | * [wang2020linformer] Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma "Linformer: Self-Attention with Linear Complexity", 2020 1400 | """ 1401 | transformer_config = cls.get_baseline_transformer_subconfig() 1402 | for arg_name in [ 1403 | 'n_blocks', 1404 | 'd_token', 1405 | 'attention_dropout', 1406 | 'ffn_d_hidden', 1407 | 'ffn_dropout', 1408 | 'residual_dropout', 1409 | 'last_layer_query_idx', 1410 | 'kv_compression_ratio', 1411 | 'kv_compression_sharing', 1412 | 'd_out', 1413 | ]: 1414 | transformer_config[arg_name] = locals()[arg_name] 1415 | return cls._make(n_num_features, cat_cardinalities, transformer_config) 1416 | 1417 | @classmethod 1418 | def make_default( 1419 | cls: Type['FTTransformer'], 1420 | *, 1421 | n_num_features: int, 1422 | cat_cardinalities: Optional[List[int]], 1423 | n_blocks: int = 3, 1424 | last_layer_query_idx: Union[None, List[int], slice] = None, 1425 | kv_compression_ratio: Optional[float] = None, 1426 | kv_compression_sharing: Optional[str] = None, 1427 | d_out: int, 1428 | ) -> 'FTTransformer': 1429 | """Create the default `FTTransformer`. 1430 | 1431 | With :code:`n_blocks=3` (default) it is the FT-Transformer variation that is 1432 | referred to as "default FT-Transformer" in [gorishniy2021revisiting]. See 1433 | `FTTransformer` for usage examples. See `FTTransformer.make_baseline` for 1434 | parameter descriptions. 1435 | 1436 | Note: 1437 | The second component of the default FT-Transformer is the default optimizer, 1438 | which can be created with the `make_default_optimizer` method. 1439 | 1440 | Note: 1441 | According to [gorishniy2021revisiting], the default FT-Transformer is 1442 | effective in the ensembling mode (i.e. when predictions of several default 1443 | FT-Transformers are averaged). For a single FT-Transformer, it is still 1444 | possible to achieve better results by tuning hyperparameter for the 1445 | `make_baseline` constructor. 1446 | 1447 | References: 1448 | * [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko, "Revisiting Deep Learning Models for Tabular Data", 2021 1449 | """ 1450 | transformer_config = cls.get_default_transformer_config(n_blocks=n_blocks) 1451 | for arg_name in [ 1452 | 'last_layer_query_idx', 1453 | 'kv_compression_ratio', 1454 | 'kv_compression_sharing', 1455 | 'd_out', 1456 | ]: 1457 | transformer_config[arg_name] = locals()[arg_name] 1458 | return cls._make(n_num_features, cat_cardinalities, transformer_config) 1459 | 1460 | def optimization_param_groups(self) -> List[Dict[str, Any]]: 1461 | """The replacement for :code:`.parameters()` when creating optimizers. 1462 | 1463 | Example:: 1464 | 1465 | optimizer = AdamW( 1466 | model.optimization_param_groups(), lr=1e-4, weight_decay=1e-5 1467 | ) 1468 | """ 1469 | no_wd_names = ['feature_tokenizer', 'normalization', '.bias'] 1470 | assert isinstance( 1471 | getattr(self, no_wd_names[0], None), FeatureTokenizer 1472 | ), _INTERNAL_ERROR_MESSAGE 1473 | assert ( 1474 | sum(1 for name, _ in self.named_modules() if no_wd_names[1] in name) 1475 | == len(self.transformer.blocks) * 2 1476 | - int('attention_normalization' not in self.transformer.blocks[0]) # type: ignore 1477 | + 1 1478 | ), _INTERNAL_ERROR_MESSAGE 1479 | 1480 | def needs_wd(name): 1481 | return all(x not in name for x in no_wd_names) 1482 | 1483 | return [ 1484 | {'params': [v for k, v in self.named_parameters() if needs_wd(k)]}, 1485 | { 1486 | 'params': [v for k, v in self.named_parameters() if not needs_wd(k)], 1487 | 'weight_decay': 0.0, 1488 | }, 1489 | ] 1490 | 1491 | def make_default_optimizer(self) -> torch.optim.AdamW: 1492 | """Make the optimizer for the default FT-Transformer.""" 1493 | return torch.optim.AdamW( 1494 | self.optimization_param_groups(), 1495 | lr=1e-4, 1496 | weight_decay=1e-5, 1497 | ) 1498 | 1499 | def forward(self, x_num: Optional[Tensor], x_cat: Optional[Tensor]) -> Tensor: 1500 | x = self.feature_tokenizer(x_num, x_cat) 1501 | x = self.cls_token(x) 1502 | x = self.transformer(x) 1503 | return x 1504 | --------------------------------------------------------------------------------