├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── setup.py
├── tab-vs-ft.png
├── tab.png
└── tab_transformer_pytorch
├── __init__.py
├── ft_transformer.py
└── tab_transformer_pytorch.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Tab Transformer
4 |
5 | Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.
6 |
7 | Update: Amazon AI claims to have beaten GBDT with Attention on a real-world tabular dataset (predicting shipping cost).
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install tab-transformer-pytorch
13 | ```
14 |
15 | ## Usage
16 |
17 | ```python
18 | import torch
19 | import torch.nn as nn
20 | from tab_transformer_pytorch import TabTransformer
21 |
22 | cont_mean_std = torch.randn(10, 2)
23 |
24 | model = TabTransformer(
25 | categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
26 | num_continuous = 10, # number of continuous values
27 | dim = 32, # dimension, paper set at 32
28 | dim_out = 1, # binary prediction, but could be anything
29 | depth = 6, # depth, paper recommended 6
30 | heads = 8, # heads, paper recommends 8
31 | attn_dropout = 0.1, # post-attention dropout
32 | ff_dropout = 0.1, # feed forward dropout
33 | mlp_hidden_mults = (4, 2), # relative multiples of each hidden dimension of the last mlp to logits
34 | mlp_act = nn.ReLU(), # activation for final mlp, defaults to relu, but could be anything else (selu etc)
35 | continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
36 | )
37 |
38 | x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
39 | x_cont = torch.randn(1, 10) # assume continuous values are already normalized individually
40 |
41 | pred = model(x_categ, x_cont) # (1, 1)
42 | ```
43 |
44 | ## FT Transformer
45 |
46 |
47 |
48 | This paper from Yandex improves on Tab Transformer by using a simpler scheme for embedding the continuous numerical values as shown in the diagram above, courtesy of this reddit post.
49 |
50 | Included in this repository just for convenient comparison to Tab Transformer
51 |
52 | ```python
53 | import torch
54 | from tab_transformer_pytorch import FTTransformer
55 |
56 | model = FTTransformer(
57 | categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
58 | num_continuous = 10, # number of continuous values
59 | dim = 32, # dimension, paper set at 32
60 | dim_out = 1, # binary prediction, but could be anything
61 | depth = 6, # depth, paper recommended 6
62 | heads = 8, # heads, paper recommends 8
63 | attn_dropout = 0.1, # post-attention dropout
64 | ff_dropout = 0.1 # feed forward dropout
65 | )
66 |
67 | x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
68 | x_numer = torch.randn(1, 10) # numerical value
69 |
70 | pred = model(x_categ, x_numer) # (1, 1)
71 | ```
72 |
73 | ## Unsupervised Training
74 |
75 | To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on `model.transformer`.
76 |
77 | ## Todo
78 |
79 | - [ ] consider https://arxiv.org/abs/2203.05556
80 |
81 | ## Citations
82 |
83 | ```bibtex
84 | @misc{huang2020tabtransformer,
85 | title = {TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
86 | author = {Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
87 | year = {2020},
88 | eprint = {2012.06678},
89 | archivePrefix = {arXiv},
90 | primaryClass = {cs.LG}
91 | }
92 | ```
93 |
94 | ```bibtex
95 | @article{Gorishniy2021RevisitingDL,
96 | title = {Revisiting Deep Learning Models for Tabular Data},
97 | author = {Yu. V. Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko},
98 | journal = {ArXiv},
99 | year = {2021},
100 | volume = {abs/2106.11959}
101 | }
102 | ```
103 |
104 | ```bibtex
105 | @article{Zhu2024HyperConnections,
106 | title = {Hyper-Connections},
107 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
108 | journal = {ArXiv},
109 | year = {2024},
110 | volume = {abs/2409.19606},
111 | url = {https://api.semanticscholar.org/CorpusID:272987528}
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'tab-transformer-pytorch',
5 | packages = find_packages(),
6 | version = '0.4.2',
7 | license='MIT',
8 | description = 'Tab Transformer - Pytorch',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Phil Wang',
11 | author_email = 'lucidrains@gmail.com',
12 | url = 'https://github.com/lucidrains/tab-transformer-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'transformers',
16 | 'attention mechanism',
17 | 'tabular data'
18 | ],
19 | install_requires=[
20 | 'einops>=0.8',
21 | 'hyper-connections>=0.1.15',
22 | 'torch>=2.3'
23 | ],
24 | classifiers=[
25 | 'Development Status :: 4 - Beta',
26 | 'Intended Audience :: Developers',
27 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
28 | 'License :: OSI Approved :: MIT License',
29 | 'Programming Language :: Python :: 3.6',
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------
/tab-vs-ft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/tab-transformer-pytorch/c1b9aa9c28f530d22ee95410842685525646bf64/tab-vs-ft.png
--------------------------------------------------------------------------------
/tab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/tab-transformer-pytorch/c1b9aa9c28f530d22ee95410842685525646bf64/tab.png
--------------------------------------------------------------------------------
/tab_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from tab_transformer_pytorch.tab_transformer_pytorch import TabTransformer
2 | from tab_transformer_pytorch.ft_transformer import FTTransformer
3 |
--------------------------------------------------------------------------------
/tab_transformer_pytorch/ft_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from torch.nn import Module, ModuleList
4 | import torch.nn.functional as F
5 |
6 | from einops import rearrange, repeat
7 |
8 | from hyper_connections import HyperConnections
9 |
10 | # feedforward and attention
11 |
12 | class GEGLU(Module):
13 | def forward(self, x):
14 | x, gates = x.chunk(2, dim = -1)
15 | return x * F.gelu(gates)
16 |
17 | def FeedForward(dim, mult = 4, dropout = 0.):
18 | return nn.Sequential(
19 | nn.LayerNorm(dim),
20 | nn.Linear(dim, dim * mult * 2),
21 | GEGLU(),
22 | nn.Dropout(dropout),
23 | nn.Linear(dim * mult, dim)
24 | )
25 |
26 | class Attention(Module):
27 | def __init__(
28 | self,
29 | dim,
30 | heads = 8,
31 | dim_head = 64,
32 | dropout = 0.
33 | ):
34 | super().__init__()
35 | inner_dim = dim_head * heads
36 | self.heads = heads
37 | self.scale = dim_head ** -0.5
38 |
39 | self.norm = nn.LayerNorm(dim)
40 |
41 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
42 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
43 |
44 | self.dropout = nn.Dropout(dropout)
45 |
46 | def forward(self, x):
47 | h = self.heads
48 |
49 | x = self.norm(x)
50 |
51 | q, k, v = self.to_qkv(x).chunk(3, dim = -1)
52 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
53 | q = q * self.scale
54 |
55 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
56 |
57 | attn = sim.softmax(dim = -1)
58 | dropped_attn = self.dropout(attn)
59 |
60 | out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
61 | out = rearrange(out, 'b h n d -> b n (h d)', h = h)
62 | out = self.to_out(out)
63 |
64 | return out, attn
65 |
66 | # transformer
67 |
68 | class Transformer(Module):
69 | def __init__(
70 | self,
71 | dim,
72 | depth,
73 | heads,
74 | dim_head,
75 | attn_dropout,
76 | ff_dropout,
77 | num_residual_streams = 4
78 | ):
79 | super().__init__()
80 |
81 | init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
82 |
83 | self.layers = ModuleList([])
84 |
85 | for _ in range(depth):
86 | self.layers.append(ModuleList([
87 | init_hyper_conn(dim = dim, branch = Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout)),
88 | init_hyper_conn(dim = dim, branch = FeedForward(dim, dropout = ff_dropout)),
89 | ]))
90 |
91 | def forward(self, x, return_attn = False):
92 | post_softmax_attns = []
93 |
94 | x = self.expand_streams(x)
95 |
96 | for attn, ff in self.layers:
97 | x, post_softmax_attn = attn(x)
98 | post_softmax_attns.append(post_softmax_attn)
99 |
100 | x = ff(x)
101 |
102 | x = self.reduce_streams(x)
103 |
104 | if not return_attn:
105 | return x
106 |
107 | return x, torch.stack(post_softmax_attns)
108 |
109 | # numerical embedder
110 |
111 | class NumericalEmbedder(Module):
112 | def __init__(self, dim, num_numerical_types):
113 | super().__init__()
114 | self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
115 | self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))
116 |
117 | def forward(self, x):
118 | x = rearrange(x, 'b n -> b n 1')
119 | return x * self.weights + self.biases
120 |
121 | # main class
122 |
123 | class FTTransformer(Module):
124 | def __init__(
125 | self,
126 | *,
127 | categories,
128 | num_continuous,
129 | dim,
130 | depth,
131 | heads,
132 | dim_head = 16,
133 | dim_out = 1,
134 | num_special_tokens = 2,
135 | attn_dropout = 0.,
136 | ff_dropout = 0.,
137 | num_residual_streams = 4
138 | ):
139 | super().__init__()
140 | assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
141 | assert len(categories) + num_continuous > 0, 'input shape must not be null'
142 |
143 | # categories related calculations
144 |
145 | self.num_categories = len(categories)
146 | self.num_unique_categories = sum(categories)
147 |
148 | # create category embeddings table
149 |
150 | self.num_special_tokens = num_special_tokens
151 | total_tokens = self.num_unique_categories + num_special_tokens
152 |
153 | # for automatically offsetting unique category ids to the correct position in the categories embedding table
154 |
155 | if self.num_unique_categories > 0:
156 | categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
157 | categories_offset = categories_offset.cumsum(dim = -1)[:-1]
158 | self.register_buffer('categories_offset', categories_offset)
159 |
160 | # categorical embedding
161 |
162 | self.categorical_embeds = nn.Embedding(total_tokens, dim)
163 |
164 | # continuous
165 |
166 | self.num_continuous = num_continuous
167 |
168 | if self.num_continuous > 0:
169 | self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous)
170 |
171 | # cls token
172 |
173 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
174 |
175 | # transformer
176 |
177 | self.transformer = Transformer(
178 | dim = dim,
179 | depth = depth,
180 | heads = heads,
181 | dim_head = dim_head,
182 | attn_dropout = attn_dropout,
183 | ff_dropout = ff_dropout,
184 | num_residual_streams = num_residual_streams
185 | )
186 |
187 | # to logits
188 |
189 | self.to_logits = nn.Sequential(
190 | nn.LayerNorm(dim),
191 | nn.ReLU(),
192 | nn.Linear(dim, dim_out)
193 | )
194 |
195 | def forward(self, x_categ, x_numer, return_attn = False):
196 | assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
197 |
198 | xs = []
199 | if self.num_unique_categories > 0:
200 | x_categ = x_categ + self.categories_offset
201 |
202 | x_categ = self.categorical_embeds(x_categ)
203 |
204 | xs.append(x_categ)
205 |
206 | # add numerically embedded tokens
207 | if self.num_continuous > 0:
208 | x_numer = self.numerical_embedder(x_numer)
209 |
210 | xs.append(x_numer)
211 |
212 | # concat categorical and numerical
213 |
214 | x = torch.cat(xs, dim = 1)
215 |
216 | # append cls tokens
217 | b = x.shape[0]
218 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
219 | x = torch.cat((cls_tokens, x), dim = 1)
220 |
221 | # attend
222 |
223 | x, attns = self.transformer(x, return_attn = True)
224 |
225 | # get cls token
226 |
227 | x = x[:, 0]
228 |
229 | # out in the paper is linear(relu(ln(cls)))
230 |
231 | logits = self.to_logits(x)
232 |
233 | if not return_attn:
234 | return logits
235 |
236 | return logits, attns
237 |
--------------------------------------------------------------------------------
/tab_transformer_pytorch/tab_transformer_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from torch.nn import Module, ModuleList
4 | import torch.nn.functional as F
5 |
6 | from einops import rearrange, repeat
7 |
8 | from hyper_connections import HyperConnections
9 |
10 | # helpers
11 |
12 | def exists(val):
13 | return val is not None
14 |
15 | def default(val, d):
16 | return val if exists(val) else d
17 |
18 | # classes
19 |
20 | class PreNorm(Module):
21 | def __init__(self, dim, fn):
22 | super().__init__()
23 | self.norm = nn.LayerNorm(dim)
24 | self.fn = fn
25 |
26 | def forward(self, x, **kwargs):
27 | return self.fn(self.norm(x), **kwargs)
28 |
29 | # attention
30 |
31 | class GEGLU(Module):
32 | def forward(self, x):
33 | x, gates = x.chunk(2, dim = -1)
34 | return x * F.gelu(gates)
35 |
36 | class FeedForward(Module):
37 | def __init__(self, dim, mult = 4, dropout = 0.):
38 | super().__init__()
39 | self.net = nn.Sequential(
40 | nn.Linear(dim, dim * mult * 2),
41 | GEGLU(),
42 | nn.Dropout(dropout),
43 | nn.Linear(dim * mult, dim)
44 | )
45 |
46 | def forward(self, x, **kwargs):
47 | return self.net(x)
48 |
49 | class Attention(Module):
50 | def __init__(
51 | self,
52 | dim,
53 | heads = 8,
54 | dim_head = 16,
55 | dropout = 0.
56 | ):
57 | super().__init__()
58 | inner_dim = dim_head * heads
59 | self.heads = heads
60 | self.scale = dim_head ** -0.5
61 |
62 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
63 | self.to_out = nn.Linear(inner_dim, dim)
64 |
65 | self.dropout = nn.Dropout(dropout)
66 |
67 | def forward(self, x):
68 | h = self.heads
69 | q, k, v = self.to_qkv(x).chunk(3, dim = -1)
70 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
71 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
72 |
73 | attn = sim.softmax(dim = -1)
74 | dropped_attn = self.dropout(attn)
75 |
76 | out = einsum('b h i j, b h j d -> b h i d', dropped_attn, v)
77 | out = rearrange(out, 'b h n d -> b n (h d)', h = h)
78 | return self.to_out(out), attn
79 |
80 | # transformer
81 |
82 | class Transformer(Module):
83 | def __init__(
84 | self,
85 | dim,
86 | depth,
87 | heads,
88 | dim_head,
89 | attn_dropout,
90 | ff_dropout,
91 | num_residual_streams = 4
92 | ):
93 | super().__init__()
94 | self.layers = ModuleList([])
95 |
96 | init_hyper_conn, self.expand_streams, self.reduce_streams = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
97 |
98 | for _ in range(depth):
99 | self.layers.append(ModuleList([
100 | init_hyper_conn(dim = dim, branch = PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout))),
101 | init_hyper_conn(dim = dim, branch = PreNorm(dim, FeedForward(dim, dropout = ff_dropout))),
102 | ]))
103 |
104 | def forward(self, x, return_attn = False):
105 | post_softmax_attns = []
106 |
107 | x = self.expand_streams(x)
108 |
109 | for attn, ff in self.layers:
110 | x, post_softmax_attn = attn(x)
111 | post_softmax_attns.append(post_softmax_attn)
112 |
113 | x = ff(x)
114 |
115 | x = self.reduce_streams(x)
116 |
117 | if not return_attn:
118 | return x
119 |
120 | return x, torch.stack(post_softmax_attns)
121 | # mlp
122 |
123 | class MLP(Module):
124 | def __init__(self, dims, act = None):
125 | super().__init__()
126 | dims_pairs = list(zip(dims[:-1], dims[1:]))
127 | layers = []
128 | for ind, (dim_in, dim_out) in enumerate(dims_pairs):
129 | is_last = ind >= (len(dims_pairs) - 1)
130 | linear = nn.Linear(dim_in, dim_out)
131 | layers.append(linear)
132 |
133 | if is_last:
134 | continue
135 |
136 | act = default(act, nn.ReLU())
137 | layers.append(act)
138 |
139 | self.mlp = nn.Sequential(*layers)
140 |
141 | def forward(self, x):
142 | return self.mlp(x)
143 |
144 | # main class
145 |
146 | class TabTransformer(Module):
147 | def __init__(
148 | self,
149 | *,
150 | categories,
151 | num_continuous,
152 | dim,
153 | depth,
154 | heads,
155 | dim_head = 16,
156 | dim_out = 1,
157 | mlp_hidden_mults = (4, 2),
158 | mlp_act = None,
159 | num_special_tokens = 2,
160 | continuous_mean_std = None,
161 | attn_dropout = 0.,
162 | ff_dropout = 0.,
163 | use_shared_categ_embed = True,
164 | shared_categ_dim_divisor = 8., # in paper, they reserve dimension / 8 for category shared embedding
165 | num_residual_streams = 4
166 | ):
167 | super().__init__()
168 | assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
169 | assert len(categories) + num_continuous > 0, 'input shape must not be null'
170 |
171 | # categories related calculations
172 |
173 | self.num_categories = len(categories)
174 | self.num_unique_categories = sum(categories)
175 |
176 | # create category embeddings table
177 |
178 | self.num_special_tokens = num_special_tokens
179 | total_tokens = self.num_unique_categories + num_special_tokens
180 |
181 | shared_embed_dim = 0 if not use_shared_categ_embed else int(dim // shared_categ_dim_divisor)
182 |
183 | self.category_embed = nn.Embedding(total_tokens, dim - shared_embed_dim)
184 |
185 | # take care of shared category embed
186 |
187 | self.use_shared_categ_embed = use_shared_categ_embed
188 |
189 | if use_shared_categ_embed:
190 | self.shared_category_embed = nn.Parameter(torch.zeros(self.num_categories, shared_embed_dim))
191 | nn.init.normal_(self.shared_category_embed, std = 0.02)
192 |
193 | # for automatically offsetting unique category ids to the correct position in the categories embedding table
194 |
195 | if self.num_unique_categories > 0:
196 | categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
197 | categories_offset = categories_offset.cumsum(dim = -1)[:-1]
198 | self.register_buffer('categories_offset', categories_offset)
199 |
200 | # continuous
201 |
202 | self.num_continuous = num_continuous
203 |
204 | if self.num_continuous > 0:
205 | if exists(continuous_mean_std):
206 | assert continuous_mean_std.shape == (num_continuous, 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively'
207 | self.register_buffer('continuous_mean_std', continuous_mean_std)
208 |
209 | self.norm = nn.LayerNorm(num_continuous)
210 |
211 | # transformer
212 |
213 | self.transformer = Transformer(
214 | dim = dim,
215 | depth = depth,
216 | heads = heads,
217 | dim_head = dim_head,
218 | attn_dropout = attn_dropout,
219 | ff_dropout = ff_dropout,
220 | num_residual_streams = num_residual_streams
221 | )
222 |
223 | # mlp to logits
224 |
225 | input_size = (dim * self.num_categories) + num_continuous
226 |
227 | hidden_dimensions = [input_size * t for t in mlp_hidden_mults]
228 | all_dimensions = [input_size, *hidden_dimensions, dim_out]
229 |
230 | self.mlp = MLP(all_dimensions, act = mlp_act)
231 |
232 | def forward(self, x_categ, x_cont, return_attn = False):
233 | xs = []
234 |
235 | assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
236 |
237 | if self.num_unique_categories > 0:
238 | x_categ = x_categ + self.categories_offset
239 |
240 | categ_embed = self.category_embed(x_categ)
241 |
242 | if self.use_shared_categ_embed:
243 | shared_categ_embed = repeat(self.shared_category_embed, 'n d -> b n d', b = categ_embed.shape[0])
244 | categ_embed = torch.cat((categ_embed, shared_categ_embed), dim = -1)
245 |
246 | x, attns = self.transformer(categ_embed, return_attn = True)
247 |
248 | flat_categ = rearrange(x, 'b ... -> b (...)')
249 | xs.append(flat_categ)
250 |
251 | assert x_cont.shape[1] == self.num_continuous, f'you must pass in {self.num_continuous} values for your continuous input'
252 |
253 | if self.num_continuous > 0:
254 | if exists(self.continuous_mean_std):
255 | mean, std = self.continuous_mean_std.unbind(dim = -1)
256 | x_cont = (x_cont - mean) / std
257 |
258 | normed_cont = self.norm(x_cont)
259 | xs.append(normed_cont)
260 |
261 | x = torch.cat(xs, dim = -1)
262 | logits = self.mlp(x)
263 |
264 | if not return_attn:
265 | return logits
266 |
267 | return logits, attns
268 |
--------------------------------------------------------------------------------