├── figures
├── scaling.png
└── performance-comparison.png
├── tabdpt_datasets
├── __pycache__
│ ├── dataset.cpython-311.pyc
│ ├── openml.cpython-311.pyc
│ └── __init__.cpython-311.pyc
├── cc_test_datasets_multiclass.pickle
├── cc_valid_datasets_multiclass.pickle
├── tabred.py
├── annotated_tables.py
├── dataset.py
├── __init__.py
├── talent.py
├── openml.py
└── catalogue.py
├── predict.py
├── data_splits
├── noleak_training_datasets_anysize.csv
├── noleak_training_datasets.csv
├── reg_datasets.csv
└── cls_datasets.csv
├── configs
└── default_config.yaml
├── LICENSE
├── README.md
├── transformer_layer.py
├── requirements.txt
├── eval_full.py
├── model.py
├── utils.py
├── train.py
├── dataset.py
└── tabdpt.py
/figures/scaling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/figures/scaling.png
--------------------------------------------------------------------------------
/figures/performance-comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/figures/performance-comparison.png
--------------------------------------------------------------------------------
/tabdpt_datasets/__pycache__/dataset.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/__pycache__/dataset.cpython-311.pyc
--------------------------------------------------------------------------------
/tabdpt_datasets/__pycache__/openml.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/__pycache__/openml.cpython-311.pyc
--------------------------------------------------------------------------------
/tabdpt_datasets/cc_test_datasets_multiclass.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/cc_test_datasets_multiclass.pickle
--------------------------------------------------------------------------------
/tabdpt_datasets/cc_valid_datasets_multiclass.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/cc_valid_datasets_multiclass.pickle
--------------------------------------------------------------------------------
/tabdpt_datasets/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import fetch_california_housing, load_breast_cancer
2 | from sklearn.metrics import accuracy_score, r2_score
3 | from sklearn.model_selection import train_test_split
4 |
5 | from tabdpt import TabDPTClassifier, TabDPTRegressor
6 |
7 | # classification example
8 | X, y = load_breast_cancer(return_X_y=True)
9 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
10 |
11 | model = TabDPTClassifier(path="./checkpoints/latest.ckpt")
12 | model.fit(X_train, y_train)
13 | y_pred = model.predict(X_test, temperature=0.8, context_size=1024, use_retrieval=True)
14 | print("classification accuracy score = ", accuracy_score(y_test, y_pred))
15 |
16 |
17 | # regression example
18 | X, y = fetch_california_housing(return_X_y=True)
19 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
20 |
21 | model = TabDPTRegressor(path="./checkpoints/latest.ckpt")
22 | model.fit(X_train, y_train)
23 | y_pred = model.predict(X_test, context_size=512, use_retrieval=True)
24 | print("regression r2 score =", r2_score(y_test, y_pred))
25 |
--------------------------------------------------------------------------------
/data_splits/noleak_training_datasets_anysize.csv:
--------------------------------------------------------------------------------
1 | did,task
2 | 4135,cls
3 | 41434,cls
4 | 375,cls
5 | 1120,cls
6 | 40900,cls
7 | 1043,cls
8 | 1119,cls
9 | 1,cls
10 | 1459,cls
11 | 1466,cls
12 | 23380,cls
13 | 1471,cls
14 | 846,cls
15 | 1044,cls
16 | 821,cls
17 | 40679,cls
18 | 334,cls
19 | 24,cls
20 | 1568,cls
21 | 30,cls
22 | 871,cls
23 | 470,cls
24 | 41146,cls
25 | 377,cls
26 | 40733,cls
27 | 44089,cls
28 | 44122,cls
29 | 45020,cls
30 | 45028,cls
31 | 45026,cls
32 | 45039,cls
33 | 41156,cls
34 | 137,cls
35 | 311,cls
36 | 333,cls
37 | 335,cls
38 | 382,cls
39 | 717,cls
40 | 757,cls
41 | 802,cls
42 | 816,cls
43 | 825,cls
44 | 839,cls
45 | 841,cls
46 | 843,cls
47 | 930,cls
48 | 940,cls
49 | 949,cls
50 | 966,cls
51 | 981,cls
52 | 1002,cls
53 | 1018,cls
54 | 1037,cls
55 | 1443,cls
56 | 1444,cls
57 | 1451,cls
58 | 1452,cls
59 | 1453,cls
60 | 1507,cls
61 | 40646,cls
62 | 40680,cls
63 | 40690,cls
64 | 40693,cls
65 | 40705,cls
66 | 40706,cls
67 | 41919,cls
68 | 42638,cls
69 | 44055,reg
70 | 44056,reg
71 | 44063,reg
72 | 44136,reg
73 | 44137,reg
74 | 44145,reg
75 | 45032,reg
76 | 546,reg
77 | 42724,reg
78 | 42727,reg
79 | 531,reg
80 | 42563,reg
81 |
--------------------------------------------------------------------------------
/data_splits/noleak_training_datasets.csv:
--------------------------------------------------------------------------------
1 | did
2 | 41138
3 | 4135
4 | 4535
5 | 41434
6 | 375
7 | 1120
8 | 41150
9 | 40900
10 | 40536
11 | 1043
12 | 1169
13 | 41147
14 | 1459
15 | 1466
16 | 1118
17 | 41142
18 | 23380
19 | 1596
20 | 41163
21 | 1471
22 | 846
23 | 1044
24 | 41164
25 | 1477
26 | 1476
27 | 41159
28 | 23512
29 | 1479
30 | 821
31 | 41168
32 | 41143
33 | 184
34 | 1483
35 | 40679
36 | 24
37 | 1116
38 | 1568
39 | 1493
40 | 30
41 | 41145
42 | 1567
43 | 871
44 | 41161
45 | 41165
46 | 312
47 | 40685
48 | 1036
49 | 41146
50 | 41166
51 | 1509
52 | 40733
53 | 44089
54 | 44122
55 | 45022
56 | 45020
57 | 45026
58 | 45038
59 | 45039
60 | 1111
61 | 1457
62 | 41167
63 | 41144
64 | 41156
65 | 41169
66 | 41162
67 | 42734
68 | 42732
69 | 42746
70 | 42742
71 | 43072
72 | 273
73 | 382
74 | 389
75 | 396
76 | 802
77 | 816
78 | 843
79 | 930
80 | 966
81 | 981
82 | 1002
83 | 1018
84 | 1037
85 | 1112
86 | 1130
87 | 1142
88 | 1444
89 | 1453
90 | 1481
91 | 1503
92 | 1507
93 | 40646
94 | 40680
95 | 40706
96 | 44055
97 | 44056
98 | 44061
99 | 44063
100 | 44065
101 | 44068
102 | 44069
103 | 45041
104 | 45043
105 | 45045
106 | 45046
107 | 45047
108 | 44136
109 | 44137
110 | 44145
111 | 45032
112 | 4549
113 | 42572
114 | 42705
115 | 42728
116 | 41540
117 | 42724
118 | 42727
119 | 42730
120 | 41980
121 | 42563
122 | 3050
123 | 3277
124 | 43071
125 |
--------------------------------------------------------------------------------
/configs/default_config.yaml:
--------------------------------------------------------------------------------
1 | version: 0.0.1
2 | description: TabDPT
3 | seed: 42
4 | exp_name: "default"
5 | folder: "default"
6 | exp_path: ''
7 |
8 | env:
9 | device: 'cuda:0'
10 | gpus: [0, 1]
11 | num_workers: 32
12 |
13 | model:
14 | emsize: 512
15 | max_num_classes: 10
16 | max_num_features: 100
17 | # number of heads in the transformer, 8 worked slightly better than 4
18 | nhead: 8
19 | nhid_factor: 2
20 | nlayers: 12
21 |
22 |
23 | training:
24 | num_epochs: 2048
25 | num_model_updates: 128
26 | num_agg: 1 # gradient aggreation steps: to be adjusted based on the number of gpus
27 | batch_size: 256 # to be adjusted based on the number of gpus
28 | lr: 0.0005
29 | weight_decay: 0.05
30 | dropout: 0.0
31 | # minimum number of elements in the context: too low means lots of noise, too high means we never see small contexts
32 | min_eval_pos: 50
33 | # maximum number of elements in the context
34 | max_eval_pos: 1024
35 | # Fixed size sequence length. Number of queries if seq_len - max_eval_pos
36 | # ensure seq_len >= max_eval_pos + constant where constant is the minimum number of queries
37 | # Too small also means a lot of noise
38 | seq_len: 1536
39 | # seq len for eval
40 | eval_seq_len: 1024
41 | label_smoothing: 0.1
42 | reset_policy: 'rm' # reset policy: 'rm' for remove, 'cnt' for continue
43 | compile: true
44 | clip_grad_norm: 1
45 |
46 | data:
47 | y_reg_augment: true
48 | # retreival during training
49 | retrieval: true
50 | # retrieval during eval: this may take more memory than the training, be careful with memory
51 | eval_retrieval: true
52 |
53 | logging:
54 | eval_every: 10
55 | save_metrics:
56 | - valid_agg
57 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Notice: © Copyright 2018 The Toronto-Dominion Bank and/or its affiliates
2 |
3 | Permission is hereby granted, subject to the conditions below, free of charge, to
4 | any person obtaining a copy of this software and associated documentation files
5 | (the "Software"), to use, copy, distribute, publish and modify the Software, only
6 | for research and for no other purpose. For clarity, and without limitation, this
7 | licence does not permit use of Software or any part thereof for commercial purposes.
8 |
9 | Patents: This permission does not grant any patent licenses in the Software.
10 |
11 | Conditions:
12 | 1. The above copyright notice and the following disclaimer shall be included in all
13 | copies or substantial portions of the Software.
14 | 2. You must give appropriate credit, provide a link to the license, and indicate if
15 | changes were made. You may do so in any reasonable manner, but not in any way that
16 | suggests that the copyright holder endorses you or your use of the Software.
17 |
18 | Names and Trademarks: No permission to use the names or trademarks of the copyright
19 | holder are granted, except as required for reasonable and customary use in describing
20 | the origin of the Software and reproducing the content of the copyright notice.
21 |
22 | DISCLAIMER: THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
24 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
25 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
26 | CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
27 | OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # TabDPT: Scaling Tabular Foundation Models on Real Data
4 |
5 | [](https://arxiv.org/abs/2410.18164)
6 | [](https://huggingface.co/Layer6/TabDPT)
7 |
8 |
9 |
10 | **TabDPT** is an open-source foundation model for tabular data based on in-context learning. It is trained on real-world data and can generalize to new tasks **without** additional training or hyperparameter tuning.
11 |
12 | This repository provides the full training code to build your own TabDPT model. A lightweight inference interface is available [here](https://github.com/layer6ai-labs/TabDPT-inference), which can support the evaluation of either the existing TabDPT model or any new models that are trained using this repository.
13 |
14 |
15 | ## Usage
16 |
17 | We provide basic usage tips below. The details can be found by stepping through the code.
18 |
19 | ### Installation
20 |
21 | Before running the code, make sure to install the required Python packages:
22 |
23 | ```
24 | pip install -r requirements.txt
25 | ```
26 |
27 | You will also need a C compiler such as `gcc` for building some dependencies. On Ubuntu, you can install it with:
28 |
29 | ```
30 | sudo apt-get update
31 | sudo apt-get install build-essential
32 | ```
33 |
34 | ### Training Example
35 |
36 |
37 | To train a fresh TabDPT model with default hyperparameters on a single GPU, use the following command:
38 |
39 | ```
40 | CUDA_VISIBLE_DEVICES=0 python train.py exp_name="TabDPT"
41 | ```
42 |
43 | #### Multi GPU Training Example
44 |
45 | If instead you want to use Multi-GPU, do the following:
46 | ```
47 | CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29500 train.py \
48 | env.gpus="[4,5,6,7]" \
49 | exp_name="my_multi_gpu_test"
50 | ```
51 |
52 | **Notes:**
53 | - Adjust `nproc_per_node` to the number of GPUs.
54 | - If there are communication issues when using several multi gpu training runs on the same node, change the `rdzv_endpoint` port as it can be maxxed out.
55 |
56 |
57 | ## Citation and Acknowledgements
58 |
59 | If citing the paper, please use the following BibTeX:
60 |
61 | ```
62 | @article{ma2024tabdpt,
63 | title={TabDPT: Scaling Tabular Foundation Models on Real Data},
64 | author={Ma, Junwei and Thomas, Valentin and Hosseinzadeh, Rasa and Kamkari, Hamidreza and Labach, Alex and Cresswell, Jesse C and Golestan, Keyvan and Yu, Guangwei and Caterini, Anthony L and Volkovs, Maksims},
65 | journal={arXiv preprint arXiv:2410.18164},
66 | year={2024}
67 | }
68 | ```
69 |
70 | Additionally, a huge thank you to [Nafiseh Ghoroghchian](https://github.com/NaGho) for spearheading the effort of refactoring and making this codebase fit for pubilc consumption, and thank you to [Roc Zhang](https://github.com/Zhang-Haipeng) for making the codebase compatible with `safetensors`.
71 |
--------------------------------------------------------------------------------
/transformer_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch.nn import GELU, LayerNorm, Linear
5 |
6 |
7 | class TransformerEncoderLayer(nn.Module):
8 | """Transformer Encoder Layer use in TabDPT."""
9 |
10 | def __init__(self, embed_dim: int, num_heads: int, ff_dim: int) -> None:
11 | """
12 | Args:
13 | embed_dim (int): Dimension of the embedding.
14 | num_heads (int): Number of attention heads.
15 | ff_dim (int): Dimension of the feed-forward network.
16 | """
17 | super().__init__()
18 | self.embed_dim = embed_dim
19 | self.head_dim = embed_dim // num_heads
20 | self.num_heads = num_heads
21 | self.kv_proj = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
22 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
23 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
24 | self.attn_norm = LayerNorm(embed_dim)
25 | self.ff_norm = LayerNorm(embed_dim)
26 | self.ff = nn.Sequential(Linear(embed_dim, ff_dim), GELU(), Linear(ff_dim, embed_dim))
27 | self.q_norm = LayerNorm(self.head_dim)
28 | self.k_norm = LayerNorm(self.head_dim)
29 |
30 | def forward(self, x: torch.Tensor, eval_pos: int) -> torch.Tensor:
31 | """
32 | Args:
33 | x (torch.tensor): Input tensor of shape (L, B, D) where B is batch size, L is sequence length, and D is embedding dimension.
34 | eval_pos (int): Evaluation position used for slicing attention keys and values.
35 | Returns:
36 | torch.tensor: Output tensor of the same shape as input.
37 | """
38 | # switch to (B, L, D) for attention computation
39 | x = x.transpose(0, 1)
40 | B, L, _ = x.size()
41 |
42 | # Normalize the input
43 | h = self.attn_norm(x)
44 |
45 | # project to query, key, and value with linear layers
46 | q = self.q_proj(h)
47 | # slice the key and value projections to the evaluation position
48 | k, v = self.kv_proj(h[:, :eval_pos]).chunk(2, dim=-1)
49 |
50 | # reshape and transpose for multi-head attention
51 | # q: (B, L, D) -> (B, L, num_heads, head_dim)
52 | # k: (B, eval_pos, D) -> (B, num_heads, eval_pos, head_dim)
53 | # v: (B, eval_pos, D) -> (B, num_heads, eval_pos, head_dim)
54 | q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
55 | k = k.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2)
56 | v = v.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2)
57 |
58 | # apply layer normalization to query and key
59 | q, k = self.q_norm(q), self.k_norm(k)
60 |
61 | # compute scaled dot-product attention
62 | attn = F.scaled_dot_product_attention(q, k, v).transpose(1, 2)
63 | attn = self.out_proj(attn.reshape(B, L, self.num_heads * self.head_dim))
64 |
65 | # residual connection and feed-forward network
66 | x = x + attn
67 | x = x + self.ff(self.ff_norm(x))
68 |
69 | # back to (L, B, D) for output
70 | return x.transpose(0, 1)
71 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.3.0
2 | antlr4-python3-runtime==4.9.3
3 | anyio==4.9.0
4 | argon2-cffi==23.1.0
5 | argon2-cffi-bindings==21.2.0
6 | arrow==1.3.0
7 | asttokens==3.0.0
8 | async-lru==2.0.5
9 | attrs==25.3.0
10 | babel==2.17.0
11 | beautifulsoup4==4.13.4
12 | bleach==6.2.0
13 | certifi==2025.4.26
14 | cffi==1.17.1
15 | charset-normalizer==3.4.2
16 | comm==0.2.2
17 | config==0.5.1
18 | debugpy==1.8.14
19 | decorator==5.2.1
20 | defusedxml==0.7.1
21 | executing==2.2.0
22 | faiss-cpu==1.11.0
23 | fastjsonschema==2.21.1
24 | filelock==3.18.0
25 | fqdn==1.5.1
26 | fsspec==2025.5.0
27 | gdown==5.2.0
28 | grpcio==1.71.0
29 | h11==0.16.0
30 | hf-xet==1.1.2
31 | httpcore==1.0.9
32 | httpx==0.28.1
33 | huggingface-hub==0.32.2
34 | hydra-core==1.3.2
35 | idna==3.10
36 | ipykernel==6.29.5
37 | ipython==9.2.0
38 | ipython_pygments_lexers==1.1.1
39 | ipywidgets==8.1.7
40 | isoduration==20.11.0
41 | jedi==0.19.2
42 | Jinja2==3.1.6
43 | joblib==1.5.1
44 | json5==0.12.0
45 | jsonpointer==3.0.0
46 | jsonschema==4.24.0
47 | jsonschema-specifications==2025.4.1
48 | jupyter==1.1.1
49 | jupyter-console==6.6.3
50 | jupyter-events==0.12.0
51 | jupyter-lsp==2.2.5
52 | jupyter_client==8.6.3
53 | jupyter_core==5.8.1
54 | jupyter_server==2.16.0
55 | jupyter_server_terminals==0.5.3
56 | jupyterlab==4.4.3
57 | jupyterlab_pygments==0.3.0
58 | jupyterlab_server==2.27.3
59 | jupyterlab_widgets==3.0.15
60 | kaggle==1.7.4.5
61 | liac-arff==2.5.0
62 | Markdown==3.8
63 | MarkupSafe==3.0.2
64 | matplotlib==3.10.0
65 | matplotlib-inline==0.1.7
66 | minio==7.2.15
67 | mistune==3.1.3
68 | mpmath==1.3.0
69 | nbclient==0.10.2
70 | nbconvert==7.16.6
71 | nbformat==5.10.4
72 | nest-asyncio==1.6.0
73 | networkx==3.4.2
74 | notebook==7.4.3
75 | notebook_shim==0.2.4
76 | omegaconf==2.3.0
77 | openml==0.15.1
78 | overrides==7.7.0
79 | pandocfilters==1.5.1
80 | parso==0.8.4
81 | pexpect==4.9.0
82 | platformdirs==4.3.8
83 | polars==1.30.0
84 | prometheus_client==0.22.0
85 | prompt_toolkit==3.0.51
86 | protobuf==6.31.1
87 | psutil==7.0.0
88 | ptyprocess==0.7.0
89 | pure_eval==0.2.3
90 | pyarrow==20.0.0
91 | pycparser==2.22
92 | pycryptodome==3.23.0
93 | Pygments==2.19.1
94 | PyQt6==6.7.1
95 | PySocks==1.7.1
96 | python-json-logger==3.3.0
97 | python-slugify==8.0.4
98 | PyYAML==6.0.2
99 | pyzmq==26.4.0
100 | referencing==0.36.2
101 | regex==2024.11.6
102 | requests==2.32.3
103 | rfc3339-validator==0.1.4
104 | rfc3986-validator==0.1.1
105 | rpds-py==0.25.1
106 | safetensors==0.5.3
107 | schedulefree==1.4.1
108 | scikit-learn==1.6.1
109 | scipy==1.15.3
110 | Send2Trash==1.8.3
111 | setuptools==72.1.0
112 | sniffio==1.3.1
113 | soupsieve==2.7
114 | stack-data==0.6.3
115 | sympy==1.14.0
116 | tensorboard==2.19.0
117 | tensorboard-data-server==0.7.2
118 | terminado==0.18.1
119 | text-unidecode==1.3
120 | threadpoolctl==3.6.0
121 | tinycss2==1.4.0
122 | tokenizers==0.21.1
123 | torch==2.7.0
124 | torchaudio==2.7.0
125 | torchvision==0.22.0
126 | tqdm==4.67.1
127 | traitlets==5.14.3
128 | transformers==4.52.3
129 | types-python-dateutil==2.9.0.20250516
130 | typing_extensions==4.13.2
131 | uri-template==1.3.0
132 | urllib3==2.4.0
133 | wcwidth==0.2.13
134 | webcolors==24.11.1
135 | webencodings==0.5.1
136 | websocket-client==1.8.0
137 | Werkzeug==3.1.3
138 | wheel==0.45.1
139 | widgetsnbextension==4.0.14
140 | xmltodict==0.14.2
141 |
--------------------------------------------------------------------------------
/tabdpt_datasets/tabred.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import numpy as np
5 |
6 | import tabdpt_datasets.external.tabred.cooking_time as cooking_time
7 | import tabdpt_datasets.external.tabred.delivery_eta as delivery_eta
8 | import tabdpt_datasets.external.tabred.maps_routing as maps_routing
9 | import tabdpt_datasets.external.tabred.weather as weather
10 | from tabdpt_datasets.dataset import Dataset
11 |
12 |
13 | class TabredDataset(Dataset):
14 | @staticmethod
15 | def all_names():
16 | return ["cooking_time", "delivery_eta", "maps_routing", "weather_forecasting"]
17 |
18 | @staticmethod
19 | def suite_name():
20 | return "tabred"
21 |
22 | def __init__(self, name, task_id=None):
23 | """
24 | task_id options correspond to the splits provided by the dataset:
25 | "split-default", "split-random-0", "split-random-1", "split-random-2",
26 | "split-sliding-window-0", "split-sliding-window-1", "split-sliding-window-2"
27 | """
28 |
29 | if task_id is None:
30 | task_id = "split-default"
31 | assert task_id in (
32 | "split-default",
33 | "split-random-0",
34 | "split-random-1",
35 | "split-random-2",
36 | "split-sliding-window-0",
37 | "split-sliding-window-1",
38 | "split-sliding-window-2",
39 | )
40 | super().__init__(name, task_id)
41 | self.metadata["target_type"] = "regression"
42 |
43 | def prepare_data(self, download_dir):
44 | download_dir = Path(download_dir)
45 |
46 | if self.name == "cooking_time":
47 | download_fn = cooking_time.main
48 | sub_dir = "cooking-time"
49 | elif self.name == "delivery_eta":
50 | download_fn = delivery_eta.main
51 | sub_dir = "delivery-eta"
52 | elif self.name == "maps_routing":
53 | download_fn = maps_routing.main
54 | sub_dir = "maps-routing"
55 | elif self.name == "weather_forecasting":
56 | download_fn = weather.main
57 | sub_dir = "weather"
58 |
59 | if not all(os.path.exists(download_dir / sub_dir / f) for f in ("X_num.npy", "Y.npy")):
60 | download_fn(download_dir)
61 | X_mats = []
62 | X_mats.append(np.load(download_dir / sub_dir / "X_num.npy"))
63 | if os.path.exists(download_dir / sub_dir / "X_bin.npy"):
64 | X_mats.append(np.load(download_dir / sub_dir / "X_bin.npy"))
65 | if os.path.exists(download_dir / sub_dir / "X_cat.npy"):
66 | X_cat = np.load(download_dir / sub_dir / "X_cat.npy")
67 | n_num = sum(X.shape[1] for X in X_mats)
68 | self.metadata["categorical_feature_inds"] = [n_num + i for i in range(X_cat.shape[1])]
69 | X_mats.append(X_cat)
70 | self.X = np.concatenate(X_mats, axis=1)
71 | self.y = np.load(download_dir / sub_dir / "Y.npy")
72 | self._train_inds = np.load(download_dir / sub_dir / self._task_id / "train_idx.npy")
73 | self._val_inds = np.load(download_dir / sub_dir / self._task_id / "val_idx.npy")
74 | self._test_inds = np.load(download_dir / sub_dir / self._task_id / "test_idx.npy")
75 |
76 | def all_instances(self):
77 | return self.X, self.y
78 |
79 | def train_inds(self):
80 | return self._train_inds
81 |
82 | def val_inds(self):
83 | return self._val_inds
84 |
85 | def test_inds(self):
86 | return self._test_inds
87 |
--------------------------------------------------------------------------------
/tabdpt_datasets/annotated_tables.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import time
4 | from pathlib import Path
5 |
6 | import kaggle
7 | import numpy as np
8 | import pandas as pd
9 | from sklearn.preprocessing import OrdinalEncoder
10 |
11 | from tabdpt_datasets.dataset import Dataset
12 | from tabdpt_datasets.external.annotated_tables.ids_deduped import ANNOTATED_TABLE_IDS
13 |
14 |
15 | class AnnotatedTablesDataset(Dataset):
16 | """
17 | Dataset of the tables used to evaluate TabPFN in https://arxiv.org/pdf/2406.16349, with
18 | duplicate datasets left out. Since Kaggle doesn't have a standard file layout, we just take the
19 | largest CSV file. Because of that, there's no target or test set, and this dataset should only
20 | be used for training.
21 | """
22 |
23 | @staticmethod
24 | def all_names():
25 | return [f"annotated_tables_{id}" for id in ANNOTATED_TABLE_IDS]
26 |
27 | @staticmethod
28 | def suite_name():
29 | return "annotated_tables"
30 |
31 | def __init__(self, name, task_id=None, split_seed=0):
32 | """
33 | Note that split_seed is only used for train/val split, since there's no test set.
34 | """
35 | if task_id is None:
36 | task_id = f"random-seed{split_seed}"
37 | else:
38 | assert task_id.startswith("random-seed")
39 | split_seed = int(task_id.removeprefix("random-seed"))
40 | super().__init__(name, task_id)
41 | self.kaggle_id = name[len("annotated_tables_") :]
42 | self.rng = np.random.default_rng(split_seed)
43 | self.metadata["kaggle_dataset_name"] = self.kaggle_id.split("/")[-1]
44 |
45 | def prepare_data(self, download_dir):
46 | dataset_dir = os.path.join(
47 | download_dir, "annotated-tables", self.kaggle_id.replace("/", "-")
48 | )
49 | if not os.path.exists(dataset_dir):
50 | os.makedirs(dataset_dir)
51 | try:
52 | kaggle.api.dataset_download_files(self.kaggle_id, dataset_dir, unzip=True)
53 | except kaggle.rest.ApiException as e:
54 | if e.status == 429: # Rate limit
55 | time.sleep(0.5)
56 | kaggle.api.dataset_download_files(self.kaggle_id, dataset_dir, unzip=True)
57 | elif e.status == 403: # Rate limit
58 | raise ValueError(
59 | f"HTTP 403 when downloading Kaggle dataset {self.kaggle_id}. It was "
60 | "probably deleted from Kaggle - try removing it from "
61 | "datasets/external/annotated_tables/ids.py"
62 | )
63 | else:
64 | raise e
65 | csv_paths = sorted(Path(dataset_dir).rglob("*.csv"), key=lambda p: p.stat().st_size)
66 | if len(csv_paths) == 0:
67 | raise ValueError("Couldn not load data for " + self.name)
68 |
69 | hasher = hashlib.new("sha1")
70 | with open(csv_paths[-1], "rb") as f:
71 | hasher.update(f.read())
72 | self.metadata["file_sha1"] = hasher.hexdigest()
73 | X = pd.read_csv(csv_paths[-1], low_memory=False)
74 |
75 | categorical_inds = []
76 | to_drop = []
77 | n_dropped = 0
78 | for i, col in enumerate(X.columns):
79 | if X[col].dtype == "object" or isinstance(X[col], pd.CategoricalDtype):
80 | enc = OrdinalEncoder()
81 | X[[col]] = enc.fit_transform(X[[col]])
82 | # Drop high cardinality
83 | if len(enc.categories_[0]) > 100:
84 | to_drop.append(col)
85 | n_dropped += 1
86 | else:
87 | categorical_inds.append(i - n_dropped)
88 | X.drop(columns=to_drop, inplace=True)
89 | self.metadata["categorical_feature_inds"] = categorical_inds
90 | self.X = X.to_numpy().astype(np.float32)
91 |
92 | n = len(X)
93 | perm = self.rng.permutation(n)
94 | self._train_inds = perm[: int(n * 0.85)]
95 | self._val_inds = perm[int(n * 0.85) :]
96 |
97 | def all_instances(self):
98 | return self.X, None
99 |
100 | def train_inds(self):
101 | return self._train_inds
102 |
103 | def val_inds(self):
104 | return self._val_inds
105 |
106 | def test_inds(self):
107 | return []
108 |
--------------------------------------------------------------------------------
/data_splits/reg_datasets.csv:
--------------------------------------------------------------------------------
1 | did,tid,num_instances,num_features,num_missing_values,source,test
2 | 44956,361234,4177.0,9.0,0.0,['ctr23'],True
3 | 44957,361235,1503.0,6.0,0.0,['ctr23'],True
4 | 44958,361236,2043.0,8.0,0.0,['ctr23'],True
5 | 44959,361237,1030.0,9.0,0.0,['ctr23'],True
6 | 44963,361241,45730.0,10.0,0.0,['ctr23'],True
7 | 44964,361242,21263.0,82.0,0.0,['ctr23'],True
8 | 44965,361243,1059.0,117.0,0.0,['ctr23'],True
9 | 44966,361244,1066.0,11.0,0.0,['ctr23'],True
10 | 44969,361247,11934.0,15.0,0.0,['ctr23'],True
11 | 44971,361249,4898.0,12.0,0.0,['ctr23'],True
12 | 44972,361250,1599.0,12.0,0.0,['ctr23'],True
13 | 44973,361251,10000.0,13.0,0.0,['ctr23'],True
14 | 44974,361252,68784.0,19.0,0.0,['ctr23'],True
15 | 44975,361253,72000.0,49.0,0.0,['ctr23'],True
16 | 44976,361254,48933.0,22.0,0.0,['ctr23'],True
17 | 44977,361255,20640.0,9.0,0.0,['ctr23'],True
18 | 44978,361256,8192.0,22.0,0.0,['ctr23'],True
19 | 44979,361257,53940.0,10.0,0.0,['ctr23'],True
20 | 44980,361258,8192.0,9.0,0.0,['ctr23'],True
21 | 44981,361259,8192.0,33.0,0.0,['ctr23'],True
22 | 44983,361260,13932.0,16.0,0.0,['ctr23'],True
23 | 44984,361261,28155.0,7.0,0.0,['ctr23'],True
24 | 44987,361264,1156.0,6.0,0.0,['ctr23'],True
25 | 44989,361266,21613.0,22.0,0.0,['ctr23'],True
26 | 44990,361267,10692.0,10.0,0.0,['ctr23'],True
27 | 44992,361268,24624.0,44.0,69696.0,['ctr23'],True
28 | 44993,361269,22272.0,12.0,0.0,['ctr23'],True
29 | 45012,361272,19178.0,29.0,0.0,['ctr23'],True
30 | 41021,361616,1232.0,15.0,3600.0,"['ctr23', 'amlb_reg']",True
31 | 44960,361617,768.0,9.0,0.0,['ctr23'],True
32 | 44962,361618,517.0,13.0,0.0,['ctr23'],True
33 | 44967,361619,649.0,31.0,0.0,['ctr23'],True
34 | 44970,361621,908.0,7.0,0.0,['ctr23'],True
35 | 44994,361622,804.0,18.0,0.0,['ctr23'],True
36 | 45402,361623,3107.0,7.0,0.0,['ctr23'],True
37 | 44055,361093,4052.0,8.0,0.0,['grins_reg'],False
38 | 44056,361094,8641.0,5.0,0.0,['grins_reg'],False
39 | 44059,361096,53940.0,10.0,0.0,['grins_reg'],False
40 | 44061,361097,4209.0,360.0,0.0,['grins_reg'],False
41 | 44062,361098,10692.0,12.0,0.0,['grins_reg'],False
42 | 44063,361099,17379.0,12.0,0.0,['grins_reg'],False
43 | 44065,361101,581835.0,17.0,0.0,['grins_reg'],False
44 | 44066,361102,21613.0,18.0,0.0,['grins_reg'],False
45 | 44068,361103,394299.0,7.0,0.0,['grins_reg'],False
46 | 44069,361104,241600.0,10.0,0.0,['grins_reg'],False
47 | 45041,361287,8885.0,256.0,0.0,['grins_reg'],False
48 | 45042,361288,4177.0,9.0,0.0,['grins_reg'],False
49 | 45043,361289,52031.0,5.0,0.0,['grins_reg'],False
50 | 45045,361291,5465575.0,12.0,0.0,['grins_reg'],False
51 | 45046,361292,188318.0,125.0,0.0,['grins_reg'],False
52 | 45047,361293,1000000.0,6.0,0.0,['grins_reg'],False
53 | 45048,361294,163065.0,4.0,0.0,['grins_reg'],False
54 | 44132,361072,8192.0,22.0,0.0,['grins_reg'],False
55 | 44133,361073,15000.0,27.0,0.0,['grins_reg'],False
56 | 44134,361074,16599.0,17.0,0.0,['grins_reg'],False
57 | 44136,361076,6497.0,12.0,0.0,['grins_reg'],False
58 | 44137,361077,13750.0,34.0,0.0,['grins_reg'],False
59 | 44138,361078,20640.0,9.0,0.0,['grins_reg'],False
60 | 44139,361079,22784.0,17.0,0.0,['grins_reg'],False
61 | 44140,361080,53940.0,7.0,0.0,['grins_reg'],False
62 | 44141,361081,10692.0,9.0,0.0,['grins_reg'],False
63 | 44142,361082,17379.0,7.0,0.0,['grins_reg'],False
64 | 44143,361083,581835.0,10.0,0.0,['grins_reg'],False
65 | 44144,361084,21613.0,16.0,0.0,['grins_reg'],False
66 | 44145,361085,10081.0,7.0,0.0,['grins_reg'],False
67 | 44146,361086,163065.0,4.0,0.0,['grins_reg'],False
68 | 44147,361087,13932.0,14.0,0.0,['grins_reg'],False
69 | 44148,361088,21263.0,80.0,0.0,['grins_reg'],False
70 | 45032,361279,8885.0,43.0,0.0,['grins_reg'],False
71 | 45033,361280,4177.0,8.0,0.0,['grins_reg'],False
72 | 45034,361281,5465575.0,9.0,0.0,['grins_reg'],False
73 | 42225,233211,53940.0,10.0,0.0,['amlb_reg'],False
74 | 42571,233212,188318.0,131.0,0.0,['amlb_reg'],False
75 | 4549,233213,583250.0,78.0,0.0,['amlb_reg'],False
76 | 42572,233214,4459.0,4992.0,0.0,['amlb_reg'],False
77 | 42570,233215,4209.0,377.0,0.0,['amlb_reg'],False
78 | 42705,317614,400000.0,101.0,0.0,['amlb_reg'],False
79 | 42728,359929,10000000.0,10.0,0.0,['amlb_reg'],False
80 | 550,359930,2178.0,4.0,0.0,['amlb_reg'],False
81 | 546,359931,576.0,12.0,0.0,['amlb_reg'],False
82 | 541,359932,1156.0,6.0,0.0,['amlb_reg'],False
83 | 507,359933,3107.0,7.0,0.0,['amlb_reg'],False
84 | 505,359934,240.0,125.0,0.0,['amlb_reg'],False
85 | 287,359935,6497.0,12.0,0.0,['amlb_reg'],False
86 | 216,359936,16599.0,19.0,0.0,['amlb_reg'],False
87 | 41540,359937,166821.0,10.0,0.0,['amlb_reg'],False
88 | 42688,359938,10692.0,13.0,0.0,['amlb_reg'],False
89 | 422,359939,8885.0,267.0,0.0,['amlb_reg'],False
90 | 416,359940,8885.0,252.0,0.0,['amlb_reg'],False
91 | 42724,359941,39644.0,60.0,0.0,['amlb_reg'],False
92 | 42727,359942,7063.0,45.0,104249.0,['amlb_reg'],False
93 | 42729,359943,581835.0,19.0,0.0,['amlb_reg'],False
94 | 42726,359944,4177.0,9.0,0.0,['amlb_reg'],False
95 | 42730,359945,1994.0,127.0,39202.0,['amlb_reg'],False
96 | 201,359946,15000.0,49.0,0.0,['amlb_reg'],False
97 | 41980,359948,4440.0,117.0,27150.0,['amlb_reg'],False
98 | 42731,359949,21613.0,22.0,0.0,['amlb_reg'],False
99 | 531,359950,506.0,14.0,0.0,['amlb_reg'],False
100 | 42563,359951,1460.0,80.0,6965.0,['amlb_reg'],False
101 | 574,359952,22784.0,17.0,0.0,['amlb_reg'],False
102 | 3050,360932,5742.0,1026.0,0.0,['amlb_reg'],False
103 | 3277,360933,5766.0,1026.0,0.0,['amlb_reg'],False
104 | 43071,360945,1090.0,145.0,0.0,['amlb_reg'],False
105 |
--------------------------------------------------------------------------------
/tabdpt_datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod, abstractstaticmethod
4 |
5 | import numpy as np
6 | import scipy
7 | import torch.utils.data
8 |
9 |
10 | class Dataset(ABC):
11 | """
12 | Base class for dataset classes.
13 |
14 | A dataset class represents a suite of datasets and an instance of the class represents a single
15 | loaded dataset from the suite.
16 |
17 | To support evaluation over a range of methods, the current code assumes that the entire dataset
18 | can be loaded into memory and accessed as numpy arrays using all_instances, which also supports
19 | fast indexing. Synthetic datasets with a fixed set of instances can either generate their data at
20 | initialization or lazily.
21 |
22 | We will probably want to add separate methods for synthetic datasets that generate an indefinite
23 | number of instances on the fly, since these are less appropriate for evaluation and might not be
24 | usable for all baseline models anyway (e.g., decision trees that expect all instances to be
25 | available at once).
26 |
27 | To enable reproducible splits, datasets take an optional parameter task_id. Subclasses should
28 | handle splitting in prepare_data in a reproducible way given a task_id. They can also implement
29 | functions to generate splits that return a task_id, ensuring future reproducibility.
30 | """
31 |
32 | def __init__(self, name: str, task_id: str = None):
33 | """
34 | Params:
35 | name - the name of the dataset to load
36 | task_id - an optional ID that subclasses should use to ensure reproducibility
37 | """
38 | self.name = name
39 | self._task_id = task_id
40 | self.metadata = {"suite_name": self.suite_name(), "dataset_name": name}
41 | self.column_names = None
42 |
43 | @abstractstaticmethod
44 | def suite_name() -> str:
45 | """
46 | Name of the suite of datasets provided by this class, for cataloguing. Typically using all
47 | lowercase for consistency.
48 | """
49 | pass
50 |
51 | @abstractstaticmethod
52 | def all_names() -> list[str] | None:
53 | """
54 | Return the names of all datasets provided by this class, or None if it does not provide a
55 | fixed set of datasets
56 | """
57 | pass
58 |
59 | @abstractmethod
60 | def prepare_data(self, download_dir: str):
61 | """
62 | Download data if needed and do any CPU-side preprocessing, splitting, etc as needed.
63 | """
64 | pass
65 |
66 | def task_id(self) -> str | None:
67 | """
68 | Returns the task_id that can be used to reproduce the current task, or None if not possible.
69 |
70 | If task_id is specified as a constructor argument, this should always return the same value.
71 | Otherwise, subclasses can either choose to modify _task_id or override this method directly
72 | to return the task_id for the settings that the dataset was set up with.
73 | """
74 | return self._task_id
75 |
76 | def __len__(self) -> int:
77 | xs, _ = self.all_instances()
78 | return xs.shape[0]
79 |
80 | def __getitem__(self, i) -> tuple[np.ndarray, np.ndarray | int | float | None]:
81 | xs, ys = self.all_instances()
82 | if ys is not None:
83 | return xs[i], ys[i]
84 | return xs[i], None
85 |
86 | @abstractmethod
87 | def all_instances(self) -> tuple[np.ndarray, np.ndarray | None]:
88 | """
89 | Return all instances as a feature matrix and target vector.
90 | """
91 | pass
92 |
93 | @abstractmethod
94 | def train_inds(self) -> list[int] | np.ndarray | range:
95 | pass
96 |
97 | @abstractmethod
98 | def val_inds(self) -> list[int] | np.ndarray | range:
99 | pass
100 |
101 | @abstractmethod
102 | def test_inds(self) -> list[int] | np.ndarray | range:
103 | pass
104 |
105 | def train_instances(self) -> tuple[np.ndarray, np.ndarray]:
106 | X, y = self.all_instances()
107 | return X[self.train_inds()], y[self.train_inds()]
108 |
109 | def val_instances(self) -> tuple[np.ndarray, np.ndarray]:
110 | X, y = self.all_instances()
111 | return X[self.val_inds()], y[self.val_inds()]
112 |
113 | def test_instances(self) -> tuple[np.ndarray, np.ndarray]:
114 | X, y = self.all_instances()
115 | return X[self.test_inds()], y[self.test_inds()]
116 |
117 | def auto_populate_metadata(self):
118 | X, y = self.all_instances()
119 | self.metadata["n_rows"] = X.shape[0]
120 | self.metadata["n_features"] = X.shape[1]
121 | self.metadata["n_cells"] = X.shape[0] * X.shape[1]
122 | if y is None:
123 | self.metadata["target_type"] = "none"
124 | else:
125 | self.metadata["y_mean"] = np.mean(y)
126 | self.metadata["y_var"] = np.var(y)
127 | if "target_type" not in self.metadata:
128 | self.metadata["target_type"] = "unknown"
129 | self.metadata["n_cells"] += y.shape[0]
130 |
131 | self.metadata["frac_missing"] = np.isnan(X).mean()
132 | self.metadata["frac_rows_with_missing"] = np.isnan(X).any(axis=1).mean()
133 | self.metadata["frac_features_with_missing"] = np.isnan(X).any(axis=0).mean()
134 |
135 | lin_coeffs = []
136 | for i in range(X.shape[1]):
137 | col = np.nan_to_num(X[:, i])
138 | if np.all(np.isclose(col, col[0])):
139 | lin_coeffs.append(None)
140 | continue
141 | try:
142 | res = np.linalg.lstsq(np.stack((col, np.ones_like(col)), axis=1), y, rcond=None)
143 | lin_coeffs.append(res[0].tolist())
144 | except np.linalg.LinAlgError:
145 | lin_coeffs.append(None)
146 | continue
147 | self.metadata["column_lin_coeffs"] = lin_coeffs
148 |
149 | self.metadata["column_means"] = X.mean(axis=0).tolist()
150 | self.metadata["column_vars"] = X.var(axis=0).tolist()
151 | self.metadata["column_skews"] = scipy.stats.skew(X, axis=0).tolist()
152 | self.metadata["column_kurtoses"] = scipy.stats.kurtosis(X, axis=0).tolist()
153 |
154 |
155 | class TorchDataset(torch.utils.data.Dataset):
156 | """
157 | Utility class for accessing a Dataset split as a torch Dataset
158 | """
159 |
160 | def __init__(self, dataset: Dataset, split):
161 | assert split in ("train", "val", "test", "all")
162 | self.dataset = dataset
163 | if split == "all":
164 | self.inds = range(len(dataset))
165 | elif split == "train":
166 | self.inds = dataset.train_inds()
167 | elif split == "val":
168 | self.inds = dataset.val_inds()
169 | elif split == "test":
170 | self.inds = dataset.test_inds()
171 |
172 | def __len__(self):
173 | return len(self.inds)
174 |
175 | def __getitem__(self, i):
176 | return self.dataset[self.inds[i]]
177 |
--------------------------------------------------------------------------------
/eval_full.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import scipy
7 | import torch
8 | from sklearn.metrics import accuracy_score, f1_score, r2_score, roc_auc_score
9 | from tqdm import tqdm
10 |
11 | from tabdpt_datasets.openml import OpenMLDataset
12 | from model import TabDPTModel, Task
13 | from tabdpt import TabDPTClassifier, TabDPTRegressor
14 | from utils import FAISS, DataPreprocessor
15 |
16 |
17 | class FullEval:
18 | def __init__(
19 | self,
20 | device,
21 | max_feat,
22 | impute_method="mean",
23 | use_retrieval=False,
24 | ):
25 | """Initialize the FullEval class for evaluating tabular data.
26 |
27 | Args:
28 | device (str): The device to use for evaluation (e.g., "cuda:0" or "cpu").
29 | max_feat (int): The maximum number of features to use for evaluation.
30 | impute_method (str, optional): The imputation method to use for missing values. Defaults to "mean".
31 | use_retrieval (bool, optional): Whether to use retrieval-augmented generation. Defaults to False.
32 | """
33 | self.device = device
34 | self.max_feat = max_feat
35 | self.use_retrieval = use_retrieval
36 |
37 | # loading classification dataset
38 | df_eval_cls = pd.read_csv("data_splits/cls_datasets.csv")
39 | self.cc18_dids = df_eval_cls[df_eval_cls["test_all"] == True][
40 | "did"
41 | ].values.tolist() # 72 datasets
42 |
43 | # loading regression dataset
44 | reg_df = pd.read_csv("data_splits/reg_datasets.csv")
45 | ctr_df = reg_df[reg_df["test"] == True]
46 | self.ctr_dids = ctr_df["did"].values.tolist()
47 |
48 | # get did (dataset ID) to tid (task ID) mapping
49 | did_tid_mapping = dict(
50 | zip(
51 | df_eval_cls[df_eval_cls["test_all"] == True]["did"],
52 | df_eval_cls[df_eval_cls["test_all"] == True]["tid"],
53 | )
54 | )
55 | did_tid_mapping.update(dict(zip(ctr_df["did"], ctr_df["tid"])))
56 |
57 | self.datasets = {}
58 | for did in [*self.cc18_dids, *self.ctr_dids]:
59 | dataset = OpenMLDataset("openml_dataset", openml_task_id=int(did_tid_mapping[did]))
60 | dataset.prepare_data("data")
61 | dataset_name = dataset.openml_dataset.name
62 |
63 | X_train, y_train = dataset.train_instances()
64 | X_val, y_val = dataset.val_instances()
65 |
66 | X_train = np.concatenate([X_train, X_val], axis=0)
67 | y_train = np.concatenate([y_train, y_val], axis=0)
68 |
69 | X_test, y_test = dataset.test_instances()
70 |
71 | # TODO missing convert_cat2num step in full_dataset.py
72 | # Preprocess data
73 | preprocessor = DataPreprocessor(impute_method=impute_method)
74 | X_train = preprocessor.fit_transform(X_train)
75 | X_test = preprocessor.transform(X_test)
76 |
77 | # Create faiss index
78 | faiss_knn = FAISS(X_train, use_hnsw=False, metric="L2")
79 |
80 | # back to tensor
81 | X_train = torch.tensor(X_train).to(device)
82 | X_test = torch.tensor(X_test).to(device)
83 | y_train = torch.tensor(y_train).to(device)
84 | y_test = torch.tensor(y_test)
85 |
86 | self.datasets[did] = (dataset_name, faiss_knn, X_train, X_test, y_train, y_test)
87 |
88 | @torch.no_grad()
89 | def eval(
90 | self,
91 | model: TabDPTModel,
92 | context_length: int = 1024,
93 | inf_batch_size=512,
94 | temperature=0.8,
95 | return_individual_perfs=False,
96 | ) -> tuple[dict, Optional[dict]]:
97 | """Evaluate tdicthe model on classification and regression datasets.
98 |
99 | Args:
100 | model (TabDPTModel): The TabDPT model to evaluate.
101 | context_length (int, optional): . Defaults to 1024.
102 | inf_batch_size (int, optional): Inference batch size. Defaults to 512.
103 | temperature (float, optional): _description_. Defaults to 0.8.
104 | return_individual_perfs (bool, optional): _description_. Defaults to False.
105 |
106 | Returns:
107 | dict: performance metrics for classification and regression tasks.
108 | Optional[dict]: individual performance metrics for each dataset if return_individual_perfs is True.
109 | """
110 | classifier = TabDPTClassifier(
111 | model=model,
112 | mode=Task.CLS,
113 | device=self.device,
114 | inf_batch_size=inf_batch_size,
115 | tensor_eval=True,
116 | )
117 | regressor = TabDPTRegressor(
118 | model=model,
119 | mode=Task.REG,
120 | device=self.device,
121 | inf_batch_size=inf_batch_size,
122 | tensor_eval=True,
123 | )
124 |
125 | cls_performance = defaultdict(lambda: defaultdict(list))
126 | reg_performance = defaultdict(lambda: defaultdict(list))
127 |
128 | final_perfs = {}
129 | individual_perfs = {}
130 |
131 | # evaluation for classification datasets
132 | for did in tqdm(self.cc18_dids):
133 | dataset_name, faiss_index, X_train, X_test, y_train, y_test = self.datasets[did]
134 |
135 | classifier.fit(X_train, y_train, faiss_index)
136 | pred_val = classifier.predict_proba(
137 | X_test,
138 | temperature=temperature,
139 | context_size=context_length,
140 | use_retrieval=self.use_retrieval,
141 | )
142 |
143 | if len(np.unique(y_test)) == 2:
144 | auc = roc_auc_score(y_test, pred_val[:, 1])
145 | else:
146 | auc = roc_auc_score(y_test, pred_val, multi_class="ovo")
147 |
148 | f1 = f1_score(y_test, np.argmax(pred_val, axis=1), average="weighted")
149 | acc = accuracy_score(y_test, np.argmax(pred_val, axis=1))
150 | ce = torch.nn.functional.cross_entropy(
151 | torch.Tensor(pred_val).float(), torch.Tensor(y_test).long()
152 | )
153 | cls_performance["cc18"][did] = [acc, f1, auc, ce]
154 | individual_perfs[dataset_name] = [acc, f1, auc, ce]
155 |
156 | cls_perfs = np.array(list(cls_performance["cc18"].values()))
157 | cls_perfs_mean = cls_perfs.mean(0)
158 |
159 | final_perfs["cls-cc18-acc"] = cls_perfs_mean[0]
160 | final_perfs["cls-cc18-f1"] = cls_perfs_mean[1]
161 | final_perfs["cls-cc18-auc"] = cls_perfs_mean[2]
162 | final_perfs["cls-cc18-ce"] = cls_perfs_mean[3]
163 |
164 | # evaluation for regression datasets
165 | for did in self.ctr_dids:
166 | dataset_name, faiss_index, X_train, X_test, y_train, y_test = self.datasets[did]
167 |
168 | regressor.fit(X_train, y_train, faiss_index)
169 | pred_val = regressor.predict(
170 | X_test, context_size=context_length, use_retrieval=self.use_retrieval
171 | ).flatten()
172 | # scaler = StandardScaler()
173 | # y_test = scaler.fit_transform(y_test.reshape(-1, 1)).flatten()
174 |
175 | mse = np.mean((y_test.cpu().numpy() - pred_val) ** 2)
176 | correlation = scipy.stats.pearsonr(y_test.cpu().numpy(), pred_val.flatten())
177 | r2 = r2_score(y_test.cpu().numpy(), pred_val)
178 |
179 | reg_performance["ctr"][did] = [mse, correlation[0], r2]
180 | individual_perfs[dataset_name] = [mse, correlation[0], r2]
181 |
182 | reg_perfs = np.array(list(reg_performance["ctr"].values()))
183 | reg_perfs_mean = reg_perfs.mean(0)
184 | final_perfs["reg-ctr-mse"] = reg_perfs_mean[0]
185 | final_perfs["reg-ctr-cor"] = reg_perfs_mean[1]
186 | final_perfs["reg-ctr-r2"] = reg_perfs_mean[2]
187 |
188 | if return_individual_perfs:
189 | return final_perfs, individual_perfs
190 | else:
191 | return final_perfs
192 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from omegaconf import DictConfig
7 |
8 | from transformer_layer import TransformerEncoderLayer
9 |
10 |
11 | class Task(Enum):
12 | """Enum representing the type of task for the model.
13 | REG: Regression task.
14 | CLS: Classification task.
15 | """
16 |
17 | REG = 1
18 | CLS = 2
19 |
20 |
21 | def pad_x(X: torch.Tensor, num_features: int) -> torch.Tensor:
22 | """Pad the input tensor X with zeros to match the specified number of features.
23 | Args:
24 | X (torch.Tensor): Input tensor of shape (seq_len, batch_size, n_features).
25 | num_features (int): Desired number of features after padding.
26 | Returns:
27 | torch.Tensor: Padded tensor of shape (seq_len, batch_size, num_features).
28 | """
29 | seq_len, batch_size, n_features = X.shape
30 | zero_feature_padding = torch.zeros(
31 | (seq_len, batch_size, num_features - n_features), device=X.device
32 | )
33 | return torch.cat([X, zero_feature_padding], -1)
34 |
35 |
36 | def maskmean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor:
37 | """Compute the mean of x along the specified dimension, ignoring masked values.
38 | Args:
39 | x (torch.Tensor): Input tensor of shape (time, batch, hidden dimension).
40 | mask (torch.Tensor): Mask tensor of the same shape as x, where True indicates valid values.
41 | dim (int): Dimension along which to compute the mean.
42 | Returns:
43 | torch.Tensor: Mean of x along the specified dimension, with masked values ignored.
44 | """
45 | x = torch.where(mask, x, 0)
46 | return x.sum(dim=dim, keepdim=True) / mask.sum(dim=dim, keepdim=True)
47 |
48 |
49 | def maskstd(x: torch.Tensor, mask: torch.Tensor, dim: int = 0):
50 | """Compute the standard deviation of x along the specified dimension, ignoring masked values.
51 | Args:
52 | x (torch.Tensor): Input tensor of shape (time, batch, hidden dimension).
53 | mask (torch.Tensor): Mask tensor of the same shape as x, where True indicates valid values.
54 | dim (int): Dimension along which to compute the standard deviation.
55 | Returns:
56 | torch.Tensor: Standard deviation of x along the specified dimension, with masked values ignored.
57 | """
58 | num = mask.sum(dim=dim, keepdim=True)
59 | mean = maskmean(x, mask, dim=0)
60 | diffs = torch.where(mask, mean - x, 0)
61 | return ((diffs**2).sum(dim=0, keepdim=True) / (num - 1)) ** 0.5
62 |
63 |
64 | def normalize_data(data: torch.Tensor, eval_pos: int) -> torch.Tensor:
65 | """Normalize the input data by subtracting the mean and dividing by the standard deviation.
66 |
67 | Args:
68 | data (torch.Tensor): input data of shape (time, batch, hidden dimension).
69 | eval_pos (int): Evaluation position used for slicing attention keys and values.
70 |
71 | Returns:
72 | torch.Tensor: normalized data of shape (time, batch, hidden dimension).
73 | """
74 | X = data[:eval_pos] if eval_pos > 0 else data
75 | mask = ~torch.isnan(X)
76 | mean = maskmean(X, mask, dim=0)
77 | std = maskstd(X, mask, dim=0) + 1e-6
78 | data = (data - mean) / std
79 | return data
80 |
81 |
82 | def clip_outliers(data: torch.Tensor, eval_pos: int, n_sigma: int = 4):
83 | """
84 | clip outliers in data based on a given number of standard deviations.
85 |
86 | Args:
87 | data (torch.Tensor): Input data of shape (time, batch, hidden dimension).
88 | eval_pos (int): Evaluation position used for slicing attention keys and values.
89 | n_sigma (int): Number of standard deviations to use for clipping outliers.
90 | Returns:
91 | torch.Tensor: Data with outliers clipped, of shape (time, batch, hidden dimension)."""
92 | assert len(data.shape) == 3, "X must be T,B,H"
93 | X = data[:eval_pos] if eval_pos > 0 else data
94 | mask = ~torch.isnan(X)
95 | mean = maskmean(X, mask, dim=0)
96 | cutoff = n_sigma * maskstd(X, mask, dim=0)
97 | mask &= cutoff >= torch.abs(X - mean)
98 | cutoff = n_sigma * maskstd(X, mask, dim=0)
99 | return torch.clip(data, mean - cutoff, mean + cutoff)
100 |
101 |
102 | def convert_to_torch_tensor(input: np.ndarray | torch.Tensor) -> torch.Tensor:
103 | """Convert a NumPy array or a PyTorch tensor to a PyTorch tensor.
104 | Args:
105 | input (np.ndarray | torch.Tensor): Input data to be converted.
106 | Returns:
107 | torch.Tensor: Converted PyTorch tensor.
108 | Raises:
109 | TypeError: If the input is neither a NumPy array nor a PyTorch tensor.
110 | """
111 | if isinstance(input, np.ndarray):
112 | return torch.from_numpy(input)
113 | elif torch.is_tensor(input):
114 | return input
115 | else:
116 | raise TypeError("Input must be a NumPy array or a PyTorch tensor.")
117 |
118 |
119 | class TabDPTModel(nn.Module):
120 | def __init__(
121 | self,
122 | dropout: float,
123 | n_out: int,
124 | nhead: int,
125 | nhid: int,
126 | ninp: int,
127 | nlayers: int,
128 | num_features: int,
129 | ):
130 | """TabDPTModel initialization.
131 |
132 | Args:
133 | dropout (float): Dropout rate.
134 | n_out (int): Number of output classes.
135 | nhead (int): Number of attention heads.
136 | nhid (int): Hidden dimension.
137 | ninp (int): Input dimension.
138 | nlayers (int): Number of transformer layers.
139 | num_features (int): Number of input features.
140 | """
141 | super().__init__()
142 | self.n_out = n_out # number of output classes
143 | self.ninp = ninp # embedding dimension
144 | self.transformer_encoder = nn.ModuleList(
145 | [
146 | TransformerEncoderLayer(
147 | embed_dim=ninp,
148 | num_heads=nhead,
149 | ff_dim=nhid,
150 | )
151 | for _ in range(nlayers)
152 | ]
153 | )
154 | self.num_features = num_features
155 | self.encoder = nn.Linear(num_features, ninp)
156 | self.dropout = nn.Dropout(p=dropout)
157 | self.y_encoder = nn.Linear(1, ninp)
158 | self.head = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out + 1))
159 |
160 | def forward(
161 | self,
162 | x_src: torch.Tensor,
163 | y_src: torch.Tensor,
164 | return_log_act_norms: bool = False,
165 | ) -> torch.Tensor:
166 | """Forward pass of the TabDPTModel.
167 | Args:
168 | x_src (torch.Tensor): Input features of shape (time, batch, hidden dimension).
169 | y_src (torch.Tensor): Target values of shape (T, B).
170 | return_log_act_norms (bool): Whether to return activation norms for logging.
171 | Returns:
172 | torch.Tensor: Predicted values of shape (T, B, n_out + 1).
173 | """
174 | eval_pos = y_src.shape[0]
175 |
176 | # preproces features by normalizing and clipping outliers
177 | x_src = clip_outliers(x_src, -1 if self.training else eval_pos, n_sigma=4)
178 | x_src = normalize_data(x_src, -1 if self.training else eval_pos)
179 | x_src = clip_outliers(x_src, -1 if self.training else eval_pos, n_sigma=4)
180 | x_src = torch.nan_to_num(x_src, nan=0)
181 |
182 | # feature encoding
183 | x_src = self.encoder(x_src)
184 | mean = (x_src**2).mean(dim=-1, keepdim=True)
185 | rms = torch.sqrt(mean)
186 | x_src = x_src / rms
187 |
188 | # target encoding
189 | y_src = self.y_encoder(y_src.unsqueeze(-1))
190 | train_x = x_src[:eval_pos] + y_src
191 | src = torch.cat([train_x, x_src[eval_pos:]], 0)
192 |
193 | log_act_norms = {}
194 | log_act_norms["y"] = torch.norm(y_src, dim=-1).mean()
195 |
196 | # transformer layers
197 | for l, layer in enumerate(self.transformer_encoder):
198 | if l in [0, 1, 3, 6, 9]:
199 | log_act_norms[f"layer_{l}"] = torch.norm(src, dim=-1).mean()
200 | src = layer(src, eval_pos)
201 |
202 | # final head
203 | pred = self.head(src)
204 |
205 | if return_log_act_norms:
206 | return pred[eval_pos:], log_act_norms
207 | else:
208 | return pred[eval_pos:]
209 |
210 | @classmethod
211 | def load(cls, model_state: dict, config: DictConfig) -> nn.Module:
212 | """Load a pre-trained TabDPTModel from a state dictionary.
213 |
214 | Args:
215 | cls: TODO
216 | model_state (dict): state dictionary containing the model parameters.
217 | config (DictConfig): configuration object containing model parameters.
218 |
219 | Returns:
220 | nn.Module: model instance with loaded parameters.
221 | """
222 | # TODO loading model inside its own class without self?
223 | assert config.model.max_num_classes > 2
224 | model = TabDPTModel(
225 | dropout=config.training.dropout,
226 | n_out=config.model.max_num_classes,
227 | nhead=config.model.nhead,
228 | nhid=config.model.emsize * config.model.nhid_factor,
229 | ninp=config.model.emsize,
230 | nlayers=config.model.nlayers,
231 | num_features=config.model.max_num_features,
232 | )
233 |
234 | module_prefix = "_orig_mod."
235 | model_state = {k.replace(module_prefix, ""): v for k, v in model_state.items()}
236 | model.load_state_dict(model_state)
237 | model.to(config.env.device)
238 | model.eval()
239 | return model
240 |
--------------------------------------------------------------------------------
/tabdpt_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import openml
3 | import pandas as pd
4 | import torch
5 |
6 |
7 | def get_openml_classification(did, max_samples, multiclass=True, shuffled=True):
8 | dataset = openml.datasets.get_dataset(did)
9 | X, y, categorical_indicator, attribute_names = dataset.get_data(
10 | dataset_format="array", target=dataset.default_target_attribute
11 | )
12 |
13 | if not multiclass:
14 | X = X[y < 2]
15 | y = y[y < 2]
16 |
17 | if multiclass and not shuffled:
18 | raise NotImplementedError("This combination of multiclass and shuffling isn't implemented")
19 |
20 | if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray):
21 | print("Not a NP Array, skipping")
22 | return None, None, None, None
23 |
24 | if not shuffled:
25 | sort = np.argsort(y) if y.mean() < 0.5 else np.argsort(-y)
26 | pos = int(y.sum()) if y.mean() < 0.5 else int((1 - y).sum())
27 | X, y = X[sort][-pos * 2 :], y[sort][-pos * 2 :]
28 | y = torch.tensor(y).reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).float()
29 | X = (
30 | torch.tensor(X)
31 | .reshape(2, -1, X.shape[1])
32 | .transpose(0, 1)
33 | .reshape(-1, X.shape[1])
34 | .flip([0])
35 | .float()
36 | )
37 | else:
38 | order = np.arange(y.shape[0])
39 | np.random.seed(13)
40 | np.random.shuffle(order)
41 | X, y = torch.tensor(X[order]), torch.tensor(y[order])
42 | if max_samples:
43 | X, y = X[:max_samples], y[:max_samples]
44 |
45 | return X, y, list(np.where(categorical_indicator)[0]), attribute_names
46 |
47 |
48 | def load_openml_list(
49 | dids,
50 | filter_for_nan=False,
51 | num_feats=100,
52 | min_samples=100,
53 | max_samples=400,
54 | multiclass=True,
55 | max_num_classes=10,
56 | shuffled=True,
57 | return_capped=False,
58 | ):
59 | datasets = []
60 | openml_list = openml.datasets.list_datasets(dids)
61 | print(f"Number of datasets: {len(openml_list)}")
62 |
63 | datalist = pd.DataFrame.from_dict(openml_list, orient="index")
64 | if filter_for_nan:
65 | datalist = datalist[datalist["NumberOfInstancesWithMissingValues"] == 0]
66 | print(f"Number of datasets after Nan and feature number filtering: {len(datalist)}")
67 |
68 | for ds in datalist.index:
69 | modifications = {
70 | "samples_capped": False,
71 | "classes_capped": False,
72 | "feats_capped": False,
73 | }
74 | entry = datalist.loc[ds]
75 |
76 | if entry["NumberOfClasses"] == 0.0:
77 | raise Exception("Regression not supported")
78 | # X, y, categorical_feats, attribute_names = get_openml_regression(int(entry.did), max_samples)
79 | else:
80 | X, y, categorical_feats, attribute_names = get_openml_classification(
81 | int(entry.did), max_samples, multiclass=multiclass, shuffled=shuffled
82 | )
83 | if X is None:
84 | continue
85 |
86 | if X.shape[1] > num_feats:
87 | if return_capped:
88 | X = X[:, 0:num_feats]
89 | categorical_feats = [c for c in categorical_feats if c < num_feats]
90 | modifications["feats_capped"] = True
91 | else:
92 | print("Too many features")
93 | continue
94 | if X.shape[0] == max_samples:
95 | modifications["samples_capped"] = True
96 |
97 | if X.shape[0] < min_samples:
98 | print("Too few samples left")
99 | continue
100 |
101 | if len(np.unique(y)) > max_num_classes:
102 | if return_capped:
103 | X = X[y < np.unique(y)[10]]
104 | y = y[y < np.unique(y)[10]]
105 | modifications["classes_capped"] = True
106 | else:
107 | print("Too many classes")
108 | continue
109 |
110 | datasets += [[entry["name"], X, y, categorical_feats, attribute_names, modifications]]
111 |
112 | return datasets, datalist
113 |
114 |
115 | # Classification
116 | valid_dids_classification = [13, 59, 4, 15, 40710, 43, 1498]
117 | test_dids_classification = [
118 | 973,
119 | 1596,
120 | 40981,
121 | 1468,
122 | 40984,
123 | 40975,
124 | 41163,
125 | 41147,
126 | 1111,
127 | 41164,
128 | 1169,
129 | 1486,
130 | 41143,
131 | 1461,
132 | 41167,
133 | 40668,
134 | 41146,
135 | 41169,
136 | 41027,
137 | 23517,
138 | 41165,
139 | 41161,
140 | 41159,
141 | 41138,
142 | 1590,
143 | 41166,
144 | 1464,
145 | 41168,
146 | 41150,
147 | 1489,
148 | 41142,
149 | 3,
150 | 12,
151 | 31,
152 | 54,
153 | 1067,
154 | ]
155 | valid_large_classification = [
156 | 943,
157 | 23512,
158 | 49,
159 | 838,
160 | 1131,
161 | 767,
162 | 1142,
163 | 748,
164 | 1112,
165 | 1541,
166 | 384,
167 | 912,
168 | 1503,
169 | 796,
170 | 20,
171 | 30,
172 | 903,
173 | 4541,
174 | 961,
175 | 805,
176 | 1000,
177 | 4135,
178 | 1442,
179 | 816,
180 | 1130,
181 | 906,
182 | 1511,
183 | 184,
184 | 181,
185 | 137,
186 | 1452,
187 | 1481,
188 | 949,
189 | 449,
190 | 50,
191 | 913,
192 | 1071,
193 | 831,
194 | 843,
195 | 9,
196 | 896,
197 | 1532,
198 | 311,
199 | 39,
200 | 451,
201 | 463,
202 | 382,
203 | 778,
204 | 474,
205 | 737,
206 | 1162,
207 | 1538,
208 | 820,
209 | 188,
210 | 452,
211 | 1156,
212 | 37,
213 | 957,
214 | 911,
215 | 1508,
216 | 1054,
217 | 745,
218 | 1220,
219 | 763,
220 | 900,
221 | 25,
222 | 387,
223 | 38,
224 | 757,
225 | 1507,
226 | 396,
227 | 4153,
228 | 806,
229 | 779,
230 | 746,
231 | 1037,
232 | 871,
233 | 717,
234 | 1480,
235 | 1010,
236 | 1016,
237 | 981,
238 | 1547,
239 | 1002,
240 | 1126,
241 | 1459,
242 | 846,
243 | 837,
244 | 1042,
245 | 273,
246 | 1524,
247 | 375,
248 | 1018,
249 | 1531,
250 | 1458,
251 | 6332,
252 | 1546,
253 | 1129,
254 | 679,
255 | 389,
256 | ]
257 |
258 | open_cc_dids = [
259 | 11,
260 | 14,
261 | 15,
262 | 16,
263 | 18,
264 | 22,
265 | 23,
266 | 29,
267 | 31,
268 | 37,
269 | 50,
270 | 54,
271 | 188,
272 | 458,
273 | 469,
274 | 1049,
275 | 1050,
276 | 1063,
277 | 1068,
278 | 1510,
279 | 1494,
280 | 1480,
281 | 1462,
282 | 1464,
283 | 6332,
284 | 23381,
285 | 40966,
286 | 40982,
287 | 40994,
288 | 40975,
289 | ]
290 | # Filtered by N_samples < 2000, N feats < 100, N classes < 10
291 |
292 | open_cc_valid_dids = [
293 | 13,
294 | 25,
295 | 35,
296 | 40,
297 | 41,
298 | 43,
299 | 48,
300 | 49,
301 | 51,
302 | 53,
303 | 55,
304 | 56,
305 | 59,
306 | 61,
307 | 187,
308 | 285,
309 | 329,
310 | 333,
311 | 334,
312 | 335,
313 | 336,
314 | 337,
315 | 338,
316 | 377,
317 | 446,
318 | 450,
319 | 451,
320 | 452,
321 | 460,
322 | 463,
323 | 464,
324 | 466,
325 | 470,
326 | 475,
327 | 481,
328 | 679,
329 | 694,
330 | 717,
331 | 721,
332 | 724,
333 | 733,
334 | 738,
335 | 745,
336 | 747,
337 | 748,
338 | 750,
339 | 753,
340 | 756,
341 | 757,
342 | 764,
343 | 765,
344 | 767,
345 | 774,
346 | 778,
347 | 786,
348 | 788,
349 | 795,
350 | 796,
351 | 798,
352 | 801,
353 | 802,
354 | 810,
355 | 811,
356 | 814,
357 | 820,
358 | 825,
359 | 826,
360 | 827,
361 | 831,
362 | 839,
363 | 840,
364 | 841,
365 | 844,
366 | 852,
367 | 853,
368 | 854,
369 | 860,
370 | 880,
371 | 886,
372 | 895,
373 | 900,
374 | 906,
375 | 907,
376 | 908,
377 | 909,
378 | 915,
379 | 925,
380 | 930,
381 | 931,
382 | 934,
383 | 939,
384 | 940,
385 | 941,
386 | 949,
387 | 966,
388 | 968,
389 | 984,
390 | 987,
391 | 996,
392 | 1048,
393 | 1054,
394 | 1071,
395 | 1073,
396 | 1100,
397 | 1115,
398 | 1412,
399 | 1442,
400 | 1443,
401 | 1444,
402 | 1446,
403 | 1447,
404 | 1448,
405 | 1451,
406 | 1453,
407 | 1488,
408 | 1490,
409 | 1495,
410 | 1498,
411 | 1499,
412 | 1506,
413 | 1508,
414 | 1511,
415 | 1512,
416 | 1520,
417 | 1523,
418 | 4153,
419 | 23499,
420 | 40496,
421 | 40646,
422 | 40663,
423 | 40669,
424 | 40680,
425 | 40682,
426 | 40686,
427 | 40690,
428 | 40693,
429 | 40705,
430 | 40706,
431 | 40710,
432 | 40711,
433 | 40981,
434 | 41430,
435 | 41538,
436 | 41919,
437 | 41976,
438 | 42172,
439 | 42261,
440 | 42544,
441 | 42585,
442 | 42638,
443 | ]
444 |
445 | grinzstjan_categorical_regression = [
446 | 44054,
447 | 44055,
448 | 44056,
449 | 44057,
450 | 44059,
451 | 44061,
452 | 44062,
453 | 44063,
454 | 44064,
455 | 44065,
456 | 44066,
457 | 44068,
458 | 44069,
459 | ]
460 |
461 | grinzstjan_numerical_classification = [
462 | 44089,
463 | 44090,
464 | 44091,
465 | 44120,
466 | 44121,
467 | 44122,
468 | 44123,
469 | 44124,
470 | 44125,
471 | 44126,
472 | 44127,
473 | 44128,
474 | 44129,
475 | 44130,
476 | 44131,
477 | ]
478 |
479 | grinzstjan_categorical_classification = [44156, 44157, 44159, 44160, 44161, 44162, 44186]
480 |
481 | grinzstjan_classification = (
482 | grinzstjan_numerical_classification + grinzstjan_categorical_classification
483 | )
484 |
--------------------------------------------------------------------------------
/tabdpt_datasets/talent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | from pathlib import Path
4 |
5 | import gdown
6 | import numpy as np
7 | from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
8 |
9 | from tabdpt_datasets.dataset import Dataset
10 |
11 |
12 | class TalentDataset(Dataset):
13 | """
14 | Dataset for LAMDA-TALENT baseline
15 | """
16 |
17 | @staticmethod
18 | def suite_name():
19 | return "talent"
20 |
21 | @staticmethod
22 | def all_names():
23 | return [
24 | "1000-Cameras-Dataset",
25 | "2dplanes",
26 | "3D_Estimation_using_RSSI_of_WLAN_dataset",
27 | "3D_Estimation_using_RSSI_of_WLAN_dataset_complete_1_target",
28 | "abalone",
29 | "Abalone_reg",
30 | "accelerometer",
31 | "ada",
32 | "ada_agnostic",
33 | "ada_prior",
34 | "Ailerons",
35 | "airfoil_self_noise",
36 | "airlines_seed_0_nrows_2000_nclasses_10_ncols_100_stratify_True",
37 | "allbp",
38 | "allrep",
39 | "Amazon_employee_access",
40 | "analcatdata_authorship",
41 | "analcatdata_supreme",
42 | "Another-Dataset-on-used-Fiat-500-(1538-rows)",
43 | "archive2",
44 | "archive_r56_Maths",
45 | "archive_r56_Portuguese",
46 | "artificial-characters",
47 | "ASP-POTASSCO-classification",
48 | "auction_verification",
49 | "autoUniv-au4-2500",
50 | "autoUniv-au7-1100",
51 | "avocado_sales",
52 | "bank",
53 | "bank32nh",
54 | "bank8FM",
55 | "Bank_Customer_Churn_Dataset",
56 | "banknote_authentication",
57 | "baseball",
58 | "Basketball_c",
59 | "Bias_correction_r",
60 | "Bias_correction_r_2",
61 | "BLE_RSSI_dataset_for_Indoor_localization",
62 | "blogfeedback",
63 | "BNG(breast-w)",
64 | "BNG(cmc)",
65 | "BNG(echoMonths)",
66 | "BNG(lowbwt)",
67 | "BNG(mv)",
68 | "BNG(stock)",
69 | "BNG(tic-tac-toe)",
70 | "Brazilian_houses_reproduced",
71 | "California-Housing-Classification",
72 | "Cardiovascular-Disease-dataset",
73 | "car-evaluation",
74 | "CDC_Diabetes_Health_Indicators",
75 | "churn",
76 | "Click_prediction_small",
77 | "cmc",
78 | "combined_cycle_power_plant",
79 | "communities_and_crime",
80 | "company_bankruptcy_prediction",
81 | "compass",
82 | "compass_reg",
83 | "concrete_compressive_strength",
84 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-10.0GHz(Urbinati)",
85 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-10.5GHz(Urbinati)",
86 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-11.0GHz(Urbinati)",
87 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-9.0GHz(Urbinati)",
88 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-9.5GHz(Urbinati)",
89 | "contraceptive_method_choice",
90 | "CookbookReviews",
91 | "CPMP-2015-regression",
92 | "CPMP-2015-runtime-regression",
93 | "CPS1988",
94 | "cpu_act",
95 | "cpu_small",
96 | "credit",
97 | "Credit_c",
98 | "credit_reg",
99 | "Customer_Personality_Analysis",
100 | "customer_satisfaction_in_airline",
101 | "dabetes_130-us_hospitals",
102 | "Data_Science_for_Good_Kiva_Crowdfunding",
103 | "Data_Science_Salaries",
104 | "dataset_sales",
105 | "debutanizer",
106 | "default_of_credit_card_clients",
107 | "delta_ailerons",
108 | "delta_elevators",
109 | "Diabetic_Retinopathy_Debrecen",
110 | "Diamonds",
111 | "dis",
112 | "dna",
113 | "drug_consumption",
114 | "dry_bean_dataset",
115 | "E-CommereShippingData",
116 | "eeg-eye-state",
117 | "electricity",
118 | "elevators",
119 | "Employee",
120 | "estimation_of_obesity_levels",
121 | "eye_movements",
122 | "eye_movements_bin",
123 | "Facebook_Comment_Volume",
124 | "FICO-HELOC-cleaned",
125 | "fifa",
126 | "Firm-Teacher_Clave-Direction_Classification",
127 | "first-order-theorem-proving",
128 | "Fitness_Club_c",
129 | "Food_Delivery_Time",
130 | "FOREX_audcad-day-High",
131 | "FOREX_audcad-hour-High",
132 | "FOREX_audchf-day-High",
133 | "FOREX_audjpy-day-High",
134 | "FOREX_audjpy-hour-High",
135 | "FOREX_audsgd-hour-High",
136 | "FOREX_audusd-hour-High",
137 | "FOREX_cadjpy-day-High",
138 | "FOREX_cadjpy-hour-High",
139 | "fried",
140 | "GAMETES_Epistasis_2-Way_20atts_0.1H_EDM-1_1",
141 | "GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_50_EDM-2_001",
142 | "garments_worker_productivity",
143 | "gas-drift",
144 | "gas_turbine_CO_and_NOx_emission",
145 | "Gender_Gap_in_Spanish_WP",
146 | "GesturePhaseSegmentationProcessed",
147 | "gina_agnostic",
148 | "golf_play_dataset_extended",
149 | "Goodreads-Computer-Books",
150 | "healthcare_insurance_expenses",
151 | "Heart-Disease-Dataset-(Comprehensive)",
152 | "heloc",
153 | "hill-valley",
154 | "house_16H",
155 | "house_16H_reg",
156 | "house_8L",
157 | "houses",
158 | "house_sales_reduced",
159 | "housing_price_prediction",
160 | "HR_Analytics_Job_Change_of_Data_Scientists",
161 | "htru",
162 | "ibm-employee-performance",
163 | "IBM_HR_Analytics_Employee_Attrition_and_Performance",
164 | "IEEE80211aa-GATS",
165 | "Indian_pines",
166 | "INNHotelsGroup",
167 | "Insurance",
168 | "internet_firewall",
169 | "internet_usage",
170 | "Intersectional-Bias-Assessment",
171 | "in_vehicle_coupon_recommendation",
172 | "Is-this-a-good-customer",
173 | "JapaneseVowels",
174 | "jm1",
175 | "Job_Profitability",
176 | "jungle_chess_2pcs_raw_endgame_complete",
177 | "Kaggle_bike_sharing_demand_challange",
178 | "kc1",
179 | "KDD",
180 | "KDDCup09_upselling",
181 | "kdd_ipums_la_97-small",
182 | "kin8nm",
183 | "kropt",
184 | "kr-vs-k",
185 | "Laptop_Prices_Dataset",
186 | "Large-scale_Wave_Energy_Farm_Perth_100",
187 | "Large-scale_Wave_Energy_Farm_Perth_49",
188 | "Large-scale_Wave_Energy_Farm_Sydney_100",
189 | "Large-scale_Wave_Energy_Farm_Sydney_49",
190 | "law-school-admission-bianry",
191 | "led24",
192 | "led7",
193 | "letter",
194 | "Long",
195 | "madeline",
196 | "MagicTelescope",
197 | "mammography",
198 | "Marketing_Campaign",
199 | "maternal_health_risk",
200 | "mauna-loa-atmospheric-co2",
201 | "mfeat-factors",
202 | "mfeat-fourier",
203 | "mfeat-karhunen",
204 | "mfeat-morphological",
205 | "mfeat-pixel",
206 | "mfeat-zernike",
207 | "MiamiHousing2016",
208 | "MIC",
209 | "mice_protein_expression",
210 | "microaggregation2",
211 | "mobile_c36_oversampling",
212 | "Mobile_Phone_Market_in_Ghana",
213 | "Mobile_Price_Classification",
214 | "mozilla4",
215 | "mv",
216 | "NASA_PHM2008",
217 | "naticusdroid+android+permissions+dataset",
218 | "National_Health_and_Nutrition_Health_Survey",
219 | "national-longitudinal-survey-binary",
220 | "NHANES_age_prediction",
221 | "okcupid_stem",
222 | "one-hundred-plants-margin",
223 | "one-hundred-plants-shape",
224 | "one-hundred-plants-texture",
225 | "online_shoppers",
226 | "optdigits",
227 | "ozone_level",
228 | "ozone-level-8hr",
229 | "page-blocks",
230 | "Parkinson_Multiple_Sound_Recording",
231 | "Parkinsons_Telemonitoring",
232 | "pc1",
233 | "pc3",
234 | "pc4",
235 | "pendigits",
236 | "Performance-Prediction",
237 | "PhishingWebsites",
238 | "phoneme",
239 | "Physicochemical_r",
240 | "PieChart3",
241 | "Pima_Indians_Diabetes_Database",
242 | "PizzaCutter3",
243 | "pol",
244 | "pole",
245 | "pol_reg",
246 | "predict_students_dropout_and_academic_success",
247 | "puma32H",
248 | "puma8NH",
249 | "Pumpkin_Seeds",
250 | "qsar",
251 | "qsar_aquatic_toxicity",
252 | "QSAR_biodegradation",
253 | "qsar_fish_toxicity",
254 | "Rain_in_Australia",
255 | "rice_cammeo_and_osmancik",
256 | "ringnorm",
257 | "rl",
258 | "Satellite",
259 | "satellite_image",
260 | "satimage",
261 | "SDSS17",
262 | "segment",
263 | "seismic+bumps",
264 | "semeion",
265 | "sensory",
266 | "shill-bidding",
267 | "Shipping",
268 | "Shop_Customer_Data",
269 | "shrutime",
270 | "shuttle",
271 | "Smoking_and_Drinking_Dataset_with_body_signal",
272 | "socmob",
273 | "space_ga",
274 | "spambase",
275 | "splice",
276 | "sports_articles_for_objectivity_analysis",
277 | "statlog",
278 | "steel_industry_energy_consumption",
279 | "steel_plates_faults",
280 | "stock",
281 | "stock_fardamento02",
282 | "Student_Alcohol_Consumption",
283 | "Student_Performance_Portuguese",
284 | "sulfur",
285 | "Superconductivty",
286 | "svmguide3",
287 | "sylvine",
288 | "taiwanese_bankruptcy_prediction",
289 | "telco-customer-churn",
290 | "Telecom_Churn_Dataset",
291 | "texture",
292 | "thyroid",
293 | "thyroid-ann",
294 | "thyroid-dis",
295 | "topo_2_1",
296 | "treasury",
297 | "turiye_student_evaluation",
298 | "twonorm",
299 | "UJIndoorLoc",
300 | "UJI_Pen_Characters",
301 | "vehicle",
302 | "volkert",
303 | "volume",
304 | "VulNoneVul",
305 | "walking-activity",
306 | "wall-robot-navigation",
307 | "water_quality",
308 | "Water_Quality_and_Potability",
309 | "Waterstress",
310 | "waveform-5000",
311 | "waveform_database_generator",
312 | "waveform_database_generator_version_1",
313 | "weather_izmir",
314 | "website_phishing",
315 | "Wilt",
316 | "wind",
317 | "wine",
318 | "wine+quality",
319 | "Wine_Quality_red",
320 | "wine-quality-red",
321 | "Wine_Quality_white",
322 | "wine-quality-white",
323 | "yeast",
324 | ]
325 |
326 | def __init__(self, name, task_id=None):
327 | super().__init__(name, task_id)
328 |
329 | def prepare_data(self, download_dir):
330 | sub_dir = Path(download_dir) / "talent"
331 | if not sub_dir.exists():
332 | zip_path = os.path.join(download_dir, "talent.zip")
333 | gdown.download(id="1-dzY-BhMzcqjCM8vMTkVwa0hOYQ1598T", output=zip_path)
334 | with zipfile.ZipFile(zip_path, "r") as z:
335 | z.extractall(sub_dir)
336 | os.remove(zip_path)
337 |
338 | # Data is split into numeric 'N' and categorical 'C' files
339 | N_splits = []
340 | C_splits = []
341 | y_splits = []
342 | for split in ("train", "val", "test"):
343 | npy = sub_dir / "data" / self.name / f"N_{split}.npy"
344 | if npy.exists():
345 | N_splits.append(np.load(npy, allow_pickle=True).astype(np.float32))
346 | npy = sub_dir / "data" / self.name / f"C_{split}.npy"
347 | if npy.exists():
348 | C_splits.append(np.load(npy, allow_pickle=True))
349 | y_splits.append(
350 | np.load(sub_dir / "data" / self.name / f"y_{split}.npy", allow_pickle=True)
351 | )
352 | if C_splits:
353 | C = OrdinalEncoder().fit_transform(np.concatenate(C_splits, axis=0))
354 | if N_splits:
355 | N = np.concatenate(N_splits, axis=0)
356 | self.X = np.concatenate((N, C), axis=1)
357 | self.metadata["categorical_feature_inds"] = [
358 | i + N.shape[1] for i in range(C.shape[1])
359 | ]
360 | else:
361 | self.X = C
362 | self.metadata["categorical_feature_inds"] = list(range(C.shape[1]))
363 | else:
364 | self.X = np.concatenate(N_splits, axis=0)
365 | self.y = np.concatenate(y_splits, axis=0)
366 | if self.y.dtype == "object":
367 | self.y = LabelEncoder().fit_transform(self.y)
368 | self.metadata["target_type"] = "classification"
369 | else:
370 | self.metadata["target_type"] = "regression"
371 | self.y = self.y.squeeze()
372 |
373 | train_len, val_len, test_len = len(y_splits[0]), len(y_splits[1]), len(y_splits[2])
374 | self._train_inds = range(train_len)
375 | self._val_inds = range(train_len, train_len + val_len)
376 | self._test_inds = range(train_len + val_len, train_len + val_len + test_len)
377 |
378 | def all_instances(self):
379 | return self.X, self.y
380 |
381 | def train_inds(self):
382 | return self._train_inds
383 |
384 | def val_inds(self):
385 | return self._val_inds
386 |
387 | def test_inds(self):
388 | return self._test_inds
389 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import random
4 | import sys
5 | from typing import Optional
6 |
7 | import faiss
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | import torch.distributed as dist
12 | import torch.nn as nn
13 | from sklearn.impute import SimpleImputer
14 | from sklearn.preprocessing import LabelEncoder, StandardScaler
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 |
18 | def compute_losses(output, task, y, config):
19 | """
20 | Compute the averaged classification and regression losses.
21 |
22 | Parameters:
23 | - output: Tensor of shape [L, B, C+1], the outputs of the network.
24 | - task: Tensor of shape [B, 1, 1], containing 0 for classification or 1 for regression.
25 | - y: Tensor of shape [L, B, 1], the target values.
26 |
27 | Returns:
28 | - class_loss: Averaged classification loss (CrossEntropyLoss) or None if no classification task.
29 | - reg_loss: Averaged regression loss (MSELoss) or None if no regression task.
30 | """
31 |
32 | # reshape the output and y tensors to [B, L, C+1] and [B, L, 1] respectively
33 | output = output.transpose(0, 1)
34 | y = y.transpose(0, 1)
35 |
36 | # Flatten the task tensor to [B]
37 | task = task.view(-1)
38 | classification_mask = task == 0
39 | regression_mask = task == 1
40 |
41 | class_loss = torch.zeros(1, device=output.device)
42 | reg_loss = torch.zeros(1, device=output.device)
43 |
44 | # Classification Loss
45 | if classification_mask.any():
46 | # Select classification batches
47 | outputs_class = output[classification_mask] # Shape: [num_class_batches, L, C+1]
48 | y_class = y[classification_mask] # Shape: [num_class_batches, L, 1]
49 |
50 | # Reshape for CrossEntropyLoss: [N, C+1], targets [N]
51 | outputs_class = outputs_class.view(-1, outputs_class.size(-1)) # [N, C+1]
52 | y_class = y_class.view(-1).long().squeeze() # [N]
53 |
54 | # Compute CrossEntropyLoss
55 | ce_loss_fn = nn.CrossEntropyLoss(
56 | reduction="mean", label_smoothing=config.training.label_smoothing
57 | )
58 | class_loss = ce_loss_fn(outputs_class, y_class)
59 |
60 | # Regression Loss
61 | if regression_mask.any():
62 | # Select regression batches
63 | outputs_reg = output[regression_mask, :, -1] # Shape: [num_reg_batches, L]
64 | y_reg = y[regression_mask].squeeze(-1) # Shape: [num_reg_batches, L]
65 |
66 | # Compute MSELoss
67 | mse_loss_fn = nn.MSELoss(reduction="mean")
68 | reg_loss = mse_loss_fn(outputs_reg, y_reg)
69 |
70 | return class_loss, reg_loss
71 |
72 |
73 | def get_combined_loss(loss_cls, loss_reg, task, config):
74 | """
75 | Combine classification and regression losses based on the task ratio.
76 | Parameters:
77 | - loss_cls: Tensor, classification loss.
78 | - loss_reg: Tensor, regression loss.
79 | - task: Tensor, indicating the task type (0 for classification, 1 for regression).
80 | - config: Configuration object containing training parameters.
81 | Returns:
82 | - loss: Tensor, combined loss.
83 | """
84 | reg_ratio = task.float().mean()
85 | loss_cls = loss_cls / config.training.num_agg
86 | loss_reg = loss_reg / config.training.num_agg
87 | loss = loss_cls * (1 - reg_ratio) + loss_reg * reg_ratio
88 | return loss
89 |
90 |
91 | class FAISS:
92 | """
93 | This class initializes a FAISS index with the provided data and allows for efficient nearest neighbor search.
94 | """
95 |
96 | def __init__(
97 | self, X: np.ndarray, use_hnsw: bool = False, hnsw_m: int = 32, metric: str = "L2"
98 | ) -> None:
99 | """
100 | Initializes the FAISS index with the provided data.
101 | Args:
102 | X (np.ndarray): The data to index, shape should be (n_samples, n_features).
103 | use_hnsw (bool): Whether to use HNSW index or not.
104 | hnsw_m (int): The number of bi-directional links created for each element in the HNSW index.
105 | metric (str): The distance metric to use, either "L2" for Euclidean distance or "IP" for inner product.
106 | """
107 | assert isinstance(X, np.ndarray), "X must be a numpy array"
108 | X = np.ascontiguousarray(X)
109 | X = X.astype(np.float32)
110 | if use_hnsw:
111 | self.index = faiss.IndexHNSWFlat(X.shape[1], hnsw_m)
112 | if metric == "L2":
113 | self.index.metric_type = faiss.METRIC_L2
114 | elif metric == "IP":
115 | self.index.metric_type = faiss.METRIC_INNER_PRODUCT
116 | else:
117 | raise NotImplementedError
118 | else:
119 | if metric == "L2":
120 | self.index = faiss.IndexFlatL2(X.shape[1])
121 | elif metric == "IP":
122 | self.index = faiss.IndexFlatIP(X.shape[1])
123 | else:
124 | raise NotImplementedError
125 | self.index.add(X)
126 |
127 | def get_knn_indices(self, queries: np.ndarray | torch.Tensor, k: int) -> np.ndarray:
128 | """retreive the k-nearest neighbors indices for the given queries.
129 |
130 | Args:
131 | queries (np.ndarray|torch.Tensor): query points for which to find nearest neighbors.
132 | k (int): number of nearest neighbors to retrieve.
133 |
134 | Returns:
135 | np.ndarray: k nearest neighbors indices for each query point.
136 | """
137 | if isinstance(queries, torch.Tensor):
138 | queries = queries.cpu().numpy()
139 | queries = np.ascontiguousarray(queries)
140 | assert isinstance(k, int)
141 |
142 | knns = self.index.search(queries, k)
143 | indices_Xs = knns[1]
144 | return indices_Xs
145 |
146 |
147 | def seed_everything(seed: int):
148 | """Set the random seed for reproducibility.
149 |
150 | Args:
151 | seed (int): seed value to set for random number generators.
152 | """
153 | random.seed(seed)
154 | os.environ["PYTHONHASHSEED"] = str(seed)
155 | np.random.seed(seed)
156 | torch.manual_seed(seed)
157 | torch.cuda.manual_seed(seed)
158 | torch.backends.cudnn.deterministic = True
159 | torch.backends.cudnn.benchmark = True
160 |
161 |
162 | def cleanup():
163 | """Cleanup the distributed training environment."""
164 | dist.destroy_process_group()
165 |
166 |
167 | def signal_handler(*_):
168 | """Handle Ctrl+C signal to gracefully exit the program."""
169 | print("Received Ctrl+C, exiting...")
170 | cleanup()
171 | sys.exit(0)
172 |
173 |
174 | def get_module(model: nn.Module) -> nn.Module:
175 | """Get the underlying module from a DistributedDataParallel model.
176 | Args:
177 | model (torch.nn.Module): The model, possibly wrapped in DistributedDataParallel.
178 | Returns:
179 | torch.nn.Module: The underlying module if model is DistributedDataParallel, else the model itself.
180 | """
181 | return model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
182 |
183 |
184 | def log_param_norms(
185 | model: nn.Module, writer: SummaryWriter, step: int, task: str, global_step: int
186 | ):
187 | """Log the norms of model parameters to TensorBoard.
188 |
189 | Args:
190 | model (nn.Module): model to log parameter norms for.
191 | writer (SummaryWriter): TensorBoard writer to log the norms.
192 | step (int): current training step, used for logging.
193 | task (str): regression or classification task, used for logging.
194 | global_step (int): current global step, used for logging.
195 | """
196 | model = get_module(model)
197 | # Log encoder weight and bias norms (if bias exists)
198 | writer.add_scalar("norms/encoder_weight", torch.norm(model.encoder.weight).item(), step)
199 | if model.encoder.bias is not None:
200 | writer.add_scalar("norms/encoder_bias", torch.norm(model.encoder.bias).item(), step)
201 |
202 | # Log y_encoder weight and bias norms (if bias exists)
203 | writer.add_scalar("norms/y_encoder_weight", torch.norm(model.y_encoder.weight).item(), step)
204 | if model.y_encoder.bias is not None:
205 | writer.add_scalar("norms/y_encoder_bias", torch.norm(model.y_encoder.bias).item(), step)
206 |
207 | # Log transformer encoder total weights and biases (if biases exist)
208 | transformer_weights_norm = sum(
209 | torch.norm(layer.kv_proj.weight).item() for layer in model.transformer_encoder
210 | )
211 |
212 | writer.add_scalar("norms/transformer_weights", transformer_weights_norm, step)
213 |
214 | # Log classifier head weight and bias norms
215 | head_weights_norm = sum(torch.norm(param).item() for param in model.head.parameters())
216 |
217 | writer.add_scalar("norms/head_weights_biases", head_weights_norm, step)
218 |
219 | total_norm = torch.norm(
220 | torch.stack(
221 | [torch.norm(p.grad.detach(), 2) for p in model.parameters() if p.grad is not None]
222 | ),
223 | 2,
224 | )
225 | total_norm1 = torch.norm(
226 | torch.stack(
227 | [torch.norm(p.grad.detach(), 1) for p in model.parameters() if p.grad is not None]
228 | ),
229 | 1,
230 | )
231 | param_norm = torch.norm(torch.stack([torch.norm(p.detach(), 2) for p in model.parameters()]), 2)
232 | writer.add_scalar("Gradient/Norm/", total_norm, global_step=global_step)
233 | writer.add_scalar("Gradient/Norm_L1/", total_norm1, global_step=global_step)
234 | writer.add_scalar("Parameter/Norm/", param_norm, global_step=global_step)
235 | writer.add_scalar("Prior/reg_ratio", task.float().mean().item(), global_step=global_step)
236 |
237 |
238 | def print_on_master_only(is_master: bool):
239 | """Override the built-in print function to only print on the master process.
240 | Args:
241 | is_master (bool): Whether the current process is the master process.
242 | """
243 | import builtins as __builtin__
244 |
245 | builtin_print = __builtin__.print
246 |
247 | def print(*args, **kwargs):
248 | force = kwargs.pop("force", False)
249 | if is_master or force:
250 | builtin_print(*args, **kwargs)
251 |
252 | __builtin__.print = print
253 |
254 |
255 | def init_dist(device: str):
256 | """Initialize distributed training.
257 |
258 | Args:
259 | device (str): The device to use for training (e.g., "cuda:0" or "cpu").
260 |
261 | Returns:
262 | bool: Whether distributed training is being used.
263 | int: The rank of the current process.
264 | str: The device to use for training.
265 | """
266 | if "LOCAL_RANK" in os.environ:
267 | rank = int(os.environ["LOCAL_RANK"])
268 | print("torch.distributed.launch and my rank is", rank)
269 | torch.cuda.set_device(rank)
270 | torch.distributed.init_process_group(
271 | backend="nccl",
272 | init_method="env://",
273 | timeout=datetime.timedelta(seconds=20),
274 | world_size=torch.cuda.device_count(),
275 | rank=rank,
276 | )
277 | torch.distributed.barrier()
278 | print_on_master_only(rank == 0)
279 | return True, rank, f"cuda:{rank}"
280 | else:
281 | return False, 0, device
282 |
283 |
284 | class DataPreprocessor:
285 | def __init__(self, impute_method: str = "mean", encode_y: bool = False):
286 | """Initialize the DataPreprocessor.
287 | Parameters:
288 | - impute_method: str, method for imputing missing values (default is "mean").
289 | - encode_y: bool, whether to encode the target variable y (default is False).
290 | """
291 | self.impute_method = impute_method
292 | self.encode_y = encode_y
293 | # Initialize the imputer and scaler
294 | self.imputer = SimpleImputer(strategy=impute_method)
295 | self.scaler = StandardScaler()
296 |
297 | def convert_cat2num(
298 | self, X: pd.DataFrame, y: pd.Series | None = None
299 | ) -> tuple[np.ndarray, np.ndarray, list[bool]]:
300 | """Convert categorical columns to numerical values.
301 |
302 | Parameters:
303 | - X: DataFrame, input features.
304 | - y: Series or None, target variable.
305 |
306 | Returns:
307 | - X: numpy array, features with categorical columns converted to numerical.
308 | - y: numpy array, target variable with categorical values converted to numerical.
309 | - cat_vals: list of bool, indicating which columns were categorical.
310 | """
311 | cat_vals = []
312 | # Convert categorical columns to numerical values using dataframe information
313 | for col in X.columns:
314 | if X[col].dtype == "object" or pd.api.types.is_categorical_dtype(X[col]):
315 | le = LabelEncoder()
316 | X[col] = le.fit_transform(X[col])
317 | cat_vals.append(True)
318 | else:
319 | cat_vals.append(False)
320 | X = X.to_numpy().astype(np.float32)
321 |
322 | # Convert categorical target to numerical values
323 | if y is not None:
324 | if y.dtype == "object" or pd.api.types.is_categorical_dtype(y):
325 | le = LabelEncoder()
326 | y = le.fit_transform(y)
327 | cat_vals.append(True)
328 | else:
329 | y = y.to_numpy().astype(np.float32)
330 | cat_vals.append(False)
331 | else:
332 | y = X[:, -1]
333 | X = np.delete(X, -1, axis=1)
334 | return X, y, cat_vals
335 |
336 | def fit_transform(self, X: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
337 | """
338 | Fit the imputer and scaler on the training data and transform it.
339 |
340 | Parameters:
341 | - X: numpy array, training features.
342 |
343 | Returns:
344 | - X: numpy array, transformed training features.
345 | """
346 | # Impute missing values
347 | X = self.imputer.fit_transform(X)
348 |
349 | # Scale the features
350 | X = self.scaler.fit_transform(X)
351 |
352 | return X
353 |
354 | def transform(self, X: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
355 | """
356 | Transform the test data using the fitted imputer and scaler.
357 |
358 | Parameters:
359 | - X: numpy array, testing features.
360 |
361 | Returns:
362 | - X: numpy array, transformed testing features.
363 | """
364 | # Impute missing values
365 | X = self.imputer.transform(X)
366 |
367 | # Scale the features
368 | X = self.scaler.transform(X)
369 |
370 | return X
371 |
372 |
373 | def standardize(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
374 | """Standardize the input tensor by subtracting the mean and dividing by the standard deviation.
375 | Args:
376 | tensor (torch.Tensor): Input tensor to standardize.
377 | Returns:
378 | torch.Tensor: Standardized tensor.
379 | torch.Tensor: Mean of the input tensor.
380 | torch.Tensor: Standard deviation of the input tensor.
381 | """
382 | y_means = tensor.mean(dim=0)
383 | y_stds = tensor.std(dim=0) + 1e-6
384 | return (tensor - y_means) / y_stds, y_means, y_stds
385 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | import signal
5 |
6 | import hydra
7 | import openml
8 | import torch.distributed as dist
9 | import schedulefree
10 | import torch
11 | import torch.nn as nn
12 | from omegaconf import DictConfig, OmegaConf
13 | from torch.nn.attention import SDPBackend, sdpa_kernel
14 | from torch.utils.data import DataLoader
15 | from torch.utils.tensorboard import SummaryWriter
16 | from tqdm import tqdm
17 |
18 | from dataset import FullDataset, collate_fn
19 | from eval_full import FullEval
20 | from model import TabDPTModel
21 | from utils import (
22 | cleanup,
23 | compute_losses,
24 | get_combined_loss,
25 | init_dist,
26 | log_param_norms,
27 | seed_everything,
28 | signal_handler,
29 | )
30 |
31 |
32 | def save(
33 | model: nn.Module,
34 | optimizer: torch.optim.Optimizer,
35 | config: DictConfig,
36 | stats: dict,
37 | path: str,
38 | name: str,
39 | ) -> None:
40 | """Save the model and optimizer state.
41 |
42 | Args:
43 | model (nn.Module): The model to save.
44 | optimizer (torch.optim.Optimizer): The optimizer to save.
45 | config (DictConfig): The configuration for the experiment.
46 | stats (dict): The training statistics.
47 | path (str): The path to save the checkpoint.
48 | name (str): The name of the checkpoint file.
49 |
50 | Returns:
51 | None
52 | """
53 | ckpt = {
54 | "model": model.state_dict(),
55 | "opt": optimizer.state_dict(),
56 | "cfg": config,
57 | "stats": stats,
58 | }
59 | torch.save(ckpt, f"{path}/{name}.ckpt")
60 |
61 |
62 | def save_eval_callback(
63 | model: nn.Module,
64 | optimizer: torch.optim.Optimizer,
65 | vals: list,
66 | epoch: int,
67 | config: DictConfig,
68 | writer: SummaryWriter,
69 | rank: int,
70 | stats: dict,
71 | ) -> None:
72 | """Save evaluation checkpoints and log metrics.
73 |
74 | Args:
75 | model (nn.Module): The model being trained.
76 | optimizer (torch.optim.Optimizer): The optimizer used for training.
77 | vals (list): A list of validation datasets.
78 | epoch (int): The current epoch number.
79 | config (DictConfig): The configuration for the training process.
80 | writer (SummaryWriter): The TensorBoard writer for logging.
81 | rank (int): The rank of the current process.
82 | stats (dict): A dictionary to store training statistics.
83 |
84 | Returns:
85 | None
86 | """
87 | # save latest model
88 | stats["epoch_in_training"] = epoch
89 | if rank == 0:
90 | save(model, optimizer, config, stats, config.exp_path, "latest")
91 |
92 | # eval and save best
93 | model.eval()
94 | if hasattr(optimizer, "eval"):
95 | optimizer.eval()
96 |
97 | for val in vals:
98 | for name, metric in val.eval(model, context_length=config.training.eval_seq_len).items():
99 | if rank == 0:
100 | print(f"Epoch {epoch} | {name}: {metric}")
101 | # only save checkpoints using the validation metric
102 | if metric > stats.get(f"best_{name}", -float("inf")):
103 | stats[f"best_{name}"] = metric
104 | if name in config.logging.save_metrics:
105 | save(model, optimizer, config, stats, config.exp_path, f"best_{name}")
106 | writer.add_scalar(f"val/{name}/", metric, epoch)
107 |
108 | # reset the model and optimizer to training mode
109 | model.train()
110 | if hasattr(optimizer, "train"):
111 | optimizer.train()
112 |
113 |
114 | def set_experiment_from_config(config: DictConfig) -> tuple[nn.Module, torch.optim.Optimizer, dict]:
115 | """Set up the experiment based on the provided configuration.
116 |
117 | Args:
118 | config (DictConfig): The configuration for the experiment.
119 |
120 | Raises:
121 | Exception: If the reset policy is not recognized.
122 |
123 | Returns:
124 | Tuple[nn.Module, torch.optim.Optimizer, dict]: A tuple containing the model,
125 | optimizer, and a dictionary with training statistics.
126 | """
127 |
128 | load_state_from_saved = False
129 | stats = {"epoch_in_training": 0}
130 |
131 | seed_everything(config.seed)
132 |
133 | if hasattr(config.data, "single_dataset_id") and config.data.single_dataset_id is not None:
134 | # Load the dataset metadata (we assume that download_data is not necessary)
135 | dataset = openml.datasets.get_dataset(config.data.single_dataset_id, download_data=False)
136 | # Use the dataset name for the experiment (replace spaces with underscores for file system safety)
137 | dataset_name = dataset.name.replace(" ", "_")
138 | config.exp_name = dataset_name
139 | print("Using single dataset for training:", dataset_name)
140 | config.exp_path = f"runs/{config.folder}/{config.exp_name}"
141 |
142 | # Only rank 0 should modify the directory
143 | is_master = (not torch.distributed.is_initialized()) or (torch.distributed.get_rank() == 0)
144 |
145 | # directory setup and checkpoint handling
146 | if is_master:
147 | if os.path.exists(config.exp_path):
148 | print("Directory already exists:", config.exp_path)
149 | # restart training
150 | if config.training.reset_policy == "rm" or not os.path.exists(
151 | f"{config.exp_path}/latest.ckpt"
152 | ):
153 | print("Remove the existing directory.")
154 | shutil.rmtree(config.exp_path, ignore_errors=True)
155 | os.makedirs(config.exp_path, exist_ok=True)
156 | # continue training from a saved checkpoint
157 | elif config.training.reset_policy == "cnt":
158 | load_state_from_saved = True
159 | checkpoint = torch.load(f"{config.exp_path}/latest.ckpt")
160 | model_state = checkpoint["model"]
161 | opt_state = checkpoint["opt"]
162 | stats = checkpoint["stats"]
163 | non_saved_num_epochs = config.training.num_epochs
164 | config.training.num_epochs = non_saved_num_epochs
165 | print("Continue training. Using saved config.")
166 | else:
167 | raise ValueError(
168 | "Invalid reset_policy: must be either 'cnt' (resume) or 'rm' (delete)."
169 | )
170 | else:
171 | os.makedirs(config.exp_path, exist_ok=True)
172 | print("Created directory: ", config.exp_path)
173 |
174 | # Synchronize all processes so that the directory is created before continuing
175 | if torch.distributed.is_initialized():
176 | torch.distributed.barrier()
177 |
178 | # Check if the number of GPUs matches the configuration
179 | # assert torch.cuda.device_count() == len(
180 | # config.env.gpus
181 | # ), f"Number of GPUs does not match the number of GPUs in the config, expected {len(config.env.gpus)}, found {torch.cuda.device_count()}"
182 |
183 | assert (
184 | config.training.batch_size % len(config.env.gpus) == 0
185 | ), "Batch size should be divisible by the number of GPUs"
186 |
187 | # adapt batch size for distributed training
188 | config.training.batch_size //= len(config.env.gpus)
189 |
190 | # save config
191 | OmegaConf.save(config=config, f=f"{config.exp_path}/config.yaml")
192 |
193 | print(f"Using {config.env.gpus} GPUs")
194 |
195 | # instantiate the model
196 | model = TabDPTModel(
197 | dropout=config.training.dropout,
198 | n_out=config.model.max_num_classes,
199 | nhead=config.model.nhead,
200 | nhid=config.model.nhid_factor * config.model.emsize,
201 | ninp=config.model.emsize,
202 | nlayers=config.model.nlayers,
203 | num_features=config.model.max_num_features,
204 | )
205 | print(f"{sum(p.numel() for p in model.parameters())/1e6:.{2}f} M parameters")
206 |
207 | # load the model state if resuming from a saved checkpoint
208 | # TODO : handle model_state definition in a more robust way
209 | if load_state_from_saved:
210 | model = TabDPTModel.load(model_state, config)
211 | del model_state
212 |
213 | # instantiate the optimizer
214 | optimizer = schedulefree.AdamWScheduleFree(
215 | model.parameters(),
216 | lr=config.training.lr,
217 | weight_decay=config.training.weight_decay,
218 | warmup_steps=1000,
219 | betas=(0.98, 0.999),
220 | )
221 |
222 | # load the optimizer state if resuming from a saved checkpoint
223 | # TODO : handle model_state definition in a more robust way
224 | if load_state_from_saved:
225 | optimizer.load_state_dict(opt_state)
226 | del opt_state
227 |
228 | return model, optimizer, stats
229 |
230 |
231 | @hydra.main(version_base=None, config_path="configs", config_name="default_config")
232 | def main(config: DictConfig):
233 | """Main function to run the training process.
234 |
235 | Args:
236 | config (DictConfig): The configuration for the training process.
237 | """
238 |
239 | # signal handling
240 | signal.signal(signal.SIGINT, signal_handler)
241 |
242 | # print config
243 | print("Config:", OmegaConf.to_yaml(config))
244 |
245 | # distributed training
246 | using_dist, rank, device = init_dist(config.env.device)
247 |
248 | model, optimizer, stats = set_experiment_from_config(config)
249 | model.to(device)
250 |
251 | # compile the model if specified in the config
252 | if config.training.compile:
253 | model = torch.compile(model)
254 |
255 | # using distributed data parallel if specified in the config
256 | if using_dist:
257 | print("Distributed training enabled.")
258 | model = torch.nn.parallel.DistributedDataParallel(
259 | model,
260 | device_ids=[rank],
261 | output_device=rank,
262 | )
263 |
264 | # Initialize the SummaryWriter for TensorBoard logging
265 | writer = SummaryWriter(config.exp_path)
266 |
267 | # set validation datasets
268 | # TODO : make this configurable
269 | vals = [
270 | FullEval(
271 | device=device,
272 | max_feat=config.model.max_num_features,
273 | use_retrieval=config.data.eval_retrieval,
274 | )
275 | ]
276 |
277 | # select the device type
278 | device_type = "cpu" if config.env.device == "cpu" else "cuda"
279 |
280 | # Use autocast for mixed precision training
281 | selected_autocast = torch.autocast(device_type=device_type, dtype=torch.bfloat16)
282 |
283 | # Initialize the FullDataset for training
284 | print("Initializing Dataset...")
285 | dataset = FullDataset(device, config)
286 |
287 | # initialize the DataLoader for the dataset
288 | data_loader = DataLoader(
289 | dataset,
290 | batch_size=config.training.batch_size,
291 | num_workers=config.env.num_workers,
292 | pin_memory=True,
293 | collate_fn=collate_fn,
294 | shuffle=False,
295 | persistent_workers=True,
296 | )
297 |
298 | # Create an iterator for the DataLoader
299 | iter_data_loader = iter(data_loader)
300 |
301 | # Initialize the outer training loop
302 | for epoch in range(stats["epoch_in_training"] + 1, config.training.num_epochs + 1):
303 | print("Epoch ", epoch)
304 |
305 | # initialize the loss accumulators
306 | epoch_loss_cls = 0.0
307 | epoch_loss_reg = 0.0
308 |
309 | # set the model and optimizer to training mode
310 | model.train()
311 | if hasattr(optimizer, "train"):
312 | optimizer.train()
313 |
314 | # initialize the inner training loop
315 | for batch in tqdm(range(config.training.num_model_updates * config.training.num_agg)):
316 | # randomly set the evaluation position, i.e. the context length
317 | # ---- 1. choose a common eval_pos ---------------------------------------
318 | if dist.get_rank() == 0:
319 | eval_pos_t = torch.randint(
320 | config.training.min_eval_pos,
321 | config.training.max_eval_pos + 1,
322 | (1,),
323 | device=device,
324 | dtype=torch.long,
325 | )
326 | else:
327 | eval_pos_t = torch.empty(1, dtype=torch.long, device=device)
328 |
329 | dist.broadcast(eval_pos_t, src=0)
330 | eval_pos = int(eval_pos_t.item())
331 | # eval_pos = random.randint(config.training.min_eval_pos, config.training.max_eval_pos)
332 |
333 | # get the next batch from the DataLoader
334 | x, y, task = [a.to(device) for a in next(iter_data_loader)]
335 |
336 | # efficient forward pass and loss computation
337 | with selected_autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
338 | output, log_act_norms = model(
339 | x, y.squeeze(-1)[:eval_pos], return_log_act_norms=True
340 | )
341 |
342 | # compute the losses
343 | loss_cls, loss_reg = compute_losses(output, task, y.squeeze(-1)[eval_pos:], config)
344 |
345 | # reweight the loss for combined classification and regression tasks
346 | loss = get_combined_loss(loss_cls, loss_reg, task, config)
347 |
348 | # detach the log_act_norms to avoid memory leak
349 | log_act_norms = {k: v.detach().item() for k, v in log_act_norms.items()}
350 |
351 | epoch_loss_cls += loss_cls.cpu().detach().item()
352 | epoch_loss_reg += loss_reg.cpu().detach().item()
353 |
354 | # backpropagation
355 | loss.backward()
356 |
357 | # gradient accumulation
358 | if (batch + 1) % config.training.num_agg == 0:
359 | # compute the global step
360 | global_step = (
361 | batch + config.training.num_model_updates * config.training.num_agg * epoch
362 | )
363 |
364 | # log the losses and norms
365 | log_param_norms(model, writer, global_step, task, global_step)
366 |
367 | # clip gradients to avoid exploding gradients
368 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.clip_grad_norm)
369 |
370 | # step the optimizer
371 | optimizer.step()
372 | optimizer.zero_grad()
373 |
374 | # log the average losses for the epoch
375 | writer.add_scalar(
376 | "loss/cls-average_train/", epoch_loss_cls / config.training.num_model_updates, epoch
377 | )
378 | writer.add_scalar(
379 | "loss/reg-average_train/", epoch_loss_reg / config.training.num_model_updates, epoch
380 | )
381 |
382 | # evaluate the model every eval_every epochs
383 | if epoch % config.logging.eval_every == 0:
384 | # synchronize all processes before evaluation
385 | if using_dist:
386 | torch.distributed.barrier()
387 |
388 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
389 | model_eval: nn.Module = model.module
390 | else:
391 | model_eval: nn.Module = model
392 |
393 | # save evaluation checkpoints and log metrics
394 | save_eval_callback(model_eval, optimizer, vals, epoch, config, writer, rank, stats)
395 |
396 | if using_dist:
397 | torch.distributed.barrier()
398 |
399 | # cleanup the model and optimizer
400 | if using_dist:
401 | cleanup()
402 |
403 | writer.close()
404 |
405 |
406 | if __name__ == "__main__":
407 | main()
408 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | import openml
6 | import pandas as pd
7 | import torch
8 | from omegaconf import DictConfig
9 | from sklearn.preprocessing import LabelEncoder
10 | from torch.utils.data import Dataset
11 | from tqdm import tqdm
12 |
13 | from model import pad_x
14 | from utils import FAISS, DataPreprocessor
15 |
16 |
17 | def collate_fn(batch):
18 | """Collate function for DataLoader.
19 |
20 | Args:
21 | batch (list): A list of samples from the dataset.
22 |
23 | Returns:
24 | tuple: A tuple containing the concatenated input tensors and the corresponding target tensors.
25 | """
26 | ret = [torch.cat(el, dim=1) for el in zip(*batch)]
27 | return ret
28 |
29 |
30 | class FullDataset(Dataset):
31 | """FullDataset class for loading and processing datasets.
32 | This class inherits from `torch.utils.data.Dataset` and is used to load datasets from OpenML,
33 | preprocess them, and provide samples for training a model.
34 | It supports both regression and classification tasks, with the ability to use k-NN for context sampling.
35 | """
36 |
37 | def __init__(self, device: str, config: DictConfig):
38 | """FullDataset initialization.
39 |
40 | Args:
41 | device (str): The device to use for tensor operations.
42 | config (DictConfig): The configuration for the dataset.
43 | """
44 | self.steps_per_epoch = (
45 | config.training.num_agg * config.training.num_model_updates * config.training.num_epochs
46 | )
47 | self.batch_size = config.training.batch_size
48 | self.device = device
49 | self.context_length = getattr(config.training, "seq_len", 1024)
50 | self.retrieval = getattr(config.data, "retrieval", True)
51 | self.max_feat = config.model.max_num_features
52 | self.y_reg_augment = config.data.y_reg_augment
53 |
54 | # Load dataset IDs from CSV file
55 | # TODO: make dataset_ids configurable
56 | self.dataset_ids = (
57 | pd.read_csv("data_splits/noleak_training_datasets.csv")["did"].values.ravel().tolist()
58 | )
59 | self.datasets = defaultdict(dict)
60 |
61 | # Load datasets and preprocess them
62 | for did in tqdm(self.dataset_ids):
63 | # download the dataset from OpenML
64 | dataset = openml.datasets.get_dataset(
65 | did,
66 | download_data=False,
67 | download_qualities=False,
68 | download_features_meta_data=False,
69 | )
70 |
71 | # Get the data as a pandas DataFrame
72 | X, y, _, _ = dataset.get_data(
73 | dataset_format="dataframe", target=dataset.default_target_attribute
74 | )
75 |
76 | # TODO: preprocessing is a bit different from the usual .fit_transform() since concatenation happens in the middle
77 | preprocessor = DataPreprocessor()
78 | # Preprocess the data
79 | X, y, cat_vals = preprocessor.convert_cat2num(X, y)
80 | X = preprocessor.imputer.fit_transform(X)
81 |
82 | # Concatenate the target as the last column
83 | X = np.concatenate([X, y[:, None]], axis=1)
84 | X = preprocessor.scaler.fit_transform(X)
85 |
86 | if config.data.retrieval:
87 | faiss_knn = FAISS(X, metric="L2", use_hnsw=False)
88 | else:
89 | faiss_knn = None
90 |
91 | self.datasets[did] = {"X": X, "index": faiss_knn, "cat": cat_vals}
92 |
93 | def __len__(self):
94 | """Return the total number of samples in the dataset."""
95 | return self.steps_per_epoch * self.batch_size
96 |
97 | def transform_target(
98 | self, y: np.ndarray, cls_threshold: float = 10, cls_prob: float = 0.3
99 | ) -> tuple[np.ndarray, str]:
100 | """Transform the target variable into classification or regression format.
101 | Args:
102 | y (np.ndarray): The target variable.
103 | cls_threshold (int): The threshold for classification.
104 | cls_prob (float): The probability of returning a classification task.
105 | Returns:
106 | tuple: A tuple containing the transformed target variable and the task type ("cls" or "reg").
107 | """
108 | unique_y = np.unique(y)
109 | if len(unique_y) > cls_threshold:
110 | if np.random.rand() > cls_prob:
111 | return y, "reg"
112 | else:
113 | num_class = np.random.randint(2, cls_threshold)
114 | cls_boundary = np.random.choice(
115 | sorted(np.unique(y))[1:-1], num_class - 1, replace=False
116 | )
117 | y = (y[:, None] > cls_boundary[None, :]).sum(1)
118 | return y, "cls"
119 | elif len(unique_y) >= 2:
120 | le = LabelEncoder()
121 | y = le.fit_transform(y)
122 | y = y.astype(np.float32)
123 | return y, "cls"
124 | else:
125 | return y, "reg"
126 |
127 | def check_cls_sample(self, X: torch.Tensor, y: torch.Tensor) -> bool:
128 | """Check if the classification sample is valid.
129 | Args:
130 | X (torch.Tensor): The input features.
131 | y (torch.Tensor): The target labels.
132 | Returns:
133 | bool: True if the sample is valid, False otherwise.
134 | """
135 | assert y.dim() == 3 and y.shape[1] == 1 and y.shape[2] == 1
136 | assert X.dim() == 3 and X.shape[1] == 1
137 | return True
138 |
139 | def check_reg_sample(self, X: torch.Tensor, y: torch.Tensor) -> bool:
140 | """Check if the regression sample is valid.
141 |
142 | Args:
143 | X (torch.Tensor): The input features.
144 | y (torch.Tensor): The target labels.
145 | Returns:
146 | bool: True if the sample is valid, False otherwise.
147 | """
148 | assert y.dim() == 3 and y.shape[1] == 1 and y.shape[2] == 1
149 | assert X.dim() == 3 and X.shape[1] == 1
150 | if torch.max(y.abs().ravel()) > 10:
151 | return False
152 | if np.max(np.unique(y.ravel(), return_counts=True)[1]) > 0.95 * y.numel():
153 | return False
154 | return True
155 |
156 | @torch.no_grad()
157 | def __getitem__(self, _) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
158 | """get a sample from the dataset.
159 |
160 | Returns:
161 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The input features, target values, and task type.
162 | """
163 | i = 0
164 | while i < 100:
165 | X, y, task = self.generate_sample()
166 | if task == 0:
167 | if self.check_cls_sample(X, y):
168 | break
169 | else:
170 | if self.check_reg_sample(X, y):
171 | break
172 | i += 1
173 | return X, y, task
174 |
175 | @torch.no_grad()
176 | def context_sampler(self, num_features: int, num_samples: int, X_sample: torch.Tensor):
177 | """Sample a context from the input tensor.
178 |
179 | Args:
180 | num_features (int): The number of features in the input tensor.
181 | num_samples (int): The number of samples in the input tensor.
182 | X_sample (torch.Tensor): The input tensor.
183 |
184 | Returns:
185 | tuple[torch.Tensor, torch.Tensor, str]: The context features, target values, and task type.
186 | """
187 | target_idx = random.randint(0, num_features - 1)
188 |
189 | # If context length > the number of samples, we need to sample with replacement
190 | if self.context_length > num_samples:
191 | random_selection_indices = np.random.choice(num_samples, num_samples, replace=False)
192 | random_selection_indices = np.concatenate(
193 | [
194 | np.random.choice(num_samples, self.context_length - num_samples, replace=True),
195 | random_selection_indices,
196 | ]
197 | )
198 | else:
199 | # If context length <= the number of samples, we can sample without replacement
200 | random_selection_indices = np.random.choice(
201 | num_samples, self.context_length, replace=False
202 | )
203 | # Select the samples and extract the target feature
204 | X_nni = X_sample[random_selection_indices]
205 | y_nni = X_nni[:, target_idx]
206 |
207 | # Remove the target feature from the feature set
208 | X_nni = np.delete(X_nni, target_idx, axis=1)
209 |
210 | # Transform the target feature into classification or regression format
211 | y_nni, task = self.transform_target(y_nni.ravel(), cls_threshold=10, cls_prob=0.5)
212 |
213 | # Reshape the tensors to match the expected dimensions
214 | X_nni = X_nni.reshape(-1, 1, X_nni.shape[-1])
215 | y_nni = y_nni.reshape(-1, 1, 1)
216 |
217 | # Select a random subset of features
218 | # Ensure that the number of features sampled is within the limits
219 | num_features_sampled = random.randint(
220 | min(self.max_feat // 2, X_nni.shape[-1] // 2), min(self.max_feat, X_nni.shape[-1])
221 | )
222 | random_feature_indices = np.random.choice(
223 | X_nni.shape[-1], num_features_sampled, replace=False
224 | )
225 | X_nni = X_nni[..., random_feature_indices]
226 |
227 | # Convert to PyTorch tensors
228 | X_nni = torch.Tensor(X_nni)
229 | y_nni = torch.Tensor(y_nni)
230 |
231 | return X_nni, y_nni, task
232 |
233 | @torch.no_grad()
234 | def use_knn(self, X_sample, index_sample, num_samples, num_features):
235 | """Use k-NN to find similar samples.
236 |
237 | Args:
238 | X_sample (np.ndarray): The input features.
239 | index_sample (FAISS): The FAISS index for k-NN search.
240 | num_samples (int): The number of samples.
241 | num_features (int): The number of features.
242 |
243 | Returns:
244 | Tuple[torch.Tensor, torch.Tensor, str]: The context features, target values, and task type.
245 | """
246 | # Randomly select a query point
247 | x_q_idx = random.randint(0, num_samples - 1)
248 | x_q = X_sample[x_q_idx].reshape(1, -1).copy()
249 | x_q = x_q.astype(np.float32)
250 |
251 | # Randomly select a target feature index
252 | target_idx = random.randint(0, num_features - 1)
253 |
254 | # Get the indices of the k-nearest neighbors of the query point
255 | indices_X_nni = index_sample.get_knn_indices(x_q, self.context_length)
256 | X_nni = X_sample[torch.tensor(indices_X_nni)]
257 | X_nni = np.swapaxes(X_nni, 0, 1)
258 |
259 | # Extract the target feature and remove it from the feature set
260 | y_nni = X_nni[:, :, target_idx]
261 | X_nni = np.delete(X_nni, target_idx, axis=2)
262 |
263 | # Transform the target feature
264 | y_nni, task = self.transform_target(y_nni.ravel(), cls_threshold=10, cls_prob=0.5)
265 |
266 | # select a random subset of features
267 | num_features_sampled = random.randint(
268 | min(self.max_feat // 2, X_nni.shape[-1] // 2), min(self.max_feat, X_nni.shape[-1])
269 | )
270 | random_feature_indices = np.random.choice(
271 | X_nni.shape[-1], num_features_sampled, replace=False
272 | )
273 | X_nni = X_nni[..., random_feature_indices]
274 |
275 | # Convert to PyTorch tensors
276 | X_nni = torch.Tensor(X_nni)
277 | y_nni = torch.Tensor(y_nni)
278 | return X_nni, y_nni, task
279 |
280 | @torch.no_grad()
281 | def generate_sample(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
282 | """Generate a sample from the dataset.
283 |
284 | Raises:
285 | ValueError: If the generated sample is invalid.
286 |
287 | Returns:
288 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The input features, target values, and task type.
289 | """
290 | # Randomly select a dataset ID
291 | did_sample = random.choices(self.dataset_ids, k=1)[0]
292 |
293 | # Get the dataset
294 | X_sample, index_sample, _ = self.datasets[did_sample].values()
295 | num_samples, num_features = X_sample.shape
296 |
297 | if self.retrieval: # Randomly select a query point and its k-nearest neighbors as context
298 | X_nni, y_nni, task = self.use_knn(X_sample, index_sample, num_samples, num_features)
299 | else: # Randomly select a sample and its context
300 | X_nni, y_nni, task = self.context_sampler(num_features, num_samples, X_sample)
301 |
302 | # Pad the features to the maximum number of features
303 | X_nni, y_nni = (
304 | pad_x(X_nni, self.max_feat),
305 | y_nni,
306 | )
307 |
308 | # Shuffle the rows
309 | shuffle_indices = np.random.choice(self.context_length, self.context_length, replace=False)
310 | X = X_nni[shuffle_indices]
311 | y = y_nni[shuffle_indices]
312 |
313 | if task == "cls": # If classification task, shuffle the labels
314 | le = LabelEncoder()
315 | y_ = le.fit_transform(y.ravel())
316 | classes = list(le.classes_)
317 | random.shuffle(classes)
318 | mapping = {original: shuffled for original, shuffled in zip(le.classes_, classes)}
319 | y_random = np.vectorize(mapping.get)(y_)
320 | y = y_random
321 | elif task == "reg": # If regression task, apply normalization and augmentation
322 | y_means = y.mean(dim=0)
323 | y_stds = y.std(dim=0) + 1e-6
324 | y = (y - y_means) / y_stds
325 | if self.y_reg_augment:
326 | # Apply random nonlinear transformations to the target for augmentation
327 | # TODO: shouldn't this non-linear transformation be implemented
328 | # only for a certain calls, not always, e.g. with:
329 | # if random.random() < 0.5:
330 | y = random_nonlinear_transform(y)
331 | y_means = y.mean(dim=0)
332 | y_stds = y.std(dim=0) + 1e-6
333 | y = (y - y_means) / y_stds
334 | else:
335 | raise ValueError("Task must be 'cls' or 'reg'")
336 |
337 | return (
338 | torch.Tensor(X),
339 | torch.Tensor(y).view(-1, 1, 1),
340 | torch.Tensor([0 if task == "cls" else 1]).unsqueeze(-1).unsqueeze(-1).long(),
341 | )
342 |
343 |
344 | # A small bank of nonlinear functions (feel free to add/remove as appropriate)
345 | FUNCTION_BANK = {
346 | "sin": torch.sin,
347 | "tanh": torch.tanh,
348 | "square": lambda x: x**2,
349 | "identity": lambda x: x,
350 | "step": lambda x: torch.where(x > 0, 1, 0),
351 | "relu": torch.nn.functional.relu,
352 | "sqrt": lambda x: torch.sign(x) * torch.sqrt(torch.abs(x)),
353 | "log": lambda x: torch.log(1 + torch.abs(x)) * torch.sign(x),
354 | }
355 |
356 |
357 | def random_nonlinear_transform(
358 | y: torch.Tensor,
359 | n_transforms: int = 2,
360 | scale_range: tuple = (0.5, 2.0),
361 | bias_range: tuple = (-1.0, 1.0),
362 | ) -> torch.Tensor:
363 | """Applies a series of random nonlinear transformations to the input tensor.
364 |
365 | Args:
366 | y (torch.Tensor): target tensor to be transformed.
367 | n_transforms (int, optional): number of transformations to apply. Defaults to 2.
368 | scale_range (tuple, optional): range of scaling factors to apply. Defaults to (0.5, 2.0).
369 | bias_range (tuple, optional): range of bias values to apply. Defaults to (-1.0, 1.0).
370 |
371 | Returns:
372 | torch.Tensor: Transformed tensor after applying the random nonlinear transformations.
373 | """
374 | device = y.device
375 | y_out = y.clone()
376 | func_keys = list(FUNCTION_BANK.keys())
377 | for _ in range(n_transforms):
378 | func_idx = torch.randint(low=0, high=len(func_keys), size=(1,)).item()
379 | func = FUNCTION_BANK[func_keys[func_idx]]
380 | scale = torch.empty(1, device=device).uniform_(*scale_range).item()
381 | bias = torch.empty(1, device=device).uniform_(*bias_range).item()
382 | if torch.rand(1, device=device) < 0.5:
383 | y_out = scale * func(y_out + bias)
384 | else:
385 | y_out = func(scale * y_out + bias)
386 | return y_out.float()
387 |
--------------------------------------------------------------------------------
/tabdpt_datasets/openml.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import openml
7 | import pandas as pd
8 | from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
9 |
10 | from tabdpt_datasets.dataset import Dataset
11 |
12 |
13 | class OpenMLDataset(Dataset):
14 | """
15 | Generic class for loading any OpenML dataset
16 | """
17 |
18 | @staticmethod
19 | def all_names():
20 | return None
21 |
22 | @staticmethod
23 | def suite_name():
24 | return "openml"
25 |
26 | def __init__(
27 | self,
28 | name: str,
29 | task_id: Optional[str] = None,
30 | openml_dataset_id: Optional[str | int] = None,
31 | openml_task_id: Optional[str | int] = None,
32 | ):
33 | """Initializes the OpenMLDataset.
34 |
35 | name (str): The name of the dataset.
36 | task_id (Optional[str], optional): Specifies the type of data split to use. Supported formats:
37 | - "default": Uses the default split provided by the OpenML dataset if available, otherwise
38 | defaults to a random split with seed 0.
39 | - "fold": Uses the fold with the given index from the OpenML task definition. Only
40 | supported with OpenML task IDs.
41 | - "random-seed": Creates a random 70/15/15 split using the specified integer seed.
42 | openml_dataset_id (Optional[str | int], optional): The OpenML dataset ID. Must be specified if
43 | `openml_task_id` is not provided. Defaults to None.
44 | openml_task_id (Optional[str | int], optional): The OpenML task ID. Must be specified if
45 | `openml_dataset_id` is not provided. Defaults to None.
46 |
47 | ValueError: Raised if neither or both `openml_dataset_id` and `openml_task_id` are provided.
48 | ValueError: Raised if `task_id` starts with "fold" but `openml_task_id` is not provided.
49 | ValueError: Raised if `task_id` has an invalid format.
50 | """
51 | super().__init__(name, task_id)
52 |
53 | # Initialize split indices to None
54 | self._train_inds = None
55 | self._val_inds = None
56 | self._test_inds = None
57 |
58 | if (openml_dataset_id is None) == (openml_task_id is None):
59 | raise ValueError("Must specify exactly one of openml_dataset_id or openml_task_id")
60 | self.did = openml_dataset_id
61 | self.tid = openml_task_id
62 |
63 | if task_id is None or task_id == "default":
64 | self.rng = np.random.default_rng(0)
65 | self.fold = 0
66 | elif task_id.startswith("fold"):
67 | if openml_task_id is None:
68 | raise ValueError("Can only use fold tasks with openml_task_id")
69 | self.fold = int(task_id.removeprefix("fold"))
70 | elif task_id.startswith("random-seed"):
71 | split_seed = int(task_id.removeprefix("random-seed"))
72 | self.rng = np.random.default_rng(split_seed)
73 | self.fold = None
74 | else:
75 | raise ValueError(f"Invalid task_id {task_id}")
76 |
77 | @property
78 | def openml_dataset(self):
79 | if not hasattr(self, "_openml_dataset"):
80 | raise ValueError("Data not loaded yet")
81 | return self._openml_dataset
82 |
83 | def prepare_data(self, download_dir: str):
84 | """
85 | Downloads the OpenML dataset and prepares the data for use.
86 | Args:
87 | download_dir (str): Directory to download the OpenML dataset to.
88 | """
89 | openml.config.set_root_cache_directory(os.path.join(download_dir, "openml_cache"))
90 |
91 | if self.tid:
92 | # retreive task and dataset information
93 | task = openml.tasks.get_task(
94 | self.tid,
95 | download_splits=True,
96 | download_data=True,
97 | download_qualities=True,
98 | download_features_meta_data=True,
99 | )
100 | dataset = openml.datasets.get_dataset(
101 | task.dataset_id,
102 | download_data=True,
103 | download_qualities=True,
104 | download_features_meta_data=True,
105 | )
106 |
107 | # retrieve data and metadata
108 | X, y, _, self.column_names = dataset.get_data(target=dataset.default_target_attribute)
109 | n = len(X)
110 |
111 | self.metadata["openml_task_id"] = self.tid
112 | self.metadata["openml_dataset_id"] = dataset.dataset_id
113 |
114 | #
115 | if self.fold is not None:
116 | split = task.get_train_test_split_indices(fold=self.fold)
117 | self._train_inds = split.train
118 | self._val_inds = []
119 | self._test_inds = split.test
120 |
121 | else:
122 | dataset = openml.datasets.get_dataset(
123 | self.did,
124 | download_data=True,
125 | download_qualities=True,
126 | download_features_meta_data=True,
127 | )
128 | X, y, _, self.column_names = dataset.get_data(target=dataset.default_target_attribute)
129 | self.metadata["openml_dataset_id"] = self.did
130 |
131 | self.metadata["openml_dataset_name"] = dataset.name
132 | self.metadata["openml_dataset_description"] = dataset.description
133 |
134 | if not self.tid or self.fold is None:
135 | n = len(X)
136 | perm = self.rng.permutation(n)
137 | self._train_inds = perm[: int(n * 0.7)]
138 | self._val_inds = perm[int(n * 0.7) : int(n * 0.85)]
139 | self._test_inds = perm[int(n * 0.85) :]
140 |
141 | if dataset.default_target_attribute and "," in dataset.default_target_attribute:
142 | y = None
143 | warnings.warn(
144 | f"Dataset {self.metadata['openml_dataset_id']} has multiple targets, which is "
145 | "not supported. Omitting targets."
146 | )
147 |
148 | self._openml_dataset = dataset
149 |
150 | categorical_inds = []
151 | for i, col in enumerate(X.columns):
152 | # Convert categorical features to ordinal integers
153 | if X[col].dtype == "object" or pd.api.types.is_categorical_dtype(X[col]):
154 | enc = OrdinalEncoder()
155 | X[[col]] = enc.fit_transform(X[[col]])
156 | categorical_inds.append(i)
157 |
158 | self.metadata["categorical_feature_inds"] = categorical_inds
159 |
160 | self.X = X.to_numpy().astype(np.float32)
161 |
162 | if y is None:
163 | self.y = None
164 | self.metadata["target_type"] = "none"
165 | return
166 |
167 | target_feature = [
168 | f for f in dataset.features.values() if f.name == dataset.default_target_attribute
169 | ][0]
170 |
171 | # encode target variable if it is categorical
172 | if (
173 | target_feature.data_type == "nominal"
174 | or y.dtype == "object"
175 | or pd.api.types.is_categorical_dtype(y)
176 | ):
177 | enc = LabelEncoder()
178 | self.y = enc.fit_transform(y)
179 | self.metadata["target_type"] = "classification"
180 | else:
181 | self.y = y.to_numpy().astype(np.float32)
182 | self.metadata["target_type"] = "regression"
183 |
184 | def all_instances(self):
185 | return self.X, self.y
186 |
187 | def train_inds(self):
188 | return self._train_inds
189 |
190 | def val_inds(self):
191 | return self._val_inds
192 |
193 | def test_inds(self):
194 | return self._test_inds
195 |
196 |
197 | class OpenMLTaskDataset(Dataset):
198 | """
199 | Abstract class for dataset classes that are fully defined by a list of OpenML task IDs, which
200 | are the names in all_names()
201 | """
202 |
203 | def __init__(self, name, task_id=None):
204 | self.tid = int(name)
205 | self.openml_dataset = OpenMLDataset(name, task_id, openml_task_id=self.tid)
206 | task_id = self.openml_dataset.task_id()
207 | super().__init__(name, task_id)
208 | assert name in self.all_names()
209 |
210 | def prepare_data(self, download_dir):
211 | self.openml_dataset.prepare_data(download_dir)
212 | self.metadata = {**self.openml_dataset.metadata, **self.metadata}
213 |
214 | def all_instances(self):
215 | return self.openml_dataset.all_instances()
216 |
217 | def train_inds(self):
218 | return self.openml_dataset.train_inds()
219 |
220 | def val_inds(self):
221 | return self.openml_dataset.val_inds()
222 |
223 | def test_inds(self):
224 | return self.openml_dataset.test_inds()
225 |
226 |
227 | # Note: openml task ids, not dataset ids
228 | # See https://github.com/LeoGrin/tabular-benchmark?tab=readme-ov-file#downloading-the-datasets
229 | # for source suite ids.
230 | GRIN_NUM_CLS_IDS = [
231 | 361055,
232 | 361060,
233 | 361061,
234 | 361062,
235 | 361063,
236 | 361065,
237 | 361066,
238 | 361068,
239 | 361069,
240 | 361070,
241 | 361273,
242 | 361274,
243 | 361275,
244 | 361276,
245 | 361277,
246 | 361278,
247 | ]
248 | GRIN_CAT_CLS_IDS = [361110, 361111, 361113, 361282, 361283, 361285, 361286]
249 | GRIN_NUM_REG_IDS = [
250 | 361072,
251 | 361073,
252 | 361074,
253 | 361076,
254 | 361077,
255 | 361078,
256 | 361079,
257 | 361080,
258 | 361081,
259 | 361082,
260 | 361083,
261 | 361084,
262 | 361085,
263 | 361086,
264 | 361087,
265 | 361088,
266 | 361279,
267 | 361280,
268 | 361281,
269 | ]
270 | GRIN_CAT_REG_IDS = [
271 | 361093,
272 | 361094,
273 | 361096,
274 | 361097,
275 | 361098,
276 | 361099,
277 | 361101,
278 | 361102,
279 | 361103,
280 | 361104,
281 | 361287,
282 | 361288,
283 | 361289,
284 | 361291,
285 | 361292,
286 | 361293,
287 | 361294,
288 | ]
289 |
290 | GRIN_ALL_IDS = GRIN_NUM_CLS_IDS + GRIN_CAT_CLS_IDS + GRIN_NUM_REG_IDS + GRIN_CAT_REG_IDS
291 |
292 |
293 | class GrinsztajnDataset(OpenMLTaskDataset):
294 | @staticmethod
295 | def all_names():
296 | return [str(tid) for tid in GRIN_ALL_IDS]
297 |
298 | @staticmethod
299 | def suite_name():
300 | return "grinsztajn"
301 |
302 |
303 | CC18_IDS = [
304 | 3,
305 | 6,
306 | 11,
307 | 12,
308 | 14,
309 | 15,
310 | 16,
311 | 18,
312 | 22,
313 | 23,
314 | 28,
315 | 29,
316 | 31,
317 | 32,
318 | 37,
319 | 43,
320 | 45,
321 | 49,
322 | 53,
323 | 219,
324 | 2074,
325 | 2079,
326 | 3021,
327 | 3022,
328 | 3481,
329 | 3549,
330 | 3560,
331 | 3573,
332 | 3902,
333 | 3903,
334 | 3904,
335 | 3913,
336 | 3917,
337 | 3918,
338 | 7592,
339 | 9910,
340 | 9946,
341 | 9952,
342 | 9957,
343 | 9960,
344 | 9964,
345 | 9971,
346 | 9976,
347 | 9977,
348 | 9978,
349 | 9981,
350 | 9985,
351 | 10093,
352 | 10101,
353 | 14952,
354 | 14954,
355 | 14965,
356 | 14969,
357 | 14970,
358 | 125920,
359 | 125922,
360 | 146195,
361 | 146800,
362 | 146817,
363 | 146819,
364 | 146820,
365 | 146821,
366 | 146822,
367 | 146824,
368 | 146825,
369 | 167119,
370 | 167120,
371 | 167121,
372 | 167124,
373 | 167125,
374 | 167140,
375 | 167141,
376 | ]
377 |
378 |
379 | class CC18Dataset(OpenMLTaskDataset):
380 | @staticmethod
381 | def all_names():
382 | return [str(tid) for tid in CC18_IDS]
383 |
384 | @staticmethod
385 | def suite_name():
386 | return "cc18"
387 |
388 |
389 | CTR23_IDS = [
390 | 361234,
391 | 361235,
392 | 361236,
393 | 361237,
394 | 361241,
395 | 361242,
396 | 361243,
397 | 361244,
398 | 361247,
399 | 361249,
400 | 361250,
401 | 361251,
402 | 361252,
403 | 361253,
404 | 361254,
405 | 361255,
406 | 361256,
407 | 361257,
408 | 361258,
409 | 361259,
410 | 361260,
411 | 361261,
412 | 361264,
413 | 361266,
414 | 361267,
415 | 361268,
416 | 361269,
417 | 361272,
418 | 361616,
419 | 361617,
420 | 361618,
421 | 361619,
422 | 361621,
423 | 361622,
424 | 361623,
425 | ]
426 |
427 |
428 | class CTR23Dataset(OpenMLTaskDataset):
429 | @staticmethod
430 | def all_names():
431 | return [str(tid) for tid in CTR23_IDS]
432 |
433 | @staticmethod
434 | def suite_name():
435 | return "ctr23"
436 |
437 |
438 | AMLB_IDS = [
439 | 3,
440 | 12,
441 | 31,
442 | 53,
443 | 3917,
444 | 3945,
445 | 7592,
446 | 7593,
447 | 9952,
448 | 9977,
449 | 9981,
450 | 10101,
451 | 14965,
452 | 34539,
453 | 146195,
454 | 146212,
455 | 146606,
456 | 146818,
457 | 146821,
458 | 146822,
459 | 146825,
460 | 167119,
461 | 167120,
462 | 168329,
463 | 168330,
464 | 168331,
465 | 168332,
466 | 168335,
467 | 168337,
468 | 168338,
469 | 168868,
470 | 168908,
471 | 168909,
472 | 168910,
473 | 168911,
474 | 168912,
475 | 189354,
476 | 189355,
477 | 189356,
478 | ]
479 |
480 |
481 | class AMLBDataset(OpenMLTaskDataset):
482 | @staticmethod
483 | def all_names():
484 | return [str(tid) for tid in AMLB_IDS]
485 |
486 | @staticmethod
487 | def suite_name():
488 | return "amlb"
489 |
490 |
491 | ICLR_TRAINING_IDS = [
492 | 41138,
493 | 4135,
494 | 4535,
495 | 41434,
496 | 375,
497 | 1120,
498 | 41150,
499 | 40900,
500 | 40536,
501 | 1043,
502 | 1119,
503 | 1169,
504 | 41147,
505 | 1459,
506 | 1466,
507 | 1118,
508 | 41142,
509 | 23380,
510 | 1596,
511 | 41163,
512 | 1471,
513 | 846,
514 | 1044,
515 | 41164,
516 | 1477,
517 | 1476,
518 | 1038,
519 | 41159,
520 | 23512,
521 | 1479,
522 | 821,
523 | 41168,
524 | 41143,
525 | 184,
526 | 1483,
527 | 40679,
528 | 24,
529 | 1116,
530 | 1568,
531 | 1493,
532 | 30,
533 | 41145,
534 | 1567,
535 | 871,
536 | 41161,
537 | 41165,
538 | 312,
539 | 40685,
540 | 1036,
541 | 41146,
542 | 41166,
543 | 1509,
544 | 40733,
545 | 44089,
546 | 44122,
547 | 45022,
548 | 45020,
549 | 45028,
550 | 45026,
551 | 45038,
552 | 45039,
553 | 1111,
554 | 1457,
555 | 41167,
556 | 41158,
557 | 41144,
558 | 41156,
559 | 40498,
560 | 41169,
561 | 41162,
562 | 42734,
563 | 42732,
564 | 42746,
565 | 42742,
566 | 43072,
567 | 137,
568 | 273,
569 | 382,
570 | 389,
571 | 396,
572 | 802,
573 | 816,
574 | 843,
575 | 930,
576 | 966,
577 | 981,
578 | 1002,
579 | 1018,
580 | 1037,
581 | 1042,
582 | 1112,
583 | 1130,
584 | 1142,
585 | 1444,
586 | 1453,
587 | 1481,
588 | 1503,
589 | 1507,
590 | 40646,
591 | 40680,
592 | 40706,
593 | 44055,
594 | 44056,
595 | 44061,
596 | 44063,
597 | 44065,
598 | 44068,
599 | 44069,
600 | 45041,
601 | 45043,
602 | 45045,
603 | 45046,
604 | 45047,
605 | 44136,
606 | 44137,
607 | 44145,
608 | 45032,
609 | 4549,
610 | 42572,
611 | 42705,
612 | 42728,
613 | 41540,
614 | 42724,
615 | 42727,
616 | 42730,
617 | 41980,
618 | 42563,
619 | 3050,
620 | 3277,
621 | 43071,
622 | ]
623 |
624 |
625 | class ICLRTrainingDataset(Dataset):
626 | @staticmethod
627 | def all_names():
628 | return [str(did) for did in ICLR_TRAINING_IDS]
629 |
630 | @staticmethod
631 | def suite_name():
632 | return "iclr-training"
633 |
634 | def __init__(self, name):
635 | super().__init__(name)
636 | assert name in self.all_names()
637 | self.did = int(name)
638 | self.openml_dataset = OpenMLDataset(name, openml_dataset_id=self.did)
639 |
640 | def prepare_data(self, download_dir):
641 | self.openml_dataset.prepare_data(download_dir)
642 | self.metadata = {**self.openml_dataset.metadata, **self.metadata}
643 |
644 | def all_instances(self):
645 | return self.openml_dataset.all_instances()
646 |
647 | def train_inds(self):
648 | return self.openml_dataset.train_inds()
649 |
650 | def val_inds(self):
651 | return self.openml_dataset.val_inds()
652 |
653 | def test_inds(self):
654 | return self.openml_dataset.test_inds()
655 |
--------------------------------------------------------------------------------
/tabdpt_datasets/catalogue.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import difflib
4 | import json
5 | import os
6 | import sys
7 | import warnings
8 | from dataclasses import dataclass
9 | from typing import Literal
10 |
11 | import numpy as np
12 | import scipy
13 |
14 | from tabdpt_datasets.annotated_tables import AnnotatedTablesDataset
15 | from tabdpt_datasets.dataset import Dataset
16 | from tabdpt_datasets.openml import (
17 | AMLBDataset,
18 | CC18Dataset,
19 | CTR23Dataset,
20 | GrinsztajnDataset,
21 | ICLRTrainingDataset,
22 | TabZillaDataset,
23 | )
24 | from tabdpt_datasets.tabred import TabredDataset
25 | from tabdpt_datasets.talent import TalentDataset
26 |
27 |
28 | class NpGenericEncoder(json.JSONEncoder):
29 | # Handle types like np.float32
30 | def default(self, x):
31 | if isinstance(x, np.generic):
32 | return x.item()
33 | return json.JSONEncoder.default(self, x)
34 |
35 |
36 | # Only include dataset classes that provide a suite of named datasets that can each be loaded by
37 | # name
38 | SUITES = [
39 | CC18Dataset,
40 | CTR23Dataset,
41 | GrinsztajnDataset,
42 | TabredDataset,
43 | TalentDataset,
44 | AMLBDataset,
45 | TabZillaDataset,
46 | ICLRTrainingDataset,
47 | AnnotatedTablesDataset,
48 | ]
49 |
50 | EVAL_SUITES = [CC18Dataset, CTR23Dataset]
51 | EVAL_SUITE_NAMES = [s.suite_name() for s in EVAL_SUITES]
52 |
53 |
54 | def longest_overlap(s1: str, s2: str) -> int:
55 | sm = difflib.SequenceMatcher(a=s1.lower(), b=s2.lower())
56 | return sm.find_longest_match(0, len(s1), 0, len(s2))[2]
57 |
58 |
59 | @dataclass
60 | class Duplicate:
61 | suite_name_1: str
62 | dataset_name_1: str
63 | suite_name_2: str
64 | dataset_name_2: str
65 | likelihood: Literal["low", "high", "certain"]
66 | reasons: list[str]
67 |
68 | def likelihood_gte(self, other: str):
69 | """
70 | Returns True if this duplicate's likelihood is greater than or equal to the other given
71 | likelihood.
72 | """
73 | assert other in ["low", "high", "certain"]
74 | return (
75 | other == "low"
76 | or other == "high"
77 | and self.likelihood != "low"
78 | or other == "certain"
79 | and self.likelihood == "certain"
80 | )
81 |
82 |
83 | class CatalogueView:
84 | """
85 | Base class for catalogue functions that don't manage metadata on disk. Allows for filtered views
86 | of the catalogue.
87 |
88 | Code outside this module generally shouldn't directly instantiate CatalogueViews, the Catalogue
89 | class should be used instead.
90 | """
91 |
92 | def __init__(self, metadata):
93 | self._metadata = metadata
94 | self.dataset_map = {(c.suite_name(), name): c for c in SUITES for name in c.all_names()}
95 |
96 | def filter(self, key: str, pred) -> CatalogueView:
97 | """
98 | Filter catalogue using a metadata key and corresponding required value, or a metadata key
99 | and a predicate that must return True when passed the value.
100 |
101 | E.g., c.filter('target_type', 'regression') or c.filter('size', lambda s: s > 1000)
102 |
103 | Returns a new CatalogueView.
104 | """
105 | if callable(pred):
106 | return CatalogueView({k: v for k, v in self._metadata.items() if pred(v[key])})
107 | return CatalogueView({k: v for k, v in self._metadata.items() if v[key] == pred})
108 |
109 | def metadata(self):
110 | return self._metadata
111 |
112 | def dataset(self, suite_name: str, dataset_name: str) -> Dataset:
113 | """
114 | Get a Dataset object for the dataset with the given name.
115 | """
116 | return self.dataset_map[(suite_name, dataset_name)](dataset_name)
117 |
118 | def datasets(self) -> list[Dataset]:
119 | """
120 | Get dataset objects for all datasets in the current view.
121 | """
122 | return [self.dataset(*k) for k in self._metadata.keys()]
123 |
124 | def detect_duplicates(self, verbose=False) -> list[Duplicate]:
125 | """
126 | Detects possible duplicate datasets. Sensitivity can be selected for detecting leakage (high
127 | sensitivity, default) or just removing redundant data (low sensitivity). Returns a list of
128 | Duplicate objects.
129 |
130 | verbose: Print out duplicates and info on detection as they're detected
131 | """
132 | entries = list(self._metadata.items())
133 | kd_trees = {}
134 | for k, m in entries:
135 | col_coeffs = []
136 | for mean, var, skew, kurt in zip(
137 | m["column_means"], m["column_vars"], m["column_skews"], m["column_kurtoses"]
138 | ):
139 | col_coeffs.append([mean, np.sqrt(var), skew, kurt])
140 | # Using moderate values for infinities so they don't drive all other values to zero once
141 | # normalized
142 | col_coeffs = np.nan_to_num(np.array(col_coeffs), posinf=1.0, neginf=-1.0)
143 | if len(col_coeffs) > 0:
144 | # Normalize row-wise so that the distance tolerance works reasonably
145 | col_coeffs = (col_coeffs - col_coeffs.mean(axis=1)[:, None]) / (
146 | col_coeffs.std(axis=1)[:, None] + 1e-8
147 | )
148 | kd_trees[k] = scipy.spatial.KDTree(col_coeffs)
149 |
150 | duplicates = []
151 | for i in range(len(entries)):
152 | for j in range(i + 1, len(entries)):
153 | k1, m1 = entries[i]
154 | k2, m2 = entries[j]
155 | # These suites should already be deduped
156 | if m1["suite_name"] == m2["suite_name"] and m1["suite_name"] in (
157 | "cc18",
158 | "ctr23",
159 | "grinsztajn",
160 | "tabred",
161 | "amlb",
162 | ):
163 | continue
164 |
165 | dup_reasons = []
166 | certain_dup = False
167 | high_feature_similarity = False
168 |
169 | if m1["n_rows"] == m2["n_rows"] and sum(c != "0" for c in str(m1["n_rows"])) > 1:
170 | # Ignore n_rows mismatches if they only have one non-zero digit, e.g., 10000 vs
171 | # 10000 is not really suspicious
172 | dup_reasons.append(f"same n_rows {m1['n_rows']}")
173 |
174 | if m1["n_features"] == m2["n_features"]:
175 | dup_reasons.append(f"same number of features {m1['n_features']}")
176 |
177 | if (
178 | m1["target_type"] != "none"
179 | and m2["target_type"] != "none"
180 | and np.isclose(m1["y_mean"], m2["y_mean"])
181 | and np.isclose(m1["y_var"], m2["y_var"])
182 | ):
183 | dup_reasons.append(
184 | f"similar target statistics: mean {m1['y_mean']} var {m1['y_var']} "
185 | f"vs mean {m2['y_mean']} var {m2['y_var']}"
186 | )
187 |
188 | if "file_sha1" in m1 and "file_sha1" in m2 and m1["file_sha1"] == m2["file_sha1"]:
189 | certain_dup = True
190 | dup_reasons.append("same sha1 hash")
191 |
192 | m1_name = m1.get("openml_dataset_name", m1.get("kaggle_dataset_name", None))
193 | m2_name = m2.get("openml_dataset_name", m2.get("kaggle_dataset_name", None))
194 | if m1_name and m2_name and longest_overlap(m1_name, m2_name) > 4:
195 | dup_reasons.append("overlap in names:" f"{m1_name}, {m2_name}")
196 |
197 | # Compare statistics for each column wrt the target. If enough are similar, then
198 | # there might be leakage.
199 | if k1 in kd_trees and k2 in kd_trees:
200 | # Check the number of columns that seem to have a pair in the other table
201 | # Annoying output format - list of lists of indices in other tree
202 | pair_inds = kd_trees[k1].query_ball_tree(kd_trees[k2], 1e-3)
203 | n_similar_k1 = sum(len(xs) > 0 for xs in pair_inds)
204 | n_similar_k2 = len(set(x for xs in pair_inds for x in xs))
205 | n_similar = min(n_similar_k1, n_similar_k2)
206 | thresh = min(5, m1["n_features"], m2["n_features"])
207 |
208 | if n_similar >= thresh:
209 | dup_reasons.append(f"{n_similar} columns with similar statistics")
210 | min_n_features = min(m1["n_features"], m2["n_features"])
211 | max_n_features = max(m1["n_features"], m2["n_features"])
212 | n_feature_ratio = max_n_features / min_n_features
213 | detection_frac = min(1, 1 / 2 ** (np.log10(min_n_features) - 1))
214 | if n_similar > detection_frac * min_n_features and n_feature_ratio < 1.5:
215 | high_feature_similarity = True
216 |
217 | likelihood = (
218 | "certain"
219 | if certain_dup
220 | else (
221 | "high"
222 | if len(dup_reasons) > 2
223 | else "low" if len(dup_reasons) > 1 or high_feature_similarity else None
224 | )
225 | )
226 | if likelihood:
227 | if verbose:
228 | print(f"Suspected duplicate {k1} {k2}: {', '.join(dup_reasons)}")
229 | if m1["suite_name"] == m2["suite_name"]:
230 | print(f"Within same suite! {m1['suite_name']}")
231 | duplicates.append(
232 | Duplicate(k1[0], k1[1], k2[0], k2[1], likelihood, dup_reasons)
233 | )
234 |
235 | return duplicates
236 |
237 | def filter_duplicates(
238 | self, eval_min_likelihood="low", train_min_likelihood="high"
239 | ) -> CatalogueView:
240 | """
241 | Returns a new CatalogueView with duplicate datasets removed. Eval datasets are always kept.
242 |
243 | eval_min_likelihood: Either 'low', 'high', 'certain', or None. Determines minimum duplicate
244 | likelihood between a train and eval dataset for the train dataset to be removed. Set to None
245 | to ignore eval duplicates. The default is 'low' to prevent all suspected leakage.
246 | train_min_likelihood: Either 'low', 'high', 'certain', or None. Determines minimum duplicate
247 | likelihood between two train datasets for the first train dataset to be removed. Set to None
248 | to ignore train duplicates. The default is 'high' to avoid removing useful data, since
249 | leakage isn't a concern in the train set.
250 | """
251 | # Remove all duplicates with eval datasets, then go back and remove the first dataset from
252 | # each remaining pair of duplicates. Since detect_duplicates uses a fixed ordering, this
253 | # ensures all duplicates are resolved.
254 | duplicates = self.detect_duplicates(verbose=False)
255 | to_remove: set[tuple[str, str]] = set()
256 | dups_no_eval = []
257 | for d in duplicates:
258 | if d.suite_name_1 in EVAL_SUITE_NAMES and d.suite_name_2 in EVAL_SUITE_NAMES:
259 | continue
260 | elif d.suite_name_1 in EVAL_SUITE_NAMES:
261 | if eval_min_likelihood and d.likelihood_gte(eval_min_likelihood):
262 | to_remove.add((d.suite_name_2, d.dataset_name_2))
263 | elif d.suite_name_2 in EVAL_SUITE_NAMES:
264 | if eval_min_likelihood and d.likelihood_gte(eval_min_likelihood):
265 | to_remove.add((d.suite_name_1, d.dataset_name_1))
266 | else:
267 | dups_no_eval.append(d)
268 | for d in dups_no_eval:
269 | if (
270 | train_min_likelihood
271 | and d.likelihood_gte(train_min_likelihood)
272 | and (d.suite_name_2, d.dataset_name_2) not in to_remove
273 | ):
274 | to_remove.add((d.suite_name_1, d.dataset_name_1))
275 | return CatalogueView({k: v for k, v in self._metadata.items() if k not in to_remove})
276 |
277 | def total_size(self) -> (int, int):
278 | """
279 | Total instances and cells across all datasets
280 | """
281 | return (
282 | sum(d["n_rows"] for d in self._metadata.values()),
283 | sum(d["n_cells"] for d in self._metadata.values()),
284 | )
285 |
286 | def split(self) -> tuple[CatalogueView, CatalogueView]:
287 | """
288 | Splits datasets based on name into a training and eval set. Note that to avoid leakage,
289 | duplicates should already be filtered out. Uses fixed splits as discussed.
290 | """
291 | train, eval = {}, {}
292 | for k, v in self._metadata.items():
293 | if k[0] in EVAL_SUITE_NAMES:
294 | eval[k] = v
295 | else:
296 | train[k] = v
297 | return CatalogueView(train), CatalogueView(eval)
298 |
299 |
300 | class Catalogue(CatalogueView):
301 | """
302 | Dataset catalogue, used to access the aggregate set of datasets in our project and filter,
303 | remove duplicates, split into train and hold-out, etc.
304 |
305 | Unlike CatalogueView, this class includes functionality to generate, load, and store the
306 | catalogue, so it should be generally used to create and access the catalogue.
307 | """
308 |
309 | def __init__(self, download_dir, suites=SUITES):
310 | self.download_dir = download_dir
311 | self.suites = suites
312 | self.cache_path = os.path.join(download_dir, "metadata.json")
313 | if os.path.exists(self.cache_path):
314 | with open(self.cache_path, "r") as f:
315 | try:
316 | metadata = json.load(f)
317 | except json.decoder.JSONDecodeError as e:
318 | print(
319 | f"JSON decoding failed - solution is probably to delete {self.cache_path} and rerun full_update"
320 | )
321 | raise e
322 | # Stored as a list, convert back into a dict indexed by suite and dataset name
323 | metadata = {(d["suite_name"], d["dataset_name"]): d for d in metadata}
324 | else:
325 | metadata = {}
326 | warnings.warn(
327 | "Catalogue cache doesn't exist yet - run catalogue.incremental_update() or catalogue.full_update() before using"
328 | )
329 | super().__init__(metadata)
330 |
331 | def update_dataset(self, dataset, update_cache=True):
332 | """
333 | Run auto_populate_metadata to update the metadata for an individual dataset, optionally
334 | writing to the cache.
335 | """
336 | dataset.prepare_data(self.download_dir)
337 | dataset.auto_populate_metadata()
338 | self._metadata[(dataset.suite_name(), dataset.name)] = dataset.metadata
339 | if update_cache:
340 | with open(self.cache_path, "w") as f:
341 | json.dump(list(self._metadata.values()), f, cls=NpGenericEncoder, indent=4)
342 |
343 | def incremental_update(self):
344 | """
345 | Update metadata and cache for datasets that don't already appear in metadata.
346 | """
347 | n_updated = 0
348 | n_total = sum(len(c.all_names()) for c in self.suites)
349 | for c in self.suites:
350 | for name in c.all_names():
351 | n_updated += 1
352 | sys.stdout.write(f"\rUpdating catalogue... {n_updated}/{n_total}")
353 | sys.stdout.flush()
354 | if name not in self._metadata:
355 | self.update_dataset(c(name), update_cache=False)
356 | print("Done")
357 | with open(self.cache_path, "w") as f:
358 | json.dump(list(self._metadata.values()), f, cls=NpGenericEncoder, indent=4)
359 |
360 | def full_update(self):
361 | """
362 | Re-generate metadata and cache entirely.
363 | """
364 | self._metadata = {}
365 | n_updated = 0
366 | n_total = sum(len(c.all_names()) for c in self.suites)
367 | for c in self.suites:
368 | for name in c.all_names():
369 | n_updated += 1
370 | sys.stdout.write(f"\rUpdating catalogue... {n_updated}/{n_total}")
371 | sys.stdout.flush()
372 | self.update_dataset(c(name), update_cache=False)
373 | print("Done")
374 | with open(self.cache_path, "w") as f:
375 | json.dump(list(self._metadata.values()), f, cls=NpGenericEncoder, indent=4)
376 |
--------------------------------------------------------------------------------
/tabdpt.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import torch
6 | from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
7 | from torch.nn.attention import SDPBackend, sdpa_kernel
8 |
9 | from model import TabDPTModel, Task, convert_to_torch_tensor, pad_x
10 | from utils import FAISS, DataPreprocessor, standardize, seed_everything
11 |
12 | _DEFAULT_DEVICE = "cuda:0"
13 | _INF_BATCH_SIZE = 512
14 | _DEFAULT_CONTEXT_SIZE = 128
15 | _DEFAULT_RETRIEVAL = True
16 | _DEFAULT_TEMPERATURE = 0.8
17 |
18 |
19 | class TabDPTEstimator(BaseEstimator):
20 | """A PyTorch-based implementation of TabDPT for tabular data tasks."""
21 |
22 | def __init__(
23 | self,
24 | model: Optional[TabDPTModel] = None,
25 | path: str = "",
26 | mode: Task = Task.CLS,
27 | inf_batch_size: int = _INF_BATCH_SIZE,
28 | device: str = _DEFAULT_DEVICE,
29 | tensor_eval: bool = False,
30 | ):
31 | """Initialize the TabDPTEstimator.
32 |
33 | Args:
34 | model (Optional[TabDPTModel], optional): The TabDPT model to use. Defaults to None.
35 | path (str, optional): Path to the model checkpoint. Defaults to "".
36 | mode (Task, optional): The task mode (classification or regression). Defaults to Task.CLS.
37 | inf_batch_size (int, optional): Inference batch size. Defaults to 512.
38 | device (str, optional): Device to run the model on. Defaults to _DEFAULT_DEVICE.
39 | tensor_eval (bool, optional): Whether to use tensor evaluation. Defaults to False.
40 |
41 | Raises:
42 | ValueError: If both model and path are None.
43 | """
44 | seed_everything(42)
45 | self.mode = mode
46 | self.inf_batch_size = inf_batch_size
47 | self.device = device
48 | self.tensor_eval = tensor_eval
49 | if model is None:
50 | if path:
51 | checkpoint = torch.load(path, weights_only=False)
52 | self.model = TabDPTModel.load(
53 | model_state=checkpoint["model"], config=checkpoint["cfg"]
54 | )
55 | self.model.eval()
56 | else:
57 | raise ValueError("Either model or path must be provided")
58 | else:
59 | self.model = model
60 | self.max_features = self.model.num_features
61 | self.max_num_classes = self.model.n_out
62 |
63 | def fit(self, X: np.ndarray, y: np.ndarray, faiss_index=None) -> None:
64 | """Fit the model to the training data.
65 |
66 | Args:
67 | X (np.ndarray): Training features, a 2D numpy array of shape [n_samples, n_features].
68 | y (np.ndarray): Training labels, a 1D numpy array of shape [n_samples].
69 | faiss_index (Optional[FAISS], optional): Precomputed FAISS index for retrieval. Defaults to None.
70 | """
71 | if not self.tensor_eval:
72 | assert isinstance(X, np.ndarray), "X must be a numpy array"
73 | assert isinstance(y, np.ndarray), "y must be a numpy array"
74 | assert X.shape[0] == y.shape[0], "X and y must have the same number of samples"
75 | assert X.ndim == 2, "X must be a 2D array"
76 | assert y.ndim == 1, "y must be a 1D array"
77 |
78 | self.preprocessor = DataPreprocessor()
79 | X = self.preprocessor.fit_transform(X)
80 |
81 | # initialize the Faiss index if not provided
82 | self.faiss_knn = faiss_index if faiss_index is not None else FAISS(X)
83 |
84 | self.n_instances, self.n_features = X.shape
85 | self.X_train = X
86 | self.y_train = y
87 |
88 | self.autocast = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
89 |
90 | def _prepare_prediction(self, X: np.ndarray) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
91 | """preprocess the input data for prediction.
92 | This method handles the transformation of the input data, including imputation,
93 | scaling, and dimensionality reduction if necessary.
94 |
95 | Args:
96 | X (np.ndarray): input features for prediction.
97 |
98 | Returns:
99 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: training features, training labels, and test features.
100 | """
101 |
102 | # preprocess the input data
103 | if not self.tensor_eval:
104 | self.X_test = self.preprocessor.transform(X)
105 | else:
106 | self.X_test = X
107 |
108 | train_x, train_y, test_x = (
109 | convert_to_torch_tensor(self.X_train).to(self.device).float(),
110 | convert_to_torch_tensor(self.y_train).to(self.device).float(),
111 | convert_to_torch_tensor(self.X_test).to(self.device).float(),
112 | )
113 |
114 | # Apply PCA optionally to reduce the number of features
115 | if self.n_features > self.max_features:
116 | _, _, self.V = torch.pca_lowrank(train_x, q=self.max_features)
117 | train_x = train_x @ self.V
118 | else:
119 | self.V = None
120 |
121 | # apply PCA to the test set if V is not None (i.e., PCA was applied to the training set)
122 | test_x = test_x @ self.V if self.V is not None else test_x
123 |
124 | return train_x, train_y, test_x
125 |
126 | @torch.no_grad()
127 | def no_retrieval_data(
128 | self,
129 | train_x: torch.Tensor,
130 | train_y: torch.Tensor,
131 | test_x: torch.Tensor,
132 | context_size: int = _DEFAULT_CONTEXT_SIZE,
133 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
134 | """Prepare data for the no retrieval scenario.
135 |
136 | Args:
137 | train_x (torch.Tensor): Training features with shape [n_train, d].
138 | train_y (torch.Tensor): Training labels with shape [n_train].
139 | test_x (torch.Tensor): Test features with shape [n_test, d].
140 | context_size (int, optional): Number of context samples to use.
141 | Defaults to _DEFAULT_CONTEXT_SIZE.
142 |
143 | Returns:
144 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
145 | Context features, context labels, and evaluation features.
146 | """
147 | n_context = min(context_size, train_x.shape[0])
148 | idx_context = np.random.choice(train_x.shape[0], n_context, replace=False)
149 |
150 | # shape: [n_context, d]
151 | context_features = train_x[idx_context]
152 | context_labels = train_y[idx_context]
153 |
154 | # shape: [n_context, 1, d]
155 | context_features = pad_x(context_features.unsqueeze(1), self.max_features).to(self.device)
156 | context_labels = context_labels.unsqueeze(1).float() # shape: [n_context, 1]
157 |
158 | # shape: [n_test, 1, d]
159 | eval_features = pad_x(test_x.unsqueeze(1), self.max_features).to(self.device)
160 |
161 | return context_features, context_labels, eval_features
162 |
163 | @torch.no_grad()
164 | def batch_retrieval_data(
165 | self,
166 | train_x: torch.Tensor,
167 | train_y: torch.Tensor,
168 | test_x: torch.Tensor,
169 | batch_index: int,
170 | context_size: int,
171 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
172 | """Prepare data for the batch retrieval scenario.
173 |
174 | Args:
175 | batch_index (int): Batch index.
176 | context_size (int): Number of context samples to use.
177 |
178 | Returns:
179 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180 | Context features, context labels, and evaluation features.
181 | """
182 | n_test = test_x.shape[0]
183 |
184 | start = batch_index * self.inf_batch_size
185 | end = min(n_test, (batch_index + 1) * self.inf_batch_size)
186 |
187 | indices_knn = self.faiss_knn.get_knn_indices(self.X_test[start:end], k=context_size)
188 | knn_features = train_x[torch.tensor(indices_knn)]
189 | knn_labels = train_y[torch.tensor(indices_knn)]
190 |
191 | # swap => [context_size, batch_size, d]
192 | knn_features = np.swapaxes(knn_features.cpu().numpy(), 0, 1)
193 | knn_labels = np.swapaxes(knn_labels.cpu().numpy(), 0, 1)
194 |
195 | knn_features = pad_x(torch.Tensor(knn_features), self.max_features).to(self.device)
196 | knn_labels = torch.Tensor(knn_labels).to(self.device)
197 |
198 | eval_features = pad_x(test_x[start:end].unsqueeze(0), self.max_features).to(self.device)
199 | return knn_features, knn_labels, eval_features
200 |
201 |
202 | class TabDPTClassifier(TabDPTEstimator, ClassifierMixin):
203 | """
204 | A PyTorch-based implementation of TabDPT for classification tasks.
205 | """
206 |
207 | def __init__(
208 | self,
209 | model: Optional[TabDPTModel] = None,
210 | path: str = "",
211 | mode: Task = Task.CLS,
212 | inf_batch_size: int = _INF_BATCH_SIZE,
213 | device: str = _DEFAULT_DEVICE,
214 | tensor_eval: bool = False,
215 | ):
216 | super().__init__(
217 | model=model,
218 | path=path,
219 | mode=mode,
220 | inf_batch_size=inf_batch_size,
221 | device=device,
222 | tensor_eval=tensor_eval,
223 | )
224 |
225 | def fit(self, X, y, faiss_index=None):
226 | super().fit(X, y, faiss_index)
227 | # Number of classes
228 | if self.tensor_eval:
229 | self.num_classes = len(torch.unique(self.y_train))
230 | else:
231 | self.num_classes = len(np.unique(self.y_train))
232 | assert self.num_classes > 1, "Number of classes must be greater than 1"
233 |
234 | def _predict_large_cls(self, X_context, X_eval, y_context) -> torch.Tensor:
235 | """Digit-level prediction for the case where num_classes > self.max_num_classes.
236 | Here, X_context + X_eval is [L, 1, d], and y_context is [L_context, 1].
237 |
238 | Args:
239 | X_context (torch.Tensor): Context features with shape [L_context, 1, d].
240 | X_eval (torch.Tensor): Evaluation features with shape [L_eval, 1, d].
241 | y_context (torch.Tensor): Context labels with shape [L_context, 1].
242 |
243 | Returns:
244 | torch.Tensor: Predicted probabilities with shape [L_eval, 1, num_classes].
245 | """
246 | # number of digits needed to represent num_classes
247 | num_digits = math.ceil(math.log(self.num_classes, self.max_num_classes))
248 |
249 | digit_preds = []
250 | for i in range(num_digits):
251 | y_context_digit = (y_context // (self.max_num_classes**i)) % self.max_num_classes
252 |
253 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
254 | pred = self.model(
255 | x_src=torch.cat([X_context, X_eval], dim=0),
256 | y_src=y_context_digit,
257 | )
258 | # shape: [L_context + L_eval, 1, max_num_classes]
259 | digit_preds.append(pred[..., : self.max_num_classes].float())
260 |
261 | # Combine digit predictions
262 | B = X_context.shape[1]
263 | L_eval = X_eval.shape[0]
264 | full_pred = torch.zeros((L_eval, B, self.num_classes), device=X_context.device)
265 | # For each of the L_eval positions, compute the class probabilities
266 | for class_idx in range(self.num_classes):
267 | # sum across digits
268 | class_pred = torch.zeros((L_eval, B), device=X_eval.device)
269 | for digit_idx, digit_pred in enumerate(digit_preds):
270 | digit_value = (
271 | class_idx // (self.max_num_classes**digit_idx)
272 | ) % self.max_num_classes
273 | # digit_pred shape: [L_context + L_eval, 1, max_num_classes]
274 | # The last L_eval rows correspond to the actual predictions
275 | class_pred += digit_pred[-L_eval:, :, digit_value] # shape: [L_eval]
276 |
277 | full_pred[:, :, class_idx] = class_pred
278 | return full_pred # shape: [L_eval, 1, self.num_classes]
279 |
280 | @torch.no_grad()
281 | def predict_proba(
282 | self,
283 | X: np.ndarray,
284 | temperature: float = _DEFAULT_TEMPERATURE,
285 | context_size: int = _DEFAULT_CONTEXT_SIZE,
286 | use_retrieval: bool = _DEFAULT_RETRIEVAL,
287 | ) -> np.ndarray:
288 | """predict class probabilities for the input data.
289 |
290 |
291 | Args:
292 | X (np.ndarray): input features for prediction.
293 | temperature (float, optional): output probability temperature. Defaults to 0.8.
294 | context_size (int, optional): context size. Defaults to 128.
295 | use_retrieval (bool, optional):
296 | - If `use_retrieval=True`, do batch-by-batch retrieval.
297 | - If `use_retrieval=False`, pick one random context and feed all test points in one shot:
298 | [N_context + N_test, 1, d]. Defaults to True.
299 |
300 | Returns:
301 | np.ndarray: probbilities for each class for each test instance.
302 | """
303 | train_x, train_y, test_x = self._prepare_prediction(X)
304 | n_test = test_x.shape[0]
305 |
306 | # 1) If context_size >= entire training set => use them all
307 | if context_size >= self.n_instances:
308 | # shape: [n_train, 1, d]
309 | X_context = pad_x(train_x[:, None, :], self.max_features).to(self.device)
310 | y_context = train_y[:, None].float()
311 | # shape: [n_test, 1, d]
312 | X_eval = pad_x(test_x[:, None, :], self.max_features).to(self.device)
313 |
314 | if self.num_classes <= self.max_num_classes:
315 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
316 | pred = self.model(
317 | x_src=torch.cat([X_context, X_eval], dim=0),
318 | y_src=y_context,
319 | )
320 | # shape: [n_train + n_test, 1, max_num_classes]
321 | pred = pred[..., : self.num_classes].float()
322 | # extract the last n_test rows => the test predictions
323 | test_preds = pred[-n_test:, 0, :] # shape: [n_test, num_classes]
324 |
325 | else:
326 | # Large-class approach
327 | test_preds = self._predict_large_cls(X_context, X_eval, y_context)
328 | test_preds = test_preds.squeeze(1) # shape: [n_test, num_classes]
329 |
330 | test_preds /= temperature
331 | test_preds = torch.nn.functional.softmax(test_preds, dim=-1)
332 | return test_preds.detach().cpu().numpy()
333 |
334 | # TODO: combine retrieval step with that of training
335 | # 2) If we want retrieval
336 | if use_retrieval:
337 | preds_list = []
338 | # batch the test set
339 | for b in range(math.ceil(n_test / self.inf_batch_size)):
340 | X_nni, y_nni, X_eval = self.batch_retrieval_data(
341 | train_x, train_y, test_x, b, context_size
342 | )
343 | # forward pass
344 | if self.num_classes <= self.max_num_classes: # small-class case
345 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
346 | pred = self.model(
347 | x_src=torch.cat([X_nni, X_eval], dim=0),
348 | y_src=y_nni,
349 | )
350 | # shape: [context_size + 1, batch_size, max_num_classes]
351 | pred = pred[..., : self.num_classes].float()
352 | # last row => predictions for test batch
353 | batch_preds = pred[-1, :, :] # shape: [batch_size, num_classes]
354 | else:
355 | # large-class case
356 | batch_preds_full = self._predict_large_cls(X_nni, X_eval, y_nni)
357 | batch_preds = batch_preds_full[-1, :, :]
358 |
359 | batch_preds /= temperature
360 | batch_preds = torch.nn.functional.softmax(batch_preds, dim=-1)
361 | preds_list.append(batch_preds)
362 |
363 | preds_all = torch.cat(preds_list, dim=0) # [n_test, num_classes]
364 | return preds_all.cpu().numpy()
365 |
366 | # 3) If we want NO retrieval => single pass
367 | # [N_context, 1, d] + [N_test, 1, d] => [N_context + N_test, 1, d]
368 | else:
369 | X_ctx, y_ctx, X_eval = self.no_retrieval_data(train_x, train_y, test_x)
370 |
371 | if self.num_classes <= self.max_num_classes:
372 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
373 | pred = self.model(
374 | x_src=torch.cat([X_ctx, X_eval], dim=0),
375 | y_src=y_ctx,
376 | )
377 | # shape: [n_context + n_test, 1, max_num_classes]
378 | pred = pred[..., : self.num_classes].float()
379 | # The last n_test rows => the actual test preds
380 | test_preds = pred[-n_test:, 0, :] # [n_test, num_classes]
381 | else:
382 | test_preds_full = self._predict_large_cls(X_ctx, X_eval, y_ctx)
383 | # shape: [n_context + n_test - ???, 1, num_classes]
384 | # We only want the last n_test rows
385 | test_preds = test_preds_full[-n_test:, 0, :]
386 |
387 | test_preds /= temperature
388 | test_preds = torch.nn.functional.softmax(test_preds, dim=-1)
389 | return test_preds.detach().cpu().numpy()
390 |
391 | def predict(
392 | self,
393 | X: np.ndarray,
394 | temperature: float = _DEFAULT_TEMPERATURE,
395 | context_size: int = _DEFAULT_CONTEXT_SIZE,
396 | use_retrieval: bool = _DEFAULT_RETRIEVAL,
397 | ) -> np.ndarray:
398 | """predict class labels for the input data.
399 | This method uses the `predict_proba` method to get class probabilities and then
400 | returns the class with the highest probability for each instance.
401 |
402 | Args:
403 | X (np.ndarray): input features for prediction.
404 | temperature (float, optional): temperature for prediction. Defaults to 0.8.
405 | context_size (int, optional): size of context. Defaults to 128.
406 | use_retrieval (bool, optional): whether to use retrieval. Defaults to True.
407 |
408 | Returns:
409 | np.ndarray: class labels for each test instance.
410 | """
411 | return self.predict_proba(
412 | X, temperature=temperature, context_size=context_size, use_retrieval=use_retrieval
413 | ).argmax(axis=-1)
414 |
415 |
416 | class TabDPTRegressor(TabDPTEstimator, RegressorMixin):
417 | """A PyTorch-based implementation of TabDPT for regression tasks."""
418 |
419 | def __init__(
420 | self,
421 | model: Optional[TabDPTModel] = None,
422 | path: str = "",
423 | mode: Task = Task.REG,
424 | inf_batch_size: int = _INF_BATCH_SIZE,
425 | device: str = _DEFAULT_DEVICE,
426 | tensor_eval: bool = False,
427 | ):
428 | super().__init__(
429 | model=model,
430 | path=path,
431 | mode=mode,
432 | inf_batch_size=inf_batch_size,
433 | device=device,
434 | tensor_eval=tensor_eval,
435 | )
436 |
437 | @torch.no_grad()
438 | def predict(
439 | self,
440 | X: np.ndarray,
441 | context_size: int = _DEFAULT_CONTEXT_SIZE,
442 | use_retrieval: bool = _DEFAULT_RETRIEVAL,
443 | ) -> np.ndarray:
444 | """predict regression values for the input data.
445 |
446 | Args:
447 | X (np.ndarray): input features for prediction.
448 | context_size (int, optional): size of the conte. Defaults to True.xt. Defaults to 128.
449 | use_retrieval (bool, optional): whether to use retrieval.
450 | - If use_retrieval=True, do KNN retrieval (batched).
451 | - If use_retrieval=False, create a single random context and do [N_context + N_test, 1, d] once.
452 | Defaults to True.
453 | Returns:
454 | np.ndarray: regression values for each test instance.
455 | """
456 | train_x, train_y, test_x = self._prepare_prediction(X)
457 | n_test = test_x.shape[0]
458 |
459 | # 1) If context_size >= entire training set => use them all
460 | if context_size >= self.n_instances:
461 | X_train = pad_x(train_x[:, None, :], self.max_features).to(self.device)
462 | X_test = pad_x(test_x[:, None, :], self.max_features).to(self.device)
463 | y_train = train_y[:, None].float()
464 |
465 | # standardize
466 | y_train, y_means, y_stds = standardize(y_train)
467 |
468 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
469 | pred = self.model(
470 | x_src=torch.cat([X_train, X_test], dim=0),
471 | y_src=y_train,
472 | )
473 | pred = pred[..., -1].float()
474 | # last n_test rows => test preds
475 | test_preds = pred[-n_test:, 0] # shape: [n_test]
476 | test_preds = test_preds * y_stds + y_means
477 | return test_preds.detach().cpu().numpy()
478 |
479 | # 2) If we want retrieval
480 | if use_retrieval:
481 | pred_list = []
482 | for b in range(math.ceil(n_test / self.inf_batch_size)):
483 | X_nni, y_nni, X_eval = self.batch_retrieval_data(
484 | train_x, train_y, test_x, b, context_size
485 | )
486 | # standardize context targets
487 | y_nni, y_means, y_stds = standardize(y_nni)
488 |
489 | # forward pass
490 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
491 | pred = self.model(
492 | x_src=torch.cat([X_nni, X_eval], dim=0),
493 | y_src=y_nni,
494 | )
495 | pred = pred[..., -1].float()
496 | # last row => predictions for the test batch
497 | batch_preds = pred[-1, :] # shape: [batch_size]
498 | # reverse standardization
499 | batch_preds = batch_preds * y_stds + y_means
500 | pred_list.append(batch_preds.cpu())
501 |
502 | return torch.cat(pred_list).squeeze().detach().cpu().numpy()
503 |
504 | # 3) If we want NO retrieval => single pass
505 | else:
506 | X_ctx, y_ctx, X_eval = self.no_retrieval_data(train_x, train_y, test_x)
507 | # standardize y_ctx
508 | y_ctx_norm, y_means, y_stds = standardize(y_ctx)
509 |
510 | # forward pass
511 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION):
512 | pred = self.model(
513 | x_src=torch.cat([X_ctx, X_eval], dim=0), # shape: [n_context + n_test, 1, d]
514 | y_src=y_ctx_norm,
515 | )
516 | # shape: [n_context + n_test, 1, out_dim]
517 | # The last dimension presumably has the regression output at the last index
518 | pred = pred[..., -1].float()
519 |
520 | # The last n_test positions are the test predictions
521 | test_preds = pred[-n_test:, 0] # shape: [n_test]
522 | # reverse standardization
523 | test_preds = test_preds * y_stds + y_means
524 | return test_preds.detach().cpu().numpy()
525 |
--------------------------------------------------------------------------------
/data_splits/cls_datasets.csv:
--------------------------------------------------------------------------------
1 | did,tid,num_instances,num_features,num_classes,num_missing_values,tabzilla_name,source,test_small,test_large,test_all
2 | 41138,168868.0,76000.0,171.0,2.0,1078695.0,openml__APSFailure__168868,"['tabzilla', 'amlb']",False,False,False
3 | 4135,34539.0,32769.0,10.0,2.0,0.0,openml__Amazon_employee_access__34539,"['tabzilla', 'amlb']",False,False,False
4 | 40981,146818.0,690.0,15.0,2.0,0.0,openml__Australian__146818,"['tabzilla', 'amlb']",False,False,False
5 | 4134,9910.0,3751.0,1777.0,2.0,0.0,openml__Bioresponse__9910,"['tabzilla', 'cc18', 'amlb', 'amlb']",False,False,True
6 | 40927,167124.0,60000.0,3073.0,10.0,0.0,openml__CIFAR_10__167124,"['tabzilla', 'cc18']",False,False,True
7 | 4535,168340.0,299285.0,42.0,,0.0,openml__Census-Income__168340,['tabzilla'],False,False,False
8 | 41434,190408.0,39948.0,12.0,2.0,0.0,openml__Click_prediction_small__190408,['tabzilla'],False,False,False
9 | 40923,167121.0,92000.0,1025.0,46.0,0.0,openml__Devnagari-Script__167121,"['tabzilla', 'cc18']",False,False,True
10 | 40996,146825.0,70000.0,785.0,10.0,0.0,openml__Fashion-MNIST__146825,"['tabzilla', 'cc18', 'amlb']",False,False,True
11 | 4538,14969.0,9873.0,33.0,5.0,0.0,openml__GesturePhaseSegmentationProcessed__14969,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
12 | 40978,167125.0,3279.0,1559.0,2.0,0.0,openml__Internet-Advertisements__167125,"['tabzilla', 'cc18', 'amlb', 'amlb']",False,False,True
13 | 375,3510.0,9961.0,15.0,9.0,0.0,openml__JapaneseVowels__3510,['tabzilla'],False,False,False
14 | 40496,125921.0,500.0,8.0,10.0,0.0,openml__LED-display-domain-7digit__125921,['tabzilla'],False,False,False
15 | 1120,3954.0,19020.0,12.0,2.0,0.0,openml__MagicTelescope__3954,['tabzilla'],False,False,False
16 | 40966,146800.0,1080.0,82.0,8.0,1396.0,openml__MiceProtein__146800,"['tabzilla', 'cc18']",True,True,True
17 | 41150,168335.0,130064.0,51.0,2.0,0.0,openml__MiniBooNE__168335,"['tabzilla', 'amlb']",False,False,False
18 | 4534,14952.0,11055.0,31.0,2.0,0.0,openml__PhishingWebsites__14952,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
19 | 40900,167211.0,5100.0,37.0,2.0,0.0,openml__Satellite__167211,"['tabzilla', 'amlb', 'amlb']",False,False,False
20 | 40536,146607.0,8378.0,121.0,2.0,18372.0,openml__SpeedDating__146607,['tabzilla'],False,False,False
21 | 1455,10089.0,120.0,7.0,2.0,0.0,openml__acute-inflammations__10089,['tabzilla'],False,False,False
22 | 1043,3896.0,4562.0,49.0,2.0,0.0,openml__ada_agnostic__3896,['tabzilla'],False,False,False
23 | 1119,3953.0,32561.0,16.0,2.0,4262.0,openml__adult-census__3953,['tabzilla'],False,False,False
24 | 1590,7592.0,48842.0,15.0,2.0,6465.0,openml__adult__7592,"['tabzilla', 'cc18', 'amlb']",True,True,True
25 | 1169,189354.0,539383.0,8.0,2.0,0.0,openml__airlines__189354,"['tabzilla', 'amlb']",False,False,False
26 | 41147,189356.0,425240.0,79.0,2.0,2734000.0,openml__albert__189356,"['tabzilla', 'amlb']",False,False,False
27 | 458,3549.0,841.0,71.0,4.0,0.0,openml__analcatdata_authorship__3549,"['tabzilla', 'cc18']",True,True,True
28 | 448,3540.0,120.0,4.0,2.0,0.0,openml__analcatdata_boxing1__3540,['tabzilla'],False,False,False
29 | 875,3739.0,100.0,4.0,2.0,0.0,openml__analcatdata_chlamydia__3739,['tabzilla'],False,False,False
30 | 469,3560.0,797.0,5.0,6.0,0.0,openml__analcatdata_dmft__3560,"['tabzilla', 'cc18']",True,True,True
31 | 1,2867.0,898.0,39.0,5.0,0.0,openml__anneal__2867,['tabzilla'],False,False,False
32 | 5,5.0,452.0,280.0,13.0,408.0,openml__arrhythmia__5,['tabzilla'],False,False,False
33 | 1459,14964.0,10218.0,8.0,10.0,0.0,openml__artificial-characters__14964,['tabzilla'],False,False,False
34 | 7,7.0,226.0,70.0,24.0,317.0,openml__audiology__7,['tabzilla'],False,False,False
35 | 9,9.0,205.0,26.0,6.0,59.0,openml__autos__9,['tabzilla'],False,False,False
36 | 11,11.0,625.0,5.0,3.0,0.0,openml__balance-scale__11,"['tabzilla', 'cc18']",True,True,True
37 | 1461,14965.0,45211.0,17.0,2.0,0.0,openml__bank-marketing__14965,"['tabzilla', 'cc18', 'amlb']",True,True,True
38 | 1558,9899.0,4521.0,17.0,2.0,0.0,openml__bank-marketing__9899,['tabzilla'],False,False,False
39 | 1462,10093.0,1372.0,5.0,2.0,0.0,openml__banknote-authentication__10093,"['tabzilla', 'cc18']",True,True,True
40 | 1464,10101.0,748.0,5.0,2.0,0.0,openml__blood-transfusion-service-center__145836,"['tabzilla', 'tabzilla', 'cc18', 'amlb']",True,True,True
41 | 13,145799.0,286.0,10.0,2.0,9.0,openml__breast-cancer__145799,['tabzilla'],False,False,False
42 | 15,15.0,699.0,10.0,2.0,16.0,openml__breast-w__15,"['tabzilla', 'cc18']",True,True,True
43 | 40664,146192.0,1728.0,22.0,4.0,0.0,openml__car-evaluation__146192,['tabzilla'],False,False,False
44 | 40975,146821.0,1728.0,7.0,4.0,0.0,openml__car__146821,"['tabzilla', 'cc18', 'amlb']",True,True,True
45 | 1466,9979.0,2126.0,36.0,10.0,0.0,openml__cardiotocography__9979,['tabzilla'],False,False,False
46 | 1118,3952.0,28056.0,7.0,18.0,0.0,openml__chess__3952,['tabzilla'],False,False,False
47 | 41142,168908.0,5418.0,1637.0,2.0,0.0,openml__christine__168908,"['tabzilla', 'amlb']",False,False,False
48 | 40701,167141.0,5000.0,21.0,2.0,0.0,openml__churn__167141,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
49 | 23380,14967.0,2796.0,35.0,6.0,68100.0,openml__cjs__14967,['tabzilla'],False,False,False
50 | 40994,146819.0,540.0,21.0,2.0,0.0,openml__climate-model-simulation-crashes__146819,"['tabzilla', 'cc18']",True,True,True
51 | 23,23.0,1473.0,10.0,3.0,0.0,openml__cmc__23,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
52 | 1468,9981.0,1080.0,857.0,9.0,0.0,openml__cnae-9__9981,"['tabzilla', 'cc18', 'amlb']",False,False,True
53 | 25,25.0,368.0,27.0,2.0,1927.0,openml__colic__25,['tabzilla'],False,False,False
54 | 27,27.0,368.0,23.0,2.0,1927.0,openml__colic__27,['tabzilla'],False,False,False
55 | 478,3567.0,500.0,22.0,15.0,0.0,openml__collins__3567,['tabzilla'],False,False,False
56 | 40668,146195.0,67557.0,43.0,3.0,0.0,openml__connect-4__146195,"['tabzilla', 'cc18', 'amlb']",True,True,True
57 | 1596,7593.0,581012.0,55.0,7.0,0.0,openml__covertype__7593,"['tabzilla', 'amlb']",False,False,False
58 | 29,29.0,690.0,16.0,2.0,67.0,openml__credit-approval__29,"['tabzilla', 'cc18']",True,True,True
59 | 31,31.0,1000.0,21.0,2.0,0.0,openml__credit-g__31,"['tabzilla', 'cc18', 'amlb']",True,True,True
60 | 6332,14954.0,540.0,40.0,2.0,999.0,openml__cylinder-bands__14954,"['tabzilla', 'cc18']",True,True,True
61 | 35,35.0,366.0,35.0,6.0,8.0,openml__dermatology__35,['tabzilla'],False,False,False
62 | 37,37.0,768.0,9.0,2.0,0.0,openml__diabetes__37,"['tabzilla', 'cc18']",True,True,True
63 | 41163,168909.0,10000.0,2001.0,5.0,0.0,openml__dilbert__168909,"['tabzilla', 'amlb']",False,False,False
64 | 40670,167140.0,3186.0,181.0,3.0,0.0,openml__dna__167140,"['tabzilla', 'cc18', 'amlb', 'amlb']",False,True,True
65 | 23381,125920.0,500.0,13.0,2.0,835.0,openml__dresses-sales__125920,"['tabzilla', 'cc18']",True,True,True
66 | 39,145977.0,336.0,8.0,8.0,0.0,openml__ecoli__145977,['tabzilla'],False,False,False
67 | 1471,14951.0,14980.0,15.0,2.0,0.0,openml__eeg-eye-state__14951,['tabzilla'],False,False,False
68 | 151,219.0,45312.0,9.0,2.0,0.0,openml__electricity__219,"['tabzilla', 'cc18']",True,True,True
69 | 846,3711.0,16599.0,19.0,2.0,0.0,openml__elevators__3711,['tabzilla'],False,False,False
70 | 188,2079.0,736.0,20.0,5.0,448.0,openml__eucalyptus__2079,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
71 | 1044,3897.0,10936.0,28.0,3.0,0.0,openml__eye_movements__3897,['tabzilla'],False,False,False
72 | 41164,168910.0,8237.0,801.0,7.0,0.0,openml__fabert__168910,"['tabzilla', 'amlb']",False,False,False
73 | 1473,9984.0,100.0,10.0,2.0,0.0,openml__fertility__9984,['tabzilla'],False,False,False
74 | 1475,9985.0,6118.0,52.0,6.0,0.0,openml__first-order-theorem-proving__9985,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
75 | 477,3566.0,67.0,16.0,5.0,0.0,openml__fl2000__3566,['tabzilla'],False,False,False
76 | 754,3620.0,100.0,6.0,2.0,0.0,openml__fri_c0_100_5__3620,['tabzilla'],False,False,False
77 | 916,3779.0,100.0,6.0,2.0,0.0,openml__fri_c3_100_5__3779,['tabzilla'],False,False,False
78 | 1477,9987.0,13910.0,130.0,6.0,0.0,openml__gas-drift-different-concentrations__9987,['tabzilla'],False,False,False
79 | 1476,9986.0,13910.0,129.0,6.0,0.0,openml__gas-drift__9986,['tabzilla'],False,False,False
80 | 1038,3891.0,3468.0,971.0,2.0,0.0,openml__gina_agnostic__3891,['tabzilla'],False,False,False
81 | 41,40.0,214.0,10.0,6.0,0.0,openml__glass__40,['tabzilla'],False,False,False
82 | 41159,168337.0,20000.0,4297.0,2.0,0.0,openml__guillermo__168337,"['tabzilla', 'amlb']",False,False,False
83 | 43,42.0,306.0,4.0,2.0,0.0,openml__haberman__42,['tabzilla'],False,False,False
84 | 1478,14970.0,10299.0,562.0,6.0,0.0,openml__har__14970,"['tabzilla', 'cc18']",False,False,True
85 | 329,146063.0,160.0,5.0,3.0,0.0,openml__hayes-roth__146063,['tabzilla'],False,False,False
86 | 49,48.0,303.0,14.0,2.0,7.0,openml__heart-c__48,['tabzilla'],False,False,False
87 | 51,50.0,294.0,14.0,2.0,782.0,openml__heart-h__50,['tabzilla'],False,False,False
88 | 55,54.0,155.0,20.0,2.0,167.0,openml__hepatitis__54,['tabzilla'],False,False,False
89 | 23512,146606.0,98050.0,29.0,2.0,9.0,openml__higgs__146606,['tabzilla'],False,False,False
90 | 1479,145847.0,1212.0,101.0,2.0,0.0,openml__hill-valley__145847,['tabzilla'],False,False,False
91 | 821,3686.0,22784.0,17.0,2.0,0.0,openml__house_16H__3686,['tabzilla'],False,False,False
92 | 1480,9971.0,583.0,11.0,2.0,0.0,openml__ilpd__9971,"['tabzilla', 'cc18']",True,True,True
93 | 59,145984.0,351.0,35.0,2.0,0.0,openml__ionosphere__145984,['tabzilla'],False,False,False
94 | 61,59.0,150.0,5.0,3.0,0.0,openml__iris__59,['tabzilla'],False,False,False
95 | 451,3543.0,500.0,6.0,2.0,32.0,openml__irish__3543,['tabzilla'],False,False,False
96 | 300,3481.0,7797.0,618.0,26.0,0.0,openml__isolet__3481,"['tabzilla', 'cc18']",False,False,True
97 | 41168,168330.0,83733.0,55.0,4.0,0.0,openml__jannis__168330,"['tabzilla', 'amlb']",False,False,False
98 | 41143,168911.0,2984.0,145.0,2.0,0.0,openml__jasmine__168911,"['tabzilla', 'amlb']",False,False,False
99 | 1053,3904.0,10885.0,22.0,2.0,25.0,openml__jm1__3904,"['tabzilla', 'cc18']",True,True,True
100 | 41027,167119.0,44819.0,7.0,3.0,0.0,openml__jungle_chess_2pcs_raw_endgame_complete__167119,"['tabzilla', 'cc18', 'amlb']",True,True,True
101 | 1067,3917.0,2109.0,22.0,2.0,0.0,openml__kc1__3917,"['tabzilla', 'cc18', 'amlb']",True,True,True
102 | 1063,3913.0,522.0,22.0,2.0,0.0,openml__kc2__3913,"['tabzilla', 'cc18']",True,True,True
103 | 3,3.0,3196.0,37.0,2.0,0.0,openml__kr-vs-kp__3,"['tabzilla', 'cc18', 'amlb']",True,True,True
104 | 184,2076.0,28056.0,7.0,18.0,0.0,openml__kropt__2076,['tabzilla'],False,False,False
105 | 4,4.0,57.0,17.0,2.0,326.0,openml__labor__4,['tabzilla'],False,False,False
106 | 1483,9974.0,164860.0,8.0,11.0,0.0,openml__ldpa__9974,['tabzilla'],False,False,False
107 | 6,6.0,20000.0,17.0,26.0,0.0,openml__letter__6,"['tabzilla', 'cc18']",False,False,True
108 | 42855,360948.0,360.0,105.0,,0.0,openml__libras__360948,['tabzilla'],False,False,False
109 | 163,146024.0,32.0,57.0,3.0,5.0,openml__lung-cancer__146024,['tabzilla'],False,False,False
110 | 10,10.0,148.0,19.0,4.0,0.0,openml__lymph__10,['tabzilla'],False,False,False
111 | 1485,9976.0,2600.0,501.0,2.0,0.0,openml__madelon__9976,"['tabzilla', 'cc18']",False,False,True
112 | 40679,146206.0,19020.0,11.0,2.0,0.0,openml__magic__146206,['tabzilla'],False,False,False
113 | 12,12.0,2000.0,217.0,10.0,0.0,openml__mfeat-factors__12,"['tabzilla', 'cc18', 'amlb']",False,True,True
114 | 14,14.0,2000.0,77.0,10.0,0.0,openml__mfeat-fourier__14,"['tabzilla', 'cc18']",True,True,True
115 | 16,16.0,2000.0,65.0,10.0,0.0,openml__mfeat-karhunen__16,"['tabzilla', 'cc18']",True,True,True
116 | 18,18.0,2000.0,7.0,10.0,0.0,openml__mfeat-morphological__18,"['tabzilla', 'cc18']",True,True,True
117 | 40979,146824.0,2000.0,241.0,10.0,0.0,openml__mfeat-pixel__146824,"['tabzilla', 'cc18']",False,True,True
118 | 22,22.0,2000.0,48.0,10.0,0.0,openml__mfeat-zernike__22,"['tabzilla', 'cc18']",True,True,True
119 | 554,3573.0,70000.0,785.0,10.0,0.0,openml__mnist_784__3573,"['tabzilla', 'cc18']",False,False,True
120 | 334,146065.0,601.0,7.0,2.0,0.0,openml__monks-problems-2__146065,['tabzilla'],False,False,False
121 | 24,24.0,8124.0,23.0,2.0,2480.0,openml__mushroom__24,['tabzilla'],False,False,False
122 | 1116,3950.0,6598.0,168.0,2.0,0.0,openml__musk__3950,['tabzilla'],False,False,False
123 | 1486,9977.0,34465.0,119.0,2.0,0.0,openml__nomao__9977,"['tabzilla', 'cc18', 'amlb']",False,True,True
124 | 23517,167120.0,96320.0,22.0,2.0,0.0,openml__numerai28.6__167120,"['tabzilla', 'cc18', 'amlb']",True,True,True
125 | 1568,9892.0,12958.0,9.0,4.0,0.0,openml__nursery__9892,['tabzilla'],False,False,False
126 | 1493,9956.0,1599.0,65.0,100.0,0.0,openml__one-hundred-plants-texture__9956,['tabzilla'],False,False,False
127 | 28,28.0,5620.0,65.0,10.0,0.0,openml__optdigits__28,"['tabzilla', 'cc18']",True,True,True
128 | 1487,9978.0,2534.0,73.0,2.0,0.0,openml__ozone-level-8hr__9978,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
129 | 30,30.0,5473.0,11.0,5.0,0.0,openml__page-blocks__30,['tabzilla'],False,False,False
130 | 1068,3918.0,1109.0,22.0,2.0,0.0,openml__pc1__3918,"['tabzilla', 'cc18']",True,True,True
131 | 1050,3903.0,1563.0,38.0,2.0,0.0,openml__pc3__3903,"['tabzilla', 'cc18']",True,True,True
132 | 1049,3902.0,1458.0,38.0,2.0,0.0,openml__pc4__3902,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
133 | 32,32.0,10992.0,17.0,10.0,0.0,openml__pendigits__32,"['tabzilla', 'cc18']",True,True,True
134 | 41145,190410.0,5832.0,309.0,2.0,0.0,openml__philippine__190410,"['tabzilla', 'amlb', 'amlb']",False,False,False
135 | 1489,9952.0,5404.0,6.0,2.0,0.0,openml__phoneme__9952,"['tabzilla', 'cc18', 'amlb']",True,True,True
136 | 1567,9890.0,1025009.0,11.0,10.0,0.0,openml__poker-hand__9890,['tabzilla'],False,False,False
137 | 871,3735.0,3848.0,6.0,2.0,0.0,openml__pollen__3735,['tabzilla'],False,False,False
138 | 40683,146210.0,88.0,9.0,2.0,0.0,openml__postoperative-patient-data__146210,['tabzilla'],False,False,False
139 | 171,146032.0,339.0,18.0,21.0,225.0,openml__primary-tumor__146032,['tabzilla'],False,False,False
140 | 470,3561.0,672.0,10.0,2.0,1200.0,openml__profb__3561,['tabzilla'],False,False,False
141 | 1494,9957.0,1055.0,42.0,2.0,0.0,openml__qsar-biodeg__9957,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
142 | 782,3647.0,120.0,3.0,2.0,0.0,openml__rabe_266__3647,['tabzilla'],False,False,False
143 | 41161,168338.0,20000.0,4297.0,2.0,0.0,openml__riccardo__168338,"['tabzilla', 'amlb']",False,False,False
144 | 41165,168332.0,10000.0,7201.0,10.0,0.0,openml__robert__168332,"['tabzilla', 'amlb']",False,False,False
145 | 182,2074.0,6430.0,37.0,6.0,0.0,openml__satimage__2074,"['tabzilla', 'cc18']",True,True,True
146 | 312,3485.0,2407.0,300.0,2.0,0.0,openml__scene__3485,['tabzilla'],False,False,False
147 | 40984,146822.0,2310.0,20.0,7.0,0.0,openml__segment__146822,"['tabzilla', 'cc18', 'amlb']",True,True,True
148 | 1501,9964.0,1593.0,257.0,10.0,0.0,openml__semeion__9964,"['tabzilla', 'cc18']",False,True,True
149 | 40685,146212.0,58000.0,10.0,7.0,0.0,openml__shuttle__146212,"['tabzilla', 'amlb']",False,False,False
150 | 38,3021.0,3772.0,30.0,2.0,6064.0,openml__sick__3021,"['tabzilla', 'cc18']",True,True,True
151 | 1502,9965.0,245057.0,4.0,2.0,0.0,openml__skin-segmentation__9965,['tabzilla'],False,False,False
152 | 934,3797.0,1156.0,6.0,2.0,0.0,openml__socmob__3797,['tabzilla'],False,False,False
153 | 174,2068.0,1066.0,13.0,3.0,0.0,openml__solar-flare__2068,['tabzilla'],False,False,False
154 | 40,39.0,208.0,61.0,2.0,0.0,openml__sonar__39,['tabzilla'],False,False,False
155 | 42,41.0,683.0,36.0,19.0,2337.0,openml__soybean__41,['tabzilla'],False,False,False
156 | 44,43.0,4601.0,58.0,2.0,0.0,openml__spambase__43,"['tabzilla', 'cc18']",True,True,True
157 | 46,45.0,3190.0,61.0,3.0,0.0,openml__splice__45,"['tabzilla', 'cc18']",True,True,True
158 | 40982,146817.0,1941.0,28.0,7.0,0.0,openml__steel-plates-fault__146817,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
159 | 1036,3889.0,14395.0,217.0,2.0,0.0,openml__sylva_agnostic__3889,['tabzilla'],False,False,False
160 | 41146,168912.0,5124.0,21.0,2.0,0.0,openml__sylvine__168912,"['tabzilla', 'amlb']",False,False,False
161 | 377,3512.0,600.0,61.0,6.0,0.0,openml__synthetic_control__3512,['tabzilla'],False,False,False
162 | 48,47.0,151.0,6.0,3.0,0.0,openml__tae__47,['tabzilla'],False,False,False
163 | 40499,125922.0,5500.0,41.0,11.0,0.0,openml__texture__125922,"['tabzilla', 'cc18']",False,True,True
164 | 50,49.0,958.0,10.0,2.0,0.0,openml__tic-tac-toe__49,"['tabzilla', 'cc18']",True,True,True
165 | 885,3748.0,131.0,4.0,2.0,0.0,openml__transplant__3748,['tabzilla'],False,False,False
166 | 54,53.0,846.0,19.0,4.0,0.0,openml__vehicle__53,"['tabzilla', 'cc18', 'amlb']",True,True,True
167 | 736,3602.0,111.0,4.0,2.0,0.0,openml__visualizing_environmental__3602,['tabzilla'],False,False,False
168 | 867,3731.0,130.0,3.0,2.0,0.0,openml__visualizing_livestock__3731,['tabzilla'],False,False,False
169 | 41166,168331.0,58310.0,181.0,10.0,0.0,openml__volkert__168331,"['tabzilla', 'amlb']",False,False,False
170 | 307,3022.0,990.0,13.0,11.0,0.0,openml__vowel__3022,"['tabzilla', 'cc18']",False,True,True
171 | 1509,9945.0,149332.0,5.0,22.0,0.0,openml__walking-activity__9945,['tabzilla'],False,False,False
172 | 1497,9960.0,5456.0,25.0,4.0,0.0,openml__wall-robot-navigation__9960,"['tabzilla', 'cc18']",True,True,True
173 | 1510,9946.0,569.0,31.0,2.0,0.0,openml__wdbc__9946,"['tabzilla', 'cc18']",True,True,True
174 | 40983,146820.0,4839.0,6.0,2.0,0.0,openml__wilt__146820,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True
175 | 40733,145793.0,1269.0,9.0,4.0,0.0,openml__yeast__145793,['tabzilla'],False,False,False
176 | 44089,361055.0,16714.0,11.0,2.0,0.0,,['grins'],False,False,False
177 | 44120,361060.0,38474.0,8.0,2.0,0.0,,['grins'],False,False,False
178 | 44121,361061.0,566602.0,11.0,2.0,0.0,,['grins'],False,False,False
179 | 44122,361062.0,10082.0,27.0,2.0,0.0,,['grins'],False,False,False
180 | 44123,361063.0,13488.0,17.0,2.0,0.0,,['grins'],False,False,False
181 | 44125,361065.0,13376.0,11.0,2.0,0.0,,['grins'],False,False,False
182 | 44126,361066.0,10578.0,8.0,2.0,0.0,,['grins'],False,False,False
183 | 44128,361068.0,72998.0,51.0,2.0,0.0,,['grins'],False,False,False
184 | 44129,361069.0,940160.0,25.0,2.0,0.0,,['grins'],False,False,False
185 | 44130,361070.0,7608.0,21.0,2.0,0.0,,['grins'],False,False,False
186 | 45022,361273.0,71090.0,8.0,2.0,0.0,,['grins'],False,False,False
187 | 45021,361274.0,57580.0,55.0,2.0,0.0,,['grins'],False,False,False
188 | 45020,361275.0,13272.0,21.0,2.0,0.0,,['grins'],False,False,False
189 | 45019,361276.0,3434.0,420.0,2.0,0.0,,['grins'],False,False,False
190 | 45028,361277.0,20634.0,9.0,2.0,0.0,,['grins'],False,False,False
191 | 45026,361278.0,10000.0,23.0,2.0,0.0,,['grins'],False,False,False
192 | 44156,361110.0,38474.0,9.0,2.0,0.0,,['grins'],False,False,False
193 | 44157,361111.0,7608.0,24.0,2.0,0.0,,['grins'],False,False,False
194 | 44159,361113.0,423680.0,55.0,2.0,0.0,,['grins'],False,False,False
195 | 45035,361282.0,58252.0,32.0,2.0,0.0,,['grins'],False,False,False
196 | 45036,361283.0,13272.0,22.0,2.0,0.0,,['grins'],False,False,False
197 | 45038,361285.0,111762.0,33.0,2.0,0.0,,['grins'],False,False,False
198 | 45039,361286.0,4966.0,12.0,2.0,0.0,,['grins'],False,False,False
199 | 181,2073.0,1484.0,9.0,10.0,0.0,,"['amlb', 'amlb']",False,False,False
200 | 1111,3945.0,50000.0,231.0,2.0,8024152.0,,['amlb'],False,False,False
201 | 1457,10090.0,1500.0,10001.0,50.0,0.0,,"['amlb', 'amlb']",False,False,False
202 | 41167,189355.0,416188.0,61.0,355.0,0.0,,['amlb'],False,False,False
203 | 41158,189922.0,3153.0,971.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False
204 | 41144,190392.0,3140.0,260.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False
205 | 41156,190411.0,4147.0,49.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False
206 | 41157,190412.0,100.0,10001.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False
207 | 4541,211986.0,101766.0,50.0,3.0,0.0,,"['amlb', 'amlb']",False,False,False
208 | 1515,359953.0,571.0,1301.0,20.0,0.0,,"['amlb', 'amlb']",False,False,False
209 | 40498,359974.0,4898.0,12.0,7.0,0.0,,"['amlb', 'amlb']",False,False,False
210 | 41169,359984.0,65196.0,28.0,100.0,0.0,,['amlb'],False,False,False
211 | 41162,359991.0,72983.0,33.0,2.0,149271.0,,"['amlb', 'amlb']",False,False,False
212 | 42733,359992.0,39948.0,12.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False
213 | 42734,359993.0,50789.0,20.0,3.0,154107.0,,"['amlb', 'amlb']",False,False,False
214 | 42732,359994.0,2215023.0,9.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False
215 | 42746,360112.0,4898431.0,42.0,23.0,0.0,,"['amlb', 'amlb']",False,False,False
216 | 42742,360113.0,595212.0,58.0,2.0,846458.0,,"['amlb', 'amlb']",False,False,False
217 | 42769,360114.0,1000000.0,29.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False
218 | 43072,360975.0,50000.0,14892.0,2.0,19658569.0,,"['amlb', 'amlb']",False,False,False
219 | 20,,2000.0,241.0,10.0,0.0,,['additional'],False,False,
220 | 53,,270.0,14.0,2.0,0.0,,['additional'],False,False,
221 | 56,,435.0,17.0,2.0,392.0,,['additional'],False,False,
222 | 137,,39366.0,10.0,2.0,0.0,,['additional'],False,False,
223 | 273,,120919.0,1002.0,2.0,0.0,,['additional'],False,False,
224 | 285,,194.0,29.0,8.0,0.0,,['additional'],False,False,
225 | 311,,937.0,50.0,2.0,0.0,,['additional'],False,False,
226 | 333,,556.0,7.0,2.0,0.0,,['additional'],False,False,
227 | 335,,554.0,7.0,2.0,0.0,,['additional'],False,False,
228 | 336,,267.0,23.0,2.0,0.0,,['additional'],False,False,
229 | 337,,349.0,45.0,2.0,0.0,,['additional'],False,False,
230 | 338,,155.0,9.0,4.0,0.0,,['additional'],False,False,
231 | 382,,7019.0,61.0,8.0,48089.0,,['additional'],False,False,
232 | 384,,336.0,7903.0,6.0,0.0,,['additional'],False,False,
233 | 387,,414.0,6430.0,9.0,0.0,,['additional'],False,False,
234 | 389,,2463.0,2001.0,17.0,0.0,,['additional'],False,False,
235 | 396,,3204.0,13196.0,6.0,0.0,,['additional'],False,False,
236 | 446,,200.0,8.0,2.0,0.0,,['additional'],False,False,
237 | 449,,163.0,27.0,5.0,9.0,,['additional'],False,False,
238 | 450,,264.0,5.0,2.0,0.0,,['additional'],False,False,
239 | 452,,285.0,8.0,7.0,27.0,,['additional'],False,False,
240 | 460,,379.0,8.0,4.0,1418.0,,['additional'],False,False,
241 | 463,,180.0,32.0,2.0,0.0,,['additional'],False,False,
242 | 464,,250.0,3.0,2.0,0.0,,['additional'],False,False,
243 | 466,,340.0,15.0,2.0,834.0,,['additional'],False,False,
244 | 474,,364.0,33.0,6.0,101.0,,['additional'],False,False,
245 | 475,,400.0,6.0,4.0,0.0,,['additional'],False,False,
246 | 481,,209.0,9.0,2.0,15.0,,['additional'],False,False,
247 | 679,,1024.0,3.0,4.0,0.0,,['additional'],False,False,
248 | 694,,310.0,9.0,9.0,0.0,,['additional'],False,False,
249 | 717,,508.0,11.0,2.0,0.0,,['additional'],False,False,
250 | 721,,200.0,11.0,2.0,0.0,,['additional'],False,False,
251 | 724,,468.0,4.0,2.0,0.0,,['additional'],False,False,
252 | 733,,209.0,7.0,2.0,0.0,,['additional'],False,False,
253 | 738,,195.0,11.0,2.0,2.0,,['additional'],False,False,
254 | 745,,159.0,16.0,2.0,0.0,,['additional'],False,False,
255 | 746,,250.0,26.0,2.0,0.0,,['additional'],False,False,
256 | 747,,167.0,5.0,2.0,0.0,,['additional'],False,False,
257 | 748,,163.0,6.0,2.0,0.0,,['additional'],False,False,
258 | 750,,500.0,8.0,2.0,0.0,,['additional'],False,False,
259 | 753,,194.0,33.0,2.0,0.0,,['additional'],False,False,
260 | 757,,528.0,22.0,2.0,504.0,,['additional'],False,False,
261 | 763,,250.0,11.0,2.0,0.0,,['additional'],False,False,
262 | 764,,450.0,4.0,2.0,0.0,,['additional'],False,False,
263 | 765,,475.0,4.0,2.0,0.0,,['additional'],False,False,
264 | 767,,475.0,4.0,2.0,0.0,,['additional'],False,False,
265 | 774,,662.0,4.0,2.0,0.0,,['additional'],False,False,
266 | 778,,252.0,15.0,2.0,0.0,,['additional'],False,False,
267 | 779,,500.0,26.0,2.0,0.0,,['additional'],False,False,
268 | 786,,303.0,14.0,2.0,6.0,,['additional'],False,False,
269 | 788,,186.0,61.0,2.0,0.0,,['additional'],False,False,
270 | 795,,662.0,4.0,2.0,0.0,,['additional'],False,False,
271 | 796,,209.0,8.0,2.0,0.0,,['additional'],False,False,
272 | 798,,303.0,14.0,2.0,6.0,,['additional'],False,False,
273 | 801,,185.0,3.0,2.0,0.0,,['additional'],False,False,
274 | 802,,1945.0,19.0,2.0,1133.0,,['additional'],False,False,
275 | 805,,500.0,51.0,2.0,0.0,,['additional'],False,False,
276 | 806,,1000.0,51.0,2.0,0.0,,['additional'],False,False,
277 | 810,,418.0,19.0,2.0,1239.0,,['additional'],False,False,
278 | 811,,264.0,3.0,2.0,0.0,,['additional'],False,False,
279 | 814,,468.0,3.0,2.0,0.0,,['additional'],False,False,
280 | 816,,8192.0,9.0,2.0,0.0,,['additional'],False,False,
281 | 820,,235.0,13.0,2.0,0.0,,['additional'],False,False,
282 | 825,,506.0,21.0,2.0,0.0,,['additional'],False,False,
283 | 827,,662.0,4.0,2.0,0.0,,['additional'],False,False,
284 | 831,,398.0,8.0,2.0,6.0,,['additional'],False,False,
285 | 837,,1000.0,51.0,2.0,0.0,,['additional'],False,False,
286 | 838,,500.0,26.0,2.0,0.0,,['additional'],False,False,
287 | 839,,782.0,9.0,2.0,466.0,,['additional'],False,False,
288 | 840,,205.0,26.0,2.0,57.0,,['additional'],False,False,
289 | 841,,950.0,10.0,2.0,0.0,,['additional'],False,False,
290 | 843,,22784.0,9.0,2.0,0.0,,['additional'],False,False,
291 | 844,,286.0,10.0,2.0,9.0,,['additional'],False,False,
292 | 852,,159.0,10.0,2.0,6.0,,['additional'],False,False,
293 | 854,,158.0,8.0,2.0,87.0,,['additional'],False,False,
294 | 860,,380.0,3.0,2.0,0.0,,['additional'],False,False,
295 | 880,,284.0,11.0,2.0,0.0,,['additional'],False,False,
296 | 886,,500.0,8.0,2.0,0.0,,['additional'],False,False,
297 | 895,,222.0,3.0,2.0,0.0,,['additional'],False,False,
298 | 896,,500.0,26.0,2.0,0.0,,['additional'],False,False,
299 | 900,,400.0,7.0,2.0,0.0,,['additional'],False,False,
300 | 903,,1000.0,26.0,2.0,0.0,,['additional'],False,False,
301 | 906,,400.0,8.0,2.0,0.0,,['additional'],False,False,
302 | 907,,400.0,8.0,2.0,0.0,,['additional'],False,False,
303 | 908,,400.0,8.0,2.0,0.0,,['additional'],False,False,
304 | 909,,400.0,8.0,2.0,0.0,,['additional'],False,False,
305 | 911,,250.0,6.0,2.0,0.0,,['additional'],False,False,
306 | 912,,1000.0,6.0,2.0,0.0,,['additional'],False,False,
307 | 913,,1000.0,11.0,2.0,0.0,,['additional'],False,False,
308 | 915,,315.0,14.0,2.0,0.0,,['additional'],False,False,
309 | 925,,323.0,5.0,2.0,0.0,,['additional'],False,False,
310 | 930,,1302.0,34.0,2.0,7830.0,,['additional'],False,False,
311 | 931,,662.0,4.0,2.0,0.0,,['additional'],False,False,
312 | 939,,228.0,9.0,2.0,20.0,,['additional'],False,False,
313 | 940,,527.0,37.0,2.0,542.0,,['additional'],False,False,
314 | 941,,189.0,10.0,2.0,0.0,,['additional'],False,False,
315 | 943,,500.0,11.0,2.0,0.0,,['additional'],False,False,
316 | 949,,559.0,5.0,2.0,0.0,,['additional'],False,False,
317 | 957,,412.0,9.0,2.0,96.0,,['additional'],False,False,
318 | 966,,1340.0,17.0,2.0,20.0,,['additional'],False,False,
319 | 968,,365.0,4.0,2.0,30.0,,['additional'],False,False,
320 | 981,,10108.0,69.0,2.0,2699.0,,['additional'],False,False,
321 | 984,,366.0,5.0,2.0,1.0,,['additional'],False,False,
322 | 987,,500.0,23.0,2.0,0.0,,['additional'],False,False,
323 | 996,,214.0,10.0,2.0,0.0,,['additional'],False,False,
324 | 1000,,3772.0,30.0,2.0,6064.0,,['additional'],False,False,
325 | 1002,,7485.0,56.0,2.0,32427.0,,['additional'],False,False,
326 | 1016,,990.0,14.0,2.0,0.0,,['additional'],False,False,
327 | 1018,,8844.0,57.0,2.0,34843.0,,['additional'],False,False,
328 | 1037,,4562.0,15.0,2.0,88.0,,['additional'],False,False,
329 | 1042,,3468.0,785.0,2.0,0.0,,['additional'],False,False,
330 | 1048,,369.0,9.0,2.0,0.0,,['additional'],False,False,
331 | 1054,,161.0,40.0,2.0,0.0,,['additional'],False,False,
332 | 1071,,403.0,38.0,2.0,0.0,,['additional'],False,False,
333 | 1073,,274.0,9.0,2.0,0.0,,['additional'],False,False,
334 | 1100,,478.0,11.0,3.0,0.0,,['additional'],False,False,
335 | 1112,,50000.0,231.0,2.0,8024152.0,,['additional'],False,False,
336 | 1115,,151.0,7.0,3.0,0.0,,['additional'],False,False,
337 | 1126,,412.0,10936.0,2.0,0.0,,['additional'],False,False,
338 | 1129,,384.0,10936.0,2.0,0.0,,['additional'],False,False,
339 | 1130,,1545.0,10936.0,2.0,0.0,,['additional'],False,False,
340 | 1131,,193.0,10936.0,2.0,0.0,,['additional'],False,False,
341 | 1142,,1545.0,10936.0,2.0,0.0,,['additional'],False,False,
342 | 1156,,275.0,10936.0,2.0,0.0,,['additional'],False,False,
343 | 1162,,322.0,10936.0,2.0,0.0,,['additional'],False,False,
344 | 1412,,226.0,24.0,2.0,0.0,,['additional'],False,False,
345 | 1442,,253.0,38.0,2.0,0.0,,['additional'],False,False,
346 | 1443,,661.0,38.0,2.0,0.0,,['additional'],False,False,
347 | 1444,,1043.0,38.0,2.0,0.0,,['additional'],False,False,
348 | 1446,,296.0,38.0,2.0,0.0,,['additional'],False,False,
349 | 1447,,327.0,38.0,2.0,0.0,,['additional'],False,False,
350 | 1448,,194.0,40.0,2.0,0.0,,['additional'],False,False,
351 | 1451,,705.0,38.0,2.0,0.0,,['additional'],False,False,
352 | 1452,,745.0,37.0,2.0,0.0,,['additional'],False,False,
353 | 1453,,1077.0,38.0,2.0,0.0,,['additional'],False,False,
354 | 1481,,28056.0,7.0,18.0,0.0,,['additional'],False,False,
355 | 1488,,195.0,23.0,2.0,0.0,,['additional'],False,False,
356 | 1490,,182.0,13.0,2.0,0.0,,['additional'],False,False,
357 | 1495,,250.0,7.0,2.0,0.0,,['additional'],False,False,
358 | 1498,,462.0,10.0,2.0,0.0,,['additional'],False,False,
359 | 1499,,210.0,8.0,3.0,0.0,,['additional'],False,False,
360 | 1503,,263256.0,15.0,10.0,0.0,,['additional'],False,False,
361 | 1506,,470.0,17.0,2.0,0.0,,['additional'],False,False,
362 | 1507,,7400.0,21.0,2.0,0.0,,['additional'],False,False,
363 | 1508,,403.0,6.0,5.0,0.0,,['additional'],False,False,
364 | 1511,,440.0,9.0,2.0,0.0,,['additional'],False,False,
365 | 1512,,200.0,14.0,5.0,0.0,,['additional'],False,False,
366 | 1520,,164.0,91.0,5.0,0.0,,['additional'],False,False,
367 | 1523,,310.0,7.0,3.0,0.0,,['additional'],False,False,
368 | 1531,,10176.0,4.0,5.0,0.0,,['additional'],False,False,
369 | 1532,,10668.0,4.0,5.0,0.0,,['additional'],False,False,
370 | 1538,,8753.0,4.0,5.0,0.0,,['additional'],False,False,
371 | 1541,,8654.0,4.0,5.0,0.0,,['additional'],False,False,
372 | 1546,,1112.0,4.0,5.0,0.0,,['additional'],False,False,
373 | 1547,,1000.0,21.0,2.0,0.0,,['additional'],False,False,
374 | 4153,,180.0,68.0,6.0,0.0,,['additional'],False,False,
375 | 23499,,277.0,10.0,2.0,0.0,,['additional'],False,False,
376 | 40646,,1600.0,21.0,2.0,0.0,,['additional'],False,False,
377 | 40663,,399.0,33.0,5.0,0.0,,['additional'],False,False,
378 | 40669,,160.0,7.0,2.0,0.0,,['additional'],False,False,
379 | 40680,,1324.0,11.0,2.0,0.0,,['additional'],False,False,
380 | 40682,,215.0,6.0,3.0,0.0,,['additional'],False,False,
381 | 40690,,512.0,10.0,2.0,0.0,,['additional'],False,False,
382 | 40693,,973.0,10.0,2.0,0.0,,['additional'],False,False,
383 | 40705,,959.0,45.0,2.0,0.0,,['additional'],False,False,
384 | 40706,,1124.0,11.0,2.0,0.0,,['additional'],False,False,
385 | 40710,,303.0,14.0,2.0,0.0,,['additional'],False,False,
386 | 40711,,303.0,8.0,5.0,0.0,,['additional'],False,False,
387 | 41430,,281.0,98.0,2.0,2.0,,['additional'],False,False,
388 | 41538,,246.0,7.0,2.0,0.0,,['additional'],False,False,
389 | 41919,,527.0,23.0,4.0,0.0,,['additional'],False,False,
390 | 41976,,156.0,81.0,2.0,0.0,,['additional'],False,False,
391 | 42172,,202.0,20.0,2.0,17.0,,['additional'],False,False,
392 | 42261,,150.0,5.0,3.0,0.0,,['additional'],False,False,
393 | 42544,,265.0,11.0,8.0,0.0,,['additional'],False,False,
394 | 42585,,344.0,7.0,3.0,18.0,,['additional'],False,False,
395 | 42638,,891.0,8.0,2.0,689.0,,['additional'],False,False,
396 |
--------------------------------------------------------------------------------