├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── tab-vs-ft.png ├── tab.png └── tab_transformer_pytorch ├── __init__.py ├── ft_transformer.py └── tab_transformer_pytorch.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Tab Transformer 4 | 5 | Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance. 6 | 7 | Update: Amazon AI claims to have beaten GBDT with Attention on a real-world tabular dataset (predicting shipping cost). 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install tab-transformer-pytorch 13 | ``` 14 | 15 | ## Usage 16 | 17 | ```python 18 | import torch 19 | import torch.nn as nn 20 | from tab_transformer_pytorch import TabTransformer 21 | 22 | cont_mean_std = torch.randn(10, 2) 23 | 24 | model = TabTransformer( 25 | categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category 26 | num_continuous = 10, # number of continuous values 27 | dim = 32, # dimension, paper set at 32 28 | dim_out = 1, # binary prediction, but could be anything 29 | depth = 6, # depth, paper recommended 6 30 | heads = 8, # heads, paper recommends 8 31 | attn_dropout = 0.1, # post-attention dropout 32 | ff_dropout = 0.1, # feed forward dropout 33 | mlp_hidden_mults = (4, 2), # relative multiples of each hidden dimension of the last mlp to logits 34 | mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc) 35 | continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm 36 | ) 37 | 38 | x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above 39 | x_cont = torch.randn(1, 10) # assume continuous values are already normalized individually 40 | 41 | pred = model(x_categ, x_cont) # (1, 1) 42 | ``` 43 | 44 | ## FT Transformer 45 | 46 | 47 | 48 | This paper from Yandex improves on Tab Transformer by using a simpler scheme for embedding the continuous numerical values as shown in the diagram above, courtesy of this reddit post. 49 | 50 | Included in this repository just for convenient comparison to Tab Transformer 51 | 52 | ```python 53 | import torch 54 | from tab_transformer_pytorch import FTTransformer 55 | 56 | model = FTTransformer( 57 | categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category 58 | num_continuous = 10, # number of continuous values 59 | dim = 32, # dimension, paper set at 32 60 | dim_out = 1, # binary prediction, but could be anything 61 | depth = 6, # depth, paper recommended 6 62 | heads = 8, # heads, paper recommends 8 63 | attn_dropout = 0.1, # post-attention dropout 64 | ff_dropout = 0.1 # feed forward dropout 65 | ) 66 | 67 | x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above 68 | x_numer = torch.randn(1, 10) # numerical value 69 | 70 | pred = model(x_categ, x_numer) # (1, 1) 71 | ``` 72 | 73 | ## Unsupervised Training 74 | 75 | To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on `model.transformer`. 76 | 77 | ## Todo 78 | 79 | - [ ] consider https://arxiv.org/abs/2203.05556 80 | 81 | ## Citations 82 | 83 | ```bibtex 84 | @misc{huang2020tabtransformer, 85 | title = {TabTransformer: Tabular Data Modeling Using Contextual Embeddings}, 86 | author = {Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin}, 87 | year = {2020}, 88 | eprint = {2012.06678}, 89 | archivePrefix = {arXiv}, 90 | primaryClass = {cs.LG} 91 | } 92 | ``` 93 | 94 | ```bibtex 95 | @article{Gorishniy2021RevisitingDL, 96 | title = {Revisiting Deep Learning Models for Tabular Data}, 97 | author = {Yu. V. Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko}, 98 | journal = {ArXiv}, 99 | year = {2021}, 100 | volume = {abs/2106.11959} 101 | } 102 | ``` 103 | 104 | ```bibtex 105 | @article{Zhu2024HyperConnections, 106 | title = {Hyper-Connections}, 107 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, 108 | journal = {ArXiv}, 109 | year = {2024}, 110 | volume = {abs/2409.19606}, 111 | url = {https://api.semanticscholar.org/CorpusID:272987528} 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'tab-transformer-pytorch', 5 | packages = find_packages(), 6 | version = '0.4.2', 7 | license='MIT', 8 | description = 'Tab Transformer - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Phil Wang', 11 | author_email = 'lucidrains@gmail.com', 12 | url = 'https://github.com/lucidrains/tab-transformer-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'transformers', 16 | 'attention mechanism', 17 | 'tabular data' 18 | ], 19 | install_requires=[ 20 | 'einops>=0.8', 21 | 'hyper-connections>=0.1.15', 22 | 'torch>=2.3' 23 | ], 24 | classifiers=[ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 28 | 'License :: OSI Approved :: MIT License', 29 | 'Programming Language :: Python :: 3.6', 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /tab-vs-ft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/tab-transformer-pytorch/c1b9aa9c28f530d22ee95410842685525646bf64/tab-vs-ft.png -------------------------------------------------------------------------------- /tab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/tab-transformer-pytorch/c1b9aa9c28f530d22ee95410842685525646bf64/tab.png -------------------------------------------------------------------------------- /tab_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from tab_transformer_pytorch.tab_transformer_pytorch import TabTransformer 2 | from tab_transformer_pytorch.ft_transformer import FTTransformer 3 | -------------------------------------------------------------------------------- /tab_transformer_pytorch/ft_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from torch.nn import Module, ModuleList 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, repeat 7 | 8 | from hyper_connections import HyperConnections 9 | 10 | # feedforward and attention 11 | 12 | class GEGLU(Module): 13 | def forward(self, x): 14 | x, gates = x.chunk(2, dim = -1) 15 | return x * F.gelu(gates) 16 | 17 | def FeedForward(dim, mult = 4, dropout = 0.): 18 | return nn.Sequential( 19 | nn.LayerNorm(dim), 20 | nn.Linear(dim, dim * mult * 2), 21 | GEGLU(), 22 | nn.Dropout(dropout), 23 | nn.Linear(dim * mult, dim) 24 | ) 25 | 26 | class Attention(Module): 27 | def __init__( 28 | self, 29 | dim, 30 | heads = 8, 31 | dim_head = 64, 32 | dropout = 0. 33 | ): 34 | super().__init__() 35 | inner_dim = dim_head * heads 36 | self.heads = heads 37 | self.scale = dim_head ** -0.5 38 | 39 | self.norm = nn.LayerNorm(dim) 40 | 41 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 42 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 43 | 44 | self.dropout = nn.Dropout(dropout) 45 | 46 | def forward(self, x): 47 | h = self.heads 48 | 49 | x = self.norm(x) 50 | 51 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 52 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 53 | q = q * self.scale 54 | 55 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 56 | 57 | attn = sim.softmax(dim = -1) 58 | dropped_attn = self.dropout(attn) 59 | 60 | out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v) 61 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 62 | out = self.to_out(out) 63 | 64 | return out, attn 65 | 66 | # transformer 67 | 68 | class Transformer(Module): 69 | def __init__( 70 | self, 71 | dim, 72 | depth, 73 | heads, 74 | dim_head, 75 | attn_dropout, 76 | ff_dropout, 77 | num_residual_streams = 4 78 | ): 79 | super().__init__() 80 | 81 | init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 82 | 83 | self.layers = ModuleList([]) 84 | 85 | for _ in range(depth): 86 | self.layers.append(ModuleList([ 87 | init_hyper_conn(dim = dim, branch = Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)), 88 | init_hyper_conn(dim = dim, branch = FeedForward(dim, dropout = ff_dropout)), 89 | ])) 90 | 91 | def forward(self, x, return_attn = False): 92 | post_softmax_attns = [] 93 | 94 | x = self.expand_streams(x) 95 | 96 | for attn, ff in self.layers: 97 | x, post_softmax_attn = attn(x) 98 | post_softmax_attns.append(post_softmax_attn) 99 | 100 | x = ff(x) 101 | 102 | x = self.reduce_streams(x) 103 | 104 | if not return_attn: 105 | return x 106 | 107 | return x, torch.stack(post_softmax_attns) 108 | 109 | # numerical embedder 110 | 111 | class NumericalEmbedder(Module): 112 | def __init__(self, dim, num_numerical_types): 113 | super().__init__() 114 | self.weights = nn.Parameter(torch.randn(num_numerical_types, dim)) 115 | self.biases = nn.Parameter(torch.randn(num_numerical_types, dim)) 116 | 117 | def forward(self, x): 118 | x = rearrange(x, 'b n -> b n 1') 119 | return x * self.weights + self.biases 120 | 121 | # main class 122 | 123 | class FTTransformer(Module): 124 | def __init__( 125 | self, 126 | *, 127 | categories, 128 | num_continuous, 129 | dim, 130 | depth, 131 | heads, 132 | dim_head = 16, 133 | dim_out = 1, 134 | num_special_tokens = 2, 135 | attn_dropout = 0., 136 | ff_dropout = 0., 137 | num_residual_streams = 4 138 | ): 139 | super().__init__() 140 | assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive' 141 | assert len(categories) + num_continuous > 0, 'input shape must not be null' 142 | 143 | # categories related calculations 144 | 145 | self.num_categories = len(categories) 146 | self.num_unique_categories = sum(categories) 147 | 148 | # create category embeddings table 149 | 150 | self.num_special_tokens = num_special_tokens 151 | total_tokens = self.num_unique_categories + num_special_tokens 152 | 153 | # for automatically offsetting unique category ids to the correct position in the categories embedding table 154 | 155 | if self.num_unique_categories > 0: 156 | categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens) 157 | categories_offset = categories_offset.cumsum(dim = -1)[:-1] 158 | self.register_buffer('categories_offset', categories_offset) 159 | 160 | # categorical embedding 161 | 162 | self.categorical_embeds = nn.Embedding(total_tokens, dim) 163 | 164 | # continuous 165 | 166 | self.num_continuous = num_continuous 167 | 168 | if self.num_continuous > 0: 169 | self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous) 170 | 171 | # cls token 172 | 173 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 174 | 175 | # transformer 176 | 177 | self.transformer = Transformer( 178 | dim = dim, 179 | depth = depth, 180 | heads = heads, 181 | dim_head = dim_head, 182 | attn_dropout = attn_dropout, 183 | ff_dropout = ff_dropout, 184 | num_residual_streams = num_residual_streams 185 | ) 186 | 187 | # to logits 188 | 189 | self.to_logits = nn.Sequential( 190 | nn.LayerNorm(dim), 191 | nn.ReLU(), 192 | nn.Linear(dim, dim_out) 193 | ) 194 | 195 | def forward(self, x_categ, x_numer, return_attn = False): 196 | assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input' 197 | 198 | xs = [] 199 | if self.num_unique_categories > 0: 200 | x_categ = x_categ + self.categories_offset 201 | 202 | x_categ = self.categorical_embeds(x_categ) 203 | 204 | xs.append(x_categ) 205 | 206 | # add numerically embedded tokens 207 | if self.num_continuous > 0: 208 | x_numer = self.numerical_embedder(x_numer) 209 | 210 | xs.append(x_numer) 211 | 212 | # concat categorical and numerical 213 | 214 | x = torch.cat(xs, dim = 1) 215 | 216 | # append cls tokens 217 | b = x.shape[0] 218 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) 219 | x = torch.cat((cls_tokens, x), dim = 1) 220 | 221 | # attend 222 | 223 | x, attns = self.transformer(x, return_attn = True) 224 | 225 | # get cls token 226 | 227 | x = x[:, 0] 228 | 229 | # out in the paper is linear(relu(ln(cls))) 230 | 231 | logits = self.to_logits(x) 232 | 233 | if not return_attn: 234 | return logits 235 | 236 | return logits, attns 237 | -------------------------------------------------------------------------------- /tab_transformer_pytorch/tab_transformer_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from torch.nn import Module, ModuleList 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange, repeat 7 | 8 | from hyper_connections import HyperConnections 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def default(val, d): 16 | return val if exists(val) else d 17 | 18 | # classes 19 | 20 | class PreNorm(Module): 21 | def __init__(self, dim, fn): 22 | super().__init__() 23 | self.norm = nn.LayerNorm(dim) 24 | self.fn = fn 25 | 26 | def forward(self, x, **kwargs): 27 | return self.fn(self.norm(x), **kwargs) 28 | 29 | # attention 30 | 31 | class GEGLU(Module): 32 | def forward(self, x): 33 | x, gates = x.chunk(2, dim = -1) 34 | return x * F.gelu(gates) 35 | 36 | class FeedForward(Module): 37 | def __init__(self, dim, mult = 4, dropout = 0.): 38 | super().__init__() 39 | self.net = nn.Sequential( 40 | nn.Linear(dim, dim * mult * 2), 41 | GEGLU(), 42 | nn.Dropout(dropout), 43 | nn.Linear(dim * mult, dim) 44 | ) 45 | 46 | def forward(self, x, **kwargs): 47 | return self.net(x) 48 | 49 | class Attention(Module): 50 | def __init__( 51 | self, 52 | dim, 53 | heads = 8, 54 | dim_head = 16, 55 | dropout = 0. 56 | ): 57 | super().__init__() 58 | inner_dim = dim_head * heads 59 | self.heads = heads 60 | self.scale = dim_head ** -0.5 61 | 62 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 63 | self.to_out = nn.Linear(inner_dim, dim) 64 | 65 | self.dropout = nn.Dropout(dropout) 66 | 67 | def forward(self, x): 68 | h = self.heads 69 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 70 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 71 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 72 | 73 | attn = sim.softmax(dim = -1) 74 | dropped_attn = self.dropout(attn) 75 | 76 | out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v) 77 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 78 | return self.to_out(out), attn 79 | 80 | # transformer 81 | 82 | class Transformer(Module): 83 | def __init__( 84 | self, 85 | dim, 86 | depth, 87 | heads, 88 | dim_head, 89 | attn_dropout, 90 | ff_dropout, 91 | num_residual_streams = 4 92 | ): 93 | super().__init__() 94 | self.layers = ModuleList([]) 95 | 96 | init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 97 | 98 | for _ in range(depth): 99 | self.layers.append(ModuleList([ 100 | init_hyper_conn(dim = dim, branch = PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))), 101 | init_hyper_conn(dim = dim, branch = PreNorm(dim, FeedForward(dim, dropout = ff_dropout))), 102 | ])) 103 | 104 | def forward(self, x, return_attn = False): 105 | post_softmax_attns = [] 106 | 107 | x = self.expand_streams(x) 108 | 109 | for attn, ff in self.layers: 110 | x, post_softmax_attn = attn(x) 111 | post_softmax_attns.append(post_softmax_attn) 112 | 113 | x = ff(x) 114 | 115 | x = self.reduce_streams(x) 116 | 117 | if not return_attn: 118 | return x 119 | 120 | return x, torch.stack(post_softmax_attns) 121 | # mlp 122 | 123 | class MLP(Module): 124 | def __init__(self, dims, act = None): 125 | super().__init__() 126 | dims_pairs = list(zip(dims[:-1], dims[1:])) 127 | layers = [] 128 | for ind, (dim_in, dim_out) in enumerate(dims_pairs): 129 | is_last = ind >= (len(dims_pairs) - 1) 130 | linear = nn.Linear(dim_in, dim_out) 131 | layers.append(linear) 132 | 133 | if is_last: 134 | continue 135 | 136 | act = default(act, nn.ReLU()) 137 | layers.append(act) 138 | 139 | self.mlp = nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | return self.mlp(x) 143 | 144 | # main class 145 | 146 | class TabTransformer(Module): 147 | def __init__( 148 | self, 149 | *, 150 | categories, 151 | num_continuous, 152 | dim, 153 | depth, 154 | heads, 155 | dim_head = 16, 156 | dim_out = 1, 157 | mlp_hidden_mults = (4, 2), 158 | mlp_act = None, 159 | num_special_tokens = 2, 160 | continuous_mean_std = None, 161 | attn_dropout = 0., 162 | ff_dropout = 0., 163 | use_shared_categ_embed = True, 164 | shared_categ_dim_divisor = 8., # in paper, they reserve dimension / 8 for category shared embedding 165 | num_residual_streams = 4 166 | ): 167 | super().__init__() 168 | assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive' 169 | assert len(categories) + num_continuous > 0, 'input shape must not be null' 170 | 171 | # categories related calculations 172 | 173 | self.num_categories = len(categories) 174 | self.num_unique_categories = sum(categories) 175 | 176 | # create category embeddings table 177 | 178 | self.num_special_tokens = num_special_tokens 179 | total_tokens = self.num_unique_categories + num_special_tokens 180 | 181 | shared_embed_dim = 0 if not use_shared_categ_embed else int(dim // shared_categ_dim_divisor) 182 | 183 | self.category_embed = nn.Embedding(total_tokens, dim - shared_embed_dim) 184 | 185 | # take care of shared category embed 186 | 187 | self.use_shared_categ_embed = use_shared_categ_embed 188 | 189 | if use_shared_categ_embed: 190 | self.shared_category_embed = nn.Parameter(torch.zeros(self.num_categories, shared_embed_dim)) 191 | nn.init.normal_(self.shared_category_embed, std = 0.02) 192 | 193 | # for automatically offsetting unique category ids to the correct position in the categories embedding table 194 | 195 | if self.num_unique_categories > 0: 196 | categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens) 197 | categories_offset = categories_offset.cumsum(dim = -1)[:-1] 198 | self.register_buffer('categories_offset', categories_offset) 199 | 200 | # continuous 201 | 202 | self.num_continuous = num_continuous 203 | 204 | if self.num_continuous > 0: 205 | if exists(continuous_mean_std): 206 | assert continuous_mean_std.shape == (num_continuous, 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively' 207 | self.register_buffer('continuous_mean_std', continuous_mean_std) 208 | 209 | self.norm = nn.LayerNorm(num_continuous) 210 | 211 | # transformer 212 | 213 | self.transformer = Transformer( 214 | dim = dim, 215 | depth = depth, 216 | heads = heads, 217 | dim_head = dim_head, 218 | attn_dropout = attn_dropout, 219 | ff_dropout = ff_dropout, 220 | num_residual_streams = num_residual_streams 221 | ) 222 | 223 | # mlp to logits 224 | 225 | input_size = (dim * self.num_categories) + num_continuous 226 | 227 | hidden_dimensions = [input_size * t for t in mlp_hidden_mults] 228 | all_dimensions = [input_size, *hidden_dimensions, dim_out] 229 | 230 | self.mlp = MLP(all_dimensions, act = mlp_act) 231 | 232 | def forward(self, x_categ, x_cont, return_attn = False): 233 | xs = [] 234 | 235 | assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input' 236 | 237 | if self.num_unique_categories > 0: 238 | x_categ = x_categ + self.categories_offset 239 | 240 | categ_embed = self.category_embed(x_categ) 241 | 242 | if self.use_shared_categ_embed: 243 | shared_categ_embed = repeat(self.shared_category_embed, 'n d -> b n d', b = categ_embed.shape[0]) 244 | categ_embed = torch.cat((categ_embed, shared_categ_embed), dim = -1) 245 | 246 | x, attns = self.transformer(categ_embed, return_attn = True) 247 | 248 | flat_categ = rearrange(x, 'b ... -> b (...)') 249 | xs.append(flat_categ) 250 | 251 | assert x_cont.shape[1] == self.num_continuous, f'you must pass in {self.num_continuous} values for your continuous input' 252 | 253 | if self.num_continuous > 0: 254 | if exists(self.continuous_mean_std): 255 | mean, std = self.continuous_mean_std.unbind(dim = -1) 256 | x_cont = (x_cont - mean) / std 257 | 258 | normed_cont = self.norm(x_cont) 259 | xs.append(normed_cont) 260 | 261 | x = torch.cat(xs, dim = -1) 262 | logits = self.mlp(x) 263 | 264 | if not return_attn: 265 | return logits 266 | 267 | return logits, attns 268 | --------------------------------------------------------------------------------