├── figures ├── scaling.png └── performance-comparison.png ├── tabdpt_datasets ├── __pycache__ │ ├── dataset.cpython-311.pyc │ ├── openml.cpython-311.pyc │ └── __init__.cpython-311.pyc ├── cc_test_datasets_multiclass.pickle ├── cc_valid_datasets_multiclass.pickle ├── tabred.py ├── annotated_tables.py ├── dataset.py ├── __init__.py ├── talent.py ├── openml.py └── catalogue.py ├── predict.py ├── data_splits ├── noleak_training_datasets_anysize.csv ├── noleak_training_datasets.csv ├── reg_datasets.csv └── cls_datasets.csv ├── configs └── default_config.yaml ├── LICENSE ├── README.md ├── transformer_layer.py ├── requirements.txt ├── eval_full.py ├── model.py ├── utils.py ├── train.py ├── dataset.py └── tabdpt.py /figures/scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/figures/scaling.png -------------------------------------------------------------------------------- /figures/performance-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/figures/performance-comparison.png -------------------------------------------------------------------------------- /tabdpt_datasets/__pycache__/dataset.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/__pycache__/dataset.cpython-311.pyc -------------------------------------------------------------------------------- /tabdpt_datasets/__pycache__/openml.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/__pycache__/openml.cpython-311.pyc -------------------------------------------------------------------------------- /tabdpt_datasets/cc_test_datasets_multiclass.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/cc_test_datasets_multiclass.pickle -------------------------------------------------------------------------------- /tabdpt_datasets/cc_valid_datasets_multiclass.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/cc_valid_datasets_multiclass.pickle -------------------------------------------------------------------------------- /tabdpt_datasets/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layer6ai-labs/TabDPT-training/HEAD/tabdpt_datasets/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import fetch_california_housing, load_breast_cancer 2 | from sklearn.metrics import accuracy_score, r2_score 3 | from sklearn.model_selection import train_test_split 4 | 5 | from tabdpt import TabDPTClassifier, TabDPTRegressor 6 | 7 | # classification example 8 | X, y = load_breast_cancer(return_X_y=True) 9 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) 10 | 11 | model = TabDPTClassifier(path="./checkpoints/latest.ckpt") 12 | model.fit(X_train, y_train) 13 | y_pred = model.predict(X_test, temperature=0.8, context_size=1024, use_retrieval=True) 14 | print("classification accuracy score = ", accuracy_score(y_test, y_pred)) 15 | 16 | 17 | # regression example 18 | X, y = fetch_california_housing(return_X_y=True) 19 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) 20 | 21 | model = TabDPTRegressor(path="./checkpoints/latest.ckpt") 22 | model.fit(X_train, y_train) 23 | y_pred = model.predict(X_test, context_size=512, use_retrieval=True) 24 | print("regression r2 score =", r2_score(y_test, y_pred)) 25 | -------------------------------------------------------------------------------- /data_splits/noleak_training_datasets_anysize.csv: -------------------------------------------------------------------------------- 1 | did,task 2 | 4135,cls 3 | 41434,cls 4 | 375,cls 5 | 1120,cls 6 | 40900,cls 7 | 1043,cls 8 | 1119,cls 9 | 1,cls 10 | 1459,cls 11 | 1466,cls 12 | 23380,cls 13 | 1471,cls 14 | 846,cls 15 | 1044,cls 16 | 821,cls 17 | 40679,cls 18 | 334,cls 19 | 24,cls 20 | 1568,cls 21 | 30,cls 22 | 871,cls 23 | 470,cls 24 | 41146,cls 25 | 377,cls 26 | 40733,cls 27 | 44089,cls 28 | 44122,cls 29 | 45020,cls 30 | 45028,cls 31 | 45026,cls 32 | 45039,cls 33 | 41156,cls 34 | 137,cls 35 | 311,cls 36 | 333,cls 37 | 335,cls 38 | 382,cls 39 | 717,cls 40 | 757,cls 41 | 802,cls 42 | 816,cls 43 | 825,cls 44 | 839,cls 45 | 841,cls 46 | 843,cls 47 | 930,cls 48 | 940,cls 49 | 949,cls 50 | 966,cls 51 | 981,cls 52 | 1002,cls 53 | 1018,cls 54 | 1037,cls 55 | 1443,cls 56 | 1444,cls 57 | 1451,cls 58 | 1452,cls 59 | 1453,cls 60 | 1507,cls 61 | 40646,cls 62 | 40680,cls 63 | 40690,cls 64 | 40693,cls 65 | 40705,cls 66 | 40706,cls 67 | 41919,cls 68 | 42638,cls 69 | 44055,reg 70 | 44056,reg 71 | 44063,reg 72 | 44136,reg 73 | 44137,reg 74 | 44145,reg 75 | 45032,reg 76 | 546,reg 77 | 42724,reg 78 | 42727,reg 79 | 531,reg 80 | 42563,reg 81 | -------------------------------------------------------------------------------- /data_splits/noleak_training_datasets.csv: -------------------------------------------------------------------------------- 1 | did 2 | 41138 3 | 4135 4 | 4535 5 | 41434 6 | 375 7 | 1120 8 | 41150 9 | 40900 10 | 40536 11 | 1043 12 | 1169 13 | 41147 14 | 1459 15 | 1466 16 | 1118 17 | 41142 18 | 23380 19 | 1596 20 | 41163 21 | 1471 22 | 846 23 | 1044 24 | 41164 25 | 1477 26 | 1476 27 | 41159 28 | 23512 29 | 1479 30 | 821 31 | 41168 32 | 41143 33 | 184 34 | 1483 35 | 40679 36 | 24 37 | 1116 38 | 1568 39 | 1493 40 | 30 41 | 41145 42 | 1567 43 | 871 44 | 41161 45 | 41165 46 | 312 47 | 40685 48 | 1036 49 | 41146 50 | 41166 51 | 1509 52 | 40733 53 | 44089 54 | 44122 55 | 45022 56 | 45020 57 | 45026 58 | 45038 59 | 45039 60 | 1111 61 | 1457 62 | 41167 63 | 41144 64 | 41156 65 | 41169 66 | 41162 67 | 42734 68 | 42732 69 | 42746 70 | 42742 71 | 43072 72 | 273 73 | 382 74 | 389 75 | 396 76 | 802 77 | 816 78 | 843 79 | 930 80 | 966 81 | 981 82 | 1002 83 | 1018 84 | 1037 85 | 1112 86 | 1130 87 | 1142 88 | 1444 89 | 1453 90 | 1481 91 | 1503 92 | 1507 93 | 40646 94 | 40680 95 | 40706 96 | 44055 97 | 44056 98 | 44061 99 | 44063 100 | 44065 101 | 44068 102 | 44069 103 | 45041 104 | 45043 105 | 45045 106 | 45046 107 | 45047 108 | 44136 109 | 44137 110 | 44145 111 | 45032 112 | 4549 113 | 42572 114 | 42705 115 | 42728 116 | 41540 117 | 42724 118 | 42727 119 | 42730 120 | 41980 121 | 42563 122 | 3050 123 | 3277 124 | 43071 125 | -------------------------------------------------------------------------------- /configs/default_config.yaml: -------------------------------------------------------------------------------- 1 | version: 0.0.1 2 | description: TabDPT 3 | seed: 42 4 | exp_name: "default" 5 | folder: "default" 6 | exp_path: '' 7 | 8 | env: 9 | device: 'cuda:0' 10 | gpus: [0, 1] 11 | num_workers: 32 12 | 13 | model: 14 | emsize: 512 15 | max_num_classes: 10 16 | max_num_features: 100 17 | # number of heads in the transformer, 8 worked slightly better than 4 18 | nhead: 8 19 | nhid_factor: 2 20 | nlayers: 12 21 | 22 | 23 | training: 24 | num_epochs: 2048 25 | num_model_updates: 128 26 | num_agg: 1 # gradient aggreation steps: to be adjusted based on the number of gpus 27 | batch_size: 256 # to be adjusted based on the number of gpus 28 | lr: 0.0005 29 | weight_decay: 0.05 30 | dropout: 0.0 31 | # minimum number of elements in the context: too low means lots of noise, too high means we never see small contexts 32 | min_eval_pos: 50 33 | # maximum number of elements in the context 34 | max_eval_pos: 1024 35 | # Fixed size sequence length. Number of queries if seq_len - max_eval_pos 36 | # ensure seq_len >= max_eval_pos + constant where constant is the minimum number of queries 37 | # Too small also means a lot of noise 38 | seq_len: 1536 39 | # seq len for eval 40 | eval_seq_len: 1024 41 | label_smoothing: 0.1 42 | reset_policy: 'rm' # reset policy: 'rm' for remove, 'cnt' for continue 43 | compile: true 44 | clip_grad_norm: 1 45 | 46 | data: 47 | y_reg_augment: true 48 | # retreival during training 49 | retrieval: true 50 | # retrieval during eval: this may take more memory than the training, be careful with memory 51 | eval_retrieval: true 52 | 53 | logging: 54 | eval_every: 10 55 | save_metrics: 56 | - valid_agg 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Notice: © Copyright 2018 The Toronto-Dominion Bank and/or its affiliates 2 | 3 | Permission is hereby granted, subject to the conditions below, free of charge, to 4 | any person obtaining a copy of this software and associated documentation files 5 | (the "Software"), to use, copy, distribute, publish and modify the Software, only 6 | for research and for no other purpose. For clarity, and without limitation, this 7 | licence does not permit use of Software or any part thereof for commercial purposes. 8 | 9 | Patents: This permission does not grant any patent licenses in the Software. 10 | 11 | Conditions: 12 | 1. The above copyright notice and the following disclaimer shall be included in all 13 | copies or substantial portions of the Software. 14 | 2. You must give appropriate credit, provide a link to the license, and indicate if 15 | changes were made. You may do so in any reasonable manner, but not in any way that 16 | suggests that the copyright holder endorses you or your use of the Software. 17 | 18 | Names and Trademarks: No permission to use the names or trademarks of the copyright 19 | holder are granted, except as required for reasonable and customary use in describing 20 | the origin of the Software and reproducing the content of the copyright notice. 21 | 22 | DISCLAIMER: THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 24 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 25 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF 26 | CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE 27 | OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # TabDPT: Scaling Tabular Foundation Models on Real Data 4 | 5 | [![arxiv](https://img.shields.io/static/v1?label=arXiv&message=2410.18164&color=B31B1B&logo=arXiv)](https://arxiv.org/abs/2410.18164) 6 | [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-FFD21E)](https://huggingface.co/Layer6/TabDPT) 7 | 8 |
9 | 10 | **TabDPT** is an open-source foundation model for tabular data based on in-context learning. It is trained on real-world data and can generalize to new tasks **without** additional training or hyperparameter tuning. 11 | 12 | This repository provides the full training code to build your own TabDPT model. A lightweight inference interface is available [here](https://github.com/layer6ai-labs/TabDPT-inference), which can support the evaluation of either the existing TabDPT model or any new models that are trained using this repository. 13 | 14 | 15 | ## Usage 16 | 17 | We provide basic usage tips below. The details can be found by stepping through the code. 18 | 19 | ### Installation 20 | 21 | Before running the code, make sure to install the required Python packages: 22 | 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | You will also need a C compiler such as `gcc` for building some dependencies. On Ubuntu, you can install it with: 28 | 29 | ``` 30 | sudo apt-get update 31 | sudo apt-get install build-essential 32 | ``` 33 | 34 | ### Training Example 35 | 36 | 37 | To train a fresh TabDPT model with default hyperparameters on a single GPU, use the following command: 38 | 39 | ``` 40 | CUDA_VISIBLE_DEVICES=0 python train.py exp_name="TabDPT" 41 | ``` 42 | 43 | #### Multi GPU Training Example 44 | 45 | If instead you want to use Multi-GPU, do the following: 46 | ``` 47 | CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29500 train.py \ 48 | env.gpus="[4,5,6,7]" \ 49 | exp_name="my_multi_gpu_test" 50 | ``` 51 | 52 | **Notes:** 53 | - Adjust `nproc_per_node` to the number of GPUs. 54 | - If there are communication issues when using several multi gpu training runs on the same node, change the `rdzv_endpoint` port as it can be maxxed out. 55 | 56 | 57 | ## Citation and Acknowledgements 58 | 59 | If citing the paper, please use the following BibTeX: 60 | 61 | ``` 62 | @article{ma2024tabdpt, 63 | title={TabDPT: Scaling Tabular Foundation Models on Real Data}, 64 | author={Ma, Junwei and Thomas, Valentin and Hosseinzadeh, Rasa and Kamkari, Hamidreza and Labach, Alex and Cresswell, Jesse C and Golestan, Keyvan and Yu, Guangwei and Caterini, Anthony L and Volkovs, Maksims}, 65 | journal={arXiv preprint arXiv:2410.18164}, 66 | year={2024} 67 | } 68 | ``` 69 | 70 | Additionally, a huge thank you to [Nafiseh Ghoroghchian](https://github.com/NaGho) for spearheading the effort of refactoring and making this codebase fit for pubilc consumption, and thank you to [Roc Zhang](https://github.com/Zhang-Haipeng) for making the codebase compatible with `safetensors`. 71 | -------------------------------------------------------------------------------- /transformer_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.nn import GELU, LayerNorm, Linear 5 | 6 | 7 | class TransformerEncoderLayer(nn.Module): 8 | """Transformer Encoder Layer use in TabDPT.""" 9 | 10 | def __init__(self, embed_dim: int, num_heads: int, ff_dim: int) -> None: 11 | """ 12 | Args: 13 | embed_dim (int): Dimension of the embedding. 14 | num_heads (int): Number of attention heads. 15 | ff_dim (int): Dimension of the feed-forward network. 16 | """ 17 | super().__init__() 18 | self.embed_dim = embed_dim 19 | self.head_dim = embed_dim // num_heads 20 | self.num_heads = num_heads 21 | self.kv_proj = nn.Linear(embed_dim, 2 * embed_dim, bias=False) 22 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) 23 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) 24 | self.attn_norm = LayerNorm(embed_dim) 25 | self.ff_norm = LayerNorm(embed_dim) 26 | self.ff = nn.Sequential(Linear(embed_dim, ff_dim), GELU(), Linear(ff_dim, embed_dim)) 27 | self.q_norm = LayerNorm(self.head_dim) 28 | self.k_norm = LayerNorm(self.head_dim) 29 | 30 | def forward(self, x: torch.Tensor, eval_pos: int) -> torch.Tensor: 31 | """ 32 | Args: 33 | x (torch.tensor): Input tensor of shape (L, B, D) where B is batch size, L is sequence length, and D is embedding dimension. 34 | eval_pos (int): Evaluation position used for slicing attention keys and values. 35 | Returns: 36 | torch.tensor: Output tensor of the same shape as input. 37 | """ 38 | # switch to (B, L, D) for attention computation 39 | x = x.transpose(0, 1) 40 | B, L, _ = x.size() 41 | 42 | # Normalize the input 43 | h = self.attn_norm(x) 44 | 45 | # project to query, key, and value with linear layers 46 | q = self.q_proj(h) 47 | # slice the key and value projections to the evaluation position 48 | k, v = self.kv_proj(h[:, :eval_pos]).chunk(2, dim=-1) 49 | 50 | # reshape and transpose for multi-head attention 51 | # q: (B, L, D) -> (B, L, num_heads, head_dim) 52 | # k: (B, eval_pos, D) -> (B, num_heads, eval_pos, head_dim) 53 | # v: (B, eval_pos, D) -> (B, num_heads, eval_pos, head_dim) 54 | q = q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) 55 | k = k.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2) 56 | v = v.view(B, eval_pos, self.num_heads, self.head_dim).transpose(1, 2) 57 | 58 | # apply layer normalization to query and key 59 | q, k = self.q_norm(q), self.k_norm(k) 60 | 61 | # compute scaled dot-product attention 62 | attn = F.scaled_dot_product_attention(q, k, v).transpose(1, 2) 63 | attn = self.out_proj(attn.reshape(B, L, self.num_heads * self.head_dim)) 64 | 65 | # residual connection and feed-forward network 66 | x = x + attn 67 | x = x + self.ff(self.ff_norm(x)) 68 | 69 | # back to (L, B, D) for output 70 | return x.transpose(0, 1) 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.3.0 2 | antlr4-python3-runtime==4.9.3 3 | anyio==4.9.0 4 | argon2-cffi==23.1.0 5 | argon2-cffi-bindings==21.2.0 6 | arrow==1.3.0 7 | asttokens==3.0.0 8 | async-lru==2.0.5 9 | attrs==25.3.0 10 | babel==2.17.0 11 | beautifulsoup4==4.13.4 12 | bleach==6.2.0 13 | certifi==2025.4.26 14 | cffi==1.17.1 15 | charset-normalizer==3.4.2 16 | comm==0.2.2 17 | config==0.5.1 18 | debugpy==1.8.14 19 | decorator==5.2.1 20 | defusedxml==0.7.1 21 | executing==2.2.0 22 | faiss-cpu==1.11.0 23 | fastjsonschema==2.21.1 24 | filelock==3.18.0 25 | fqdn==1.5.1 26 | fsspec==2025.5.0 27 | gdown==5.2.0 28 | grpcio==1.71.0 29 | h11==0.16.0 30 | hf-xet==1.1.2 31 | httpcore==1.0.9 32 | httpx==0.28.1 33 | huggingface-hub==0.32.2 34 | hydra-core==1.3.2 35 | idna==3.10 36 | ipykernel==6.29.5 37 | ipython==9.2.0 38 | ipython_pygments_lexers==1.1.1 39 | ipywidgets==8.1.7 40 | isoduration==20.11.0 41 | jedi==0.19.2 42 | Jinja2==3.1.6 43 | joblib==1.5.1 44 | json5==0.12.0 45 | jsonpointer==3.0.0 46 | jsonschema==4.24.0 47 | jsonschema-specifications==2025.4.1 48 | jupyter==1.1.1 49 | jupyter-console==6.6.3 50 | jupyter-events==0.12.0 51 | jupyter-lsp==2.2.5 52 | jupyter_client==8.6.3 53 | jupyter_core==5.8.1 54 | jupyter_server==2.16.0 55 | jupyter_server_terminals==0.5.3 56 | jupyterlab==4.4.3 57 | jupyterlab_pygments==0.3.0 58 | jupyterlab_server==2.27.3 59 | jupyterlab_widgets==3.0.15 60 | kaggle==1.7.4.5 61 | liac-arff==2.5.0 62 | Markdown==3.8 63 | MarkupSafe==3.0.2 64 | matplotlib==3.10.0 65 | matplotlib-inline==0.1.7 66 | minio==7.2.15 67 | mistune==3.1.3 68 | mpmath==1.3.0 69 | nbclient==0.10.2 70 | nbconvert==7.16.6 71 | nbformat==5.10.4 72 | nest-asyncio==1.6.0 73 | networkx==3.4.2 74 | notebook==7.4.3 75 | notebook_shim==0.2.4 76 | omegaconf==2.3.0 77 | openml==0.15.1 78 | overrides==7.7.0 79 | pandocfilters==1.5.1 80 | parso==0.8.4 81 | pexpect==4.9.0 82 | platformdirs==4.3.8 83 | polars==1.30.0 84 | prometheus_client==0.22.0 85 | prompt_toolkit==3.0.51 86 | protobuf==6.31.1 87 | psutil==7.0.0 88 | ptyprocess==0.7.0 89 | pure_eval==0.2.3 90 | pyarrow==20.0.0 91 | pycparser==2.22 92 | pycryptodome==3.23.0 93 | Pygments==2.19.1 94 | PyQt6==6.7.1 95 | PySocks==1.7.1 96 | python-json-logger==3.3.0 97 | python-slugify==8.0.4 98 | PyYAML==6.0.2 99 | pyzmq==26.4.0 100 | referencing==0.36.2 101 | regex==2024.11.6 102 | requests==2.32.3 103 | rfc3339-validator==0.1.4 104 | rfc3986-validator==0.1.1 105 | rpds-py==0.25.1 106 | safetensors==0.5.3 107 | schedulefree==1.4.1 108 | scikit-learn==1.6.1 109 | scipy==1.15.3 110 | Send2Trash==1.8.3 111 | setuptools==72.1.0 112 | sniffio==1.3.1 113 | soupsieve==2.7 114 | stack-data==0.6.3 115 | sympy==1.14.0 116 | tensorboard==2.19.0 117 | tensorboard-data-server==0.7.2 118 | terminado==0.18.1 119 | text-unidecode==1.3 120 | threadpoolctl==3.6.0 121 | tinycss2==1.4.0 122 | tokenizers==0.21.1 123 | torch==2.7.0 124 | torchaudio==2.7.0 125 | torchvision==0.22.0 126 | tqdm==4.67.1 127 | traitlets==5.14.3 128 | transformers==4.52.3 129 | types-python-dateutil==2.9.0.20250516 130 | typing_extensions==4.13.2 131 | uri-template==1.3.0 132 | urllib3==2.4.0 133 | wcwidth==0.2.13 134 | webcolors==24.11.1 135 | webencodings==0.5.1 136 | websocket-client==1.8.0 137 | Werkzeug==3.1.3 138 | wheel==0.45.1 139 | widgetsnbextension==4.0.14 140 | xmltodict==0.14.2 141 | -------------------------------------------------------------------------------- /tabdpt_datasets/tabred.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | 6 | import tabdpt_datasets.external.tabred.cooking_time as cooking_time 7 | import tabdpt_datasets.external.tabred.delivery_eta as delivery_eta 8 | import tabdpt_datasets.external.tabred.maps_routing as maps_routing 9 | import tabdpt_datasets.external.tabred.weather as weather 10 | from tabdpt_datasets.dataset import Dataset 11 | 12 | 13 | class TabredDataset(Dataset): 14 | @staticmethod 15 | def all_names(): 16 | return ["cooking_time", "delivery_eta", "maps_routing", "weather_forecasting"] 17 | 18 | @staticmethod 19 | def suite_name(): 20 | return "tabred" 21 | 22 | def __init__(self, name, task_id=None): 23 | """ 24 | task_id options correspond to the splits provided by the dataset: 25 | "split-default", "split-random-0", "split-random-1", "split-random-2", 26 | "split-sliding-window-0", "split-sliding-window-1", "split-sliding-window-2" 27 | """ 28 | 29 | if task_id is None: 30 | task_id = "split-default" 31 | assert task_id in ( 32 | "split-default", 33 | "split-random-0", 34 | "split-random-1", 35 | "split-random-2", 36 | "split-sliding-window-0", 37 | "split-sliding-window-1", 38 | "split-sliding-window-2", 39 | ) 40 | super().__init__(name, task_id) 41 | self.metadata["target_type"] = "regression" 42 | 43 | def prepare_data(self, download_dir): 44 | download_dir = Path(download_dir) 45 | 46 | if self.name == "cooking_time": 47 | download_fn = cooking_time.main 48 | sub_dir = "cooking-time" 49 | elif self.name == "delivery_eta": 50 | download_fn = delivery_eta.main 51 | sub_dir = "delivery-eta" 52 | elif self.name == "maps_routing": 53 | download_fn = maps_routing.main 54 | sub_dir = "maps-routing" 55 | elif self.name == "weather_forecasting": 56 | download_fn = weather.main 57 | sub_dir = "weather" 58 | 59 | if not all(os.path.exists(download_dir / sub_dir / f) for f in ("X_num.npy", "Y.npy")): 60 | download_fn(download_dir) 61 | X_mats = [] 62 | X_mats.append(np.load(download_dir / sub_dir / "X_num.npy")) 63 | if os.path.exists(download_dir / sub_dir / "X_bin.npy"): 64 | X_mats.append(np.load(download_dir / sub_dir / "X_bin.npy")) 65 | if os.path.exists(download_dir / sub_dir / "X_cat.npy"): 66 | X_cat = np.load(download_dir / sub_dir / "X_cat.npy") 67 | n_num = sum(X.shape[1] for X in X_mats) 68 | self.metadata["categorical_feature_inds"] = [n_num + i for i in range(X_cat.shape[1])] 69 | X_mats.append(X_cat) 70 | self.X = np.concatenate(X_mats, axis=1) 71 | self.y = np.load(download_dir / sub_dir / "Y.npy") 72 | self._train_inds = np.load(download_dir / sub_dir / self._task_id / "train_idx.npy") 73 | self._val_inds = np.load(download_dir / sub_dir / self._task_id / "val_idx.npy") 74 | self._test_inds = np.load(download_dir / sub_dir / self._task_id / "test_idx.npy") 75 | 76 | def all_instances(self): 77 | return self.X, self.y 78 | 79 | def train_inds(self): 80 | return self._train_inds 81 | 82 | def val_inds(self): 83 | return self._val_inds 84 | 85 | def test_inds(self): 86 | return self._test_inds 87 | -------------------------------------------------------------------------------- /tabdpt_datasets/annotated_tables.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import time 4 | from pathlib import Path 5 | 6 | import kaggle 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.preprocessing import OrdinalEncoder 10 | 11 | from tabdpt_datasets.dataset import Dataset 12 | from tabdpt_datasets.external.annotated_tables.ids_deduped import ANNOTATED_TABLE_IDS 13 | 14 | 15 | class AnnotatedTablesDataset(Dataset): 16 | """ 17 | Dataset of the tables used to evaluate TabPFN in https://arxiv.org/pdf/2406.16349, with 18 | duplicate datasets left out. Since Kaggle doesn't have a standard file layout, we just take the 19 | largest CSV file. Because of that, there's no target or test set, and this dataset should only 20 | be used for training. 21 | """ 22 | 23 | @staticmethod 24 | def all_names(): 25 | return [f"annotated_tables_{id}" for id in ANNOTATED_TABLE_IDS] 26 | 27 | @staticmethod 28 | def suite_name(): 29 | return "annotated_tables" 30 | 31 | def __init__(self, name, task_id=None, split_seed=0): 32 | """ 33 | Note that split_seed is only used for train/val split, since there's no test set. 34 | """ 35 | if task_id is None: 36 | task_id = f"random-seed{split_seed}" 37 | else: 38 | assert task_id.startswith("random-seed") 39 | split_seed = int(task_id.removeprefix("random-seed")) 40 | super().__init__(name, task_id) 41 | self.kaggle_id = name[len("annotated_tables_") :] 42 | self.rng = np.random.default_rng(split_seed) 43 | self.metadata["kaggle_dataset_name"] = self.kaggle_id.split("/")[-1] 44 | 45 | def prepare_data(self, download_dir): 46 | dataset_dir = os.path.join( 47 | download_dir, "annotated-tables", self.kaggle_id.replace("/", "-") 48 | ) 49 | if not os.path.exists(dataset_dir): 50 | os.makedirs(dataset_dir) 51 | try: 52 | kaggle.api.dataset_download_files(self.kaggle_id, dataset_dir, unzip=True) 53 | except kaggle.rest.ApiException as e: 54 | if e.status == 429: # Rate limit 55 | time.sleep(0.5) 56 | kaggle.api.dataset_download_files(self.kaggle_id, dataset_dir, unzip=True) 57 | elif e.status == 403: # Rate limit 58 | raise ValueError( 59 | f"HTTP 403 when downloading Kaggle dataset {self.kaggle_id}. It was " 60 | "probably deleted from Kaggle - try removing it from " 61 | "datasets/external/annotated_tables/ids.py" 62 | ) 63 | else: 64 | raise e 65 | csv_paths = sorted(Path(dataset_dir).rglob("*.csv"), key=lambda p: p.stat().st_size) 66 | if len(csv_paths) == 0: 67 | raise ValueError("Couldn not load data for " + self.name) 68 | 69 | hasher = hashlib.new("sha1") 70 | with open(csv_paths[-1], "rb") as f: 71 | hasher.update(f.read()) 72 | self.metadata["file_sha1"] = hasher.hexdigest() 73 | X = pd.read_csv(csv_paths[-1], low_memory=False) 74 | 75 | categorical_inds = [] 76 | to_drop = [] 77 | n_dropped = 0 78 | for i, col in enumerate(X.columns): 79 | if X[col].dtype == "object" or isinstance(X[col], pd.CategoricalDtype): 80 | enc = OrdinalEncoder() 81 | X[[col]] = enc.fit_transform(X[[col]]) 82 | # Drop high cardinality 83 | if len(enc.categories_[0]) > 100: 84 | to_drop.append(col) 85 | n_dropped += 1 86 | else: 87 | categorical_inds.append(i - n_dropped) 88 | X.drop(columns=to_drop, inplace=True) 89 | self.metadata["categorical_feature_inds"] = categorical_inds 90 | self.X = X.to_numpy().astype(np.float32) 91 | 92 | n = len(X) 93 | perm = self.rng.permutation(n) 94 | self._train_inds = perm[: int(n * 0.85)] 95 | self._val_inds = perm[int(n * 0.85) :] 96 | 97 | def all_instances(self): 98 | return self.X, None 99 | 100 | def train_inds(self): 101 | return self._train_inds 102 | 103 | def val_inds(self): 104 | return self._val_inds 105 | 106 | def test_inds(self): 107 | return [] 108 | -------------------------------------------------------------------------------- /data_splits/reg_datasets.csv: -------------------------------------------------------------------------------- 1 | did,tid,num_instances,num_features,num_missing_values,source,test 2 | 44956,361234,4177.0,9.0,0.0,['ctr23'],True 3 | 44957,361235,1503.0,6.0,0.0,['ctr23'],True 4 | 44958,361236,2043.0,8.0,0.0,['ctr23'],True 5 | 44959,361237,1030.0,9.0,0.0,['ctr23'],True 6 | 44963,361241,45730.0,10.0,0.0,['ctr23'],True 7 | 44964,361242,21263.0,82.0,0.0,['ctr23'],True 8 | 44965,361243,1059.0,117.0,0.0,['ctr23'],True 9 | 44966,361244,1066.0,11.0,0.0,['ctr23'],True 10 | 44969,361247,11934.0,15.0,0.0,['ctr23'],True 11 | 44971,361249,4898.0,12.0,0.0,['ctr23'],True 12 | 44972,361250,1599.0,12.0,0.0,['ctr23'],True 13 | 44973,361251,10000.0,13.0,0.0,['ctr23'],True 14 | 44974,361252,68784.0,19.0,0.0,['ctr23'],True 15 | 44975,361253,72000.0,49.0,0.0,['ctr23'],True 16 | 44976,361254,48933.0,22.0,0.0,['ctr23'],True 17 | 44977,361255,20640.0,9.0,0.0,['ctr23'],True 18 | 44978,361256,8192.0,22.0,0.0,['ctr23'],True 19 | 44979,361257,53940.0,10.0,0.0,['ctr23'],True 20 | 44980,361258,8192.0,9.0,0.0,['ctr23'],True 21 | 44981,361259,8192.0,33.0,0.0,['ctr23'],True 22 | 44983,361260,13932.0,16.0,0.0,['ctr23'],True 23 | 44984,361261,28155.0,7.0,0.0,['ctr23'],True 24 | 44987,361264,1156.0,6.0,0.0,['ctr23'],True 25 | 44989,361266,21613.0,22.0,0.0,['ctr23'],True 26 | 44990,361267,10692.0,10.0,0.0,['ctr23'],True 27 | 44992,361268,24624.0,44.0,69696.0,['ctr23'],True 28 | 44993,361269,22272.0,12.0,0.0,['ctr23'],True 29 | 45012,361272,19178.0,29.0,0.0,['ctr23'],True 30 | 41021,361616,1232.0,15.0,3600.0,"['ctr23', 'amlb_reg']",True 31 | 44960,361617,768.0,9.0,0.0,['ctr23'],True 32 | 44962,361618,517.0,13.0,0.0,['ctr23'],True 33 | 44967,361619,649.0,31.0,0.0,['ctr23'],True 34 | 44970,361621,908.0,7.0,0.0,['ctr23'],True 35 | 44994,361622,804.0,18.0,0.0,['ctr23'],True 36 | 45402,361623,3107.0,7.0,0.0,['ctr23'],True 37 | 44055,361093,4052.0,8.0,0.0,['grins_reg'],False 38 | 44056,361094,8641.0,5.0,0.0,['grins_reg'],False 39 | 44059,361096,53940.0,10.0,0.0,['grins_reg'],False 40 | 44061,361097,4209.0,360.0,0.0,['grins_reg'],False 41 | 44062,361098,10692.0,12.0,0.0,['grins_reg'],False 42 | 44063,361099,17379.0,12.0,0.0,['grins_reg'],False 43 | 44065,361101,581835.0,17.0,0.0,['grins_reg'],False 44 | 44066,361102,21613.0,18.0,0.0,['grins_reg'],False 45 | 44068,361103,394299.0,7.0,0.0,['grins_reg'],False 46 | 44069,361104,241600.0,10.0,0.0,['grins_reg'],False 47 | 45041,361287,8885.0,256.0,0.0,['grins_reg'],False 48 | 45042,361288,4177.0,9.0,0.0,['grins_reg'],False 49 | 45043,361289,52031.0,5.0,0.0,['grins_reg'],False 50 | 45045,361291,5465575.0,12.0,0.0,['grins_reg'],False 51 | 45046,361292,188318.0,125.0,0.0,['grins_reg'],False 52 | 45047,361293,1000000.0,6.0,0.0,['grins_reg'],False 53 | 45048,361294,163065.0,4.0,0.0,['grins_reg'],False 54 | 44132,361072,8192.0,22.0,0.0,['grins_reg'],False 55 | 44133,361073,15000.0,27.0,0.0,['grins_reg'],False 56 | 44134,361074,16599.0,17.0,0.0,['grins_reg'],False 57 | 44136,361076,6497.0,12.0,0.0,['grins_reg'],False 58 | 44137,361077,13750.0,34.0,0.0,['grins_reg'],False 59 | 44138,361078,20640.0,9.0,0.0,['grins_reg'],False 60 | 44139,361079,22784.0,17.0,0.0,['grins_reg'],False 61 | 44140,361080,53940.0,7.0,0.0,['grins_reg'],False 62 | 44141,361081,10692.0,9.0,0.0,['grins_reg'],False 63 | 44142,361082,17379.0,7.0,0.0,['grins_reg'],False 64 | 44143,361083,581835.0,10.0,0.0,['grins_reg'],False 65 | 44144,361084,21613.0,16.0,0.0,['grins_reg'],False 66 | 44145,361085,10081.0,7.0,0.0,['grins_reg'],False 67 | 44146,361086,163065.0,4.0,0.0,['grins_reg'],False 68 | 44147,361087,13932.0,14.0,0.0,['grins_reg'],False 69 | 44148,361088,21263.0,80.0,0.0,['grins_reg'],False 70 | 45032,361279,8885.0,43.0,0.0,['grins_reg'],False 71 | 45033,361280,4177.0,8.0,0.0,['grins_reg'],False 72 | 45034,361281,5465575.0,9.0,0.0,['grins_reg'],False 73 | 42225,233211,53940.0,10.0,0.0,['amlb_reg'],False 74 | 42571,233212,188318.0,131.0,0.0,['amlb_reg'],False 75 | 4549,233213,583250.0,78.0,0.0,['amlb_reg'],False 76 | 42572,233214,4459.0,4992.0,0.0,['amlb_reg'],False 77 | 42570,233215,4209.0,377.0,0.0,['amlb_reg'],False 78 | 42705,317614,400000.0,101.0,0.0,['amlb_reg'],False 79 | 42728,359929,10000000.0,10.0,0.0,['amlb_reg'],False 80 | 550,359930,2178.0,4.0,0.0,['amlb_reg'],False 81 | 546,359931,576.0,12.0,0.0,['amlb_reg'],False 82 | 541,359932,1156.0,6.0,0.0,['amlb_reg'],False 83 | 507,359933,3107.0,7.0,0.0,['amlb_reg'],False 84 | 505,359934,240.0,125.0,0.0,['amlb_reg'],False 85 | 287,359935,6497.0,12.0,0.0,['amlb_reg'],False 86 | 216,359936,16599.0,19.0,0.0,['amlb_reg'],False 87 | 41540,359937,166821.0,10.0,0.0,['amlb_reg'],False 88 | 42688,359938,10692.0,13.0,0.0,['amlb_reg'],False 89 | 422,359939,8885.0,267.0,0.0,['amlb_reg'],False 90 | 416,359940,8885.0,252.0,0.0,['amlb_reg'],False 91 | 42724,359941,39644.0,60.0,0.0,['amlb_reg'],False 92 | 42727,359942,7063.0,45.0,104249.0,['amlb_reg'],False 93 | 42729,359943,581835.0,19.0,0.0,['amlb_reg'],False 94 | 42726,359944,4177.0,9.0,0.0,['amlb_reg'],False 95 | 42730,359945,1994.0,127.0,39202.0,['amlb_reg'],False 96 | 201,359946,15000.0,49.0,0.0,['amlb_reg'],False 97 | 41980,359948,4440.0,117.0,27150.0,['amlb_reg'],False 98 | 42731,359949,21613.0,22.0,0.0,['amlb_reg'],False 99 | 531,359950,506.0,14.0,0.0,['amlb_reg'],False 100 | 42563,359951,1460.0,80.0,6965.0,['amlb_reg'],False 101 | 574,359952,22784.0,17.0,0.0,['amlb_reg'],False 102 | 3050,360932,5742.0,1026.0,0.0,['amlb_reg'],False 103 | 3277,360933,5766.0,1026.0,0.0,['amlb_reg'],False 104 | 43071,360945,1090.0,145.0,0.0,['amlb_reg'],False 105 | -------------------------------------------------------------------------------- /tabdpt_datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod, abstractstaticmethod 4 | 5 | import numpy as np 6 | import scipy 7 | import torch.utils.data 8 | 9 | 10 | class Dataset(ABC): 11 | """ 12 | Base class for dataset classes. 13 | 14 | A dataset class represents a suite of datasets and an instance of the class represents a single 15 | loaded dataset from the suite. 16 | 17 | To support evaluation over a range of methods, the current code assumes that the entire dataset 18 | can be loaded into memory and accessed as numpy arrays using all_instances, which also supports 19 | fast indexing. Synthetic datasets with a fixed set of instances can either generate their data at 20 | initialization or lazily. 21 | 22 | We will probably want to add separate methods for synthetic datasets that generate an indefinite 23 | number of instances on the fly, since these are less appropriate for evaluation and might not be 24 | usable for all baseline models anyway (e.g., decision trees that expect all instances to be 25 | available at once). 26 | 27 | To enable reproducible splits, datasets take an optional parameter task_id. Subclasses should 28 | handle splitting in prepare_data in a reproducible way given a task_id. They can also implement 29 | functions to generate splits that return a task_id, ensuring future reproducibility. 30 | """ 31 | 32 | def __init__(self, name: str, task_id: str = None): 33 | """ 34 | Params: 35 | name - the name of the dataset to load 36 | task_id - an optional ID that subclasses should use to ensure reproducibility 37 | """ 38 | self.name = name 39 | self._task_id = task_id 40 | self.metadata = {"suite_name": self.suite_name(), "dataset_name": name} 41 | self.column_names = None 42 | 43 | @abstractstaticmethod 44 | def suite_name() -> str: 45 | """ 46 | Name of the suite of datasets provided by this class, for cataloguing. Typically using all 47 | lowercase for consistency. 48 | """ 49 | pass 50 | 51 | @abstractstaticmethod 52 | def all_names() -> list[str] | None: 53 | """ 54 | Return the names of all datasets provided by this class, or None if it does not provide a 55 | fixed set of datasets 56 | """ 57 | pass 58 | 59 | @abstractmethod 60 | def prepare_data(self, download_dir: str): 61 | """ 62 | Download data if needed and do any CPU-side preprocessing, splitting, etc as needed. 63 | """ 64 | pass 65 | 66 | def task_id(self) -> str | None: 67 | """ 68 | Returns the task_id that can be used to reproduce the current task, or None if not possible. 69 | 70 | If task_id is specified as a constructor argument, this should always return the same value. 71 | Otherwise, subclasses can either choose to modify _task_id or override this method directly 72 | to return the task_id for the settings that the dataset was set up with. 73 | """ 74 | return self._task_id 75 | 76 | def __len__(self) -> int: 77 | xs, _ = self.all_instances() 78 | return xs.shape[0] 79 | 80 | def __getitem__(self, i) -> tuple[np.ndarray, np.ndarray | int | float | None]: 81 | xs, ys = self.all_instances() 82 | if ys is not None: 83 | return xs[i], ys[i] 84 | return xs[i], None 85 | 86 | @abstractmethod 87 | def all_instances(self) -> tuple[np.ndarray, np.ndarray | None]: 88 | """ 89 | Return all instances as a feature matrix and target vector. 90 | """ 91 | pass 92 | 93 | @abstractmethod 94 | def train_inds(self) -> list[int] | np.ndarray | range: 95 | pass 96 | 97 | @abstractmethod 98 | def val_inds(self) -> list[int] | np.ndarray | range: 99 | pass 100 | 101 | @abstractmethod 102 | def test_inds(self) -> list[int] | np.ndarray | range: 103 | pass 104 | 105 | def train_instances(self) -> tuple[np.ndarray, np.ndarray]: 106 | X, y = self.all_instances() 107 | return X[self.train_inds()], y[self.train_inds()] 108 | 109 | def val_instances(self) -> tuple[np.ndarray, np.ndarray]: 110 | X, y = self.all_instances() 111 | return X[self.val_inds()], y[self.val_inds()] 112 | 113 | def test_instances(self) -> tuple[np.ndarray, np.ndarray]: 114 | X, y = self.all_instances() 115 | return X[self.test_inds()], y[self.test_inds()] 116 | 117 | def auto_populate_metadata(self): 118 | X, y = self.all_instances() 119 | self.metadata["n_rows"] = X.shape[0] 120 | self.metadata["n_features"] = X.shape[1] 121 | self.metadata["n_cells"] = X.shape[0] * X.shape[1] 122 | if y is None: 123 | self.metadata["target_type"] = "none" 124 | else: 125 | self.metadata["y_mean"] = np.mean(y) 126 | self.metadata["y_var"] = np.var(y) 127 | if "target_type" not in self.metadata: 128 | self.metadata["target_type"] = "unknown" 129 | self.metadata["n_cells"] += y.shape[0] 130 | 131 | self.metadata["frac_missing"] = np.isnan(X).mean() 132 | self.metadata["frac_rows_with_missing"] = np.isnan(X).any(axis=1).mean() 133 | self.metadata["frac_features_with_missing"] = np.isnan(X).any(axis=0).mean() 134 | 135 | lin_coeffs = [] 136 | for i in range(X.shape[1]): 137 | col = np.nan_to_num(X[:, i]) 138 | if np.all(np.isclose(col, col[0])): 139 | lin_coeffs.append(None) 140 | continue 141 | try: 142 | res = np.linalg.lstsq(np.stack((col, np.ones_like(col)), axis=1), y, rcond=None) 143 | lin_coeffs.append(res[0].tolist()) 144 | except np.linalg.LinAlgError: 145 | lin_coeffs.append(None) 146 | continue 147 | self.metadata["column_lin_coeffs"] = lin_coeffs 148 | 149 | self.metadata["column_means"] = X.mean(axis=0).tolist() 150 | self.metadata["column_vars"] = X.var(axis=0).tolist() 151 | self.metadata["column_skews"] = scipy.stats.skew(X, axis=0).tolist() 152 | self.metadata["column_kurtoses"] = scipy.stats.kurtosis(X, axis=0).tolist() 153 | 154 | 155 | class TorchDataset(torch.utils.data.Dataset): 156 | """ 157 | Utility class for accessing a Dataset split as a torch Dataset 158 | """ 159 | 160 | def __init__(self, dataset: Dataset, split): 161 | assert split in ("train", "val", "test", "all") 162 | self.dataset = dataset 163 | if split == "all": 164 | self.inds = range(len(dataset)) 165 | elif split == "train": 166 | self.inds = dataset.train_inds() 167 | elif split == "val": 168 | self.inds = dataset.val_inds() 169 | elif split == "test": 170 | self.inds = dataset.test_inds() 171 | 172 | def __len__(self): 173 | return len(self.inds) 174 | 175 | def __getitem__(self, i): 176 | return self.dataset[self.inds[i]] 177 | -------------------------------------------------------------------------------- /eval_full.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import scipy 7 | import torch 8 | from sklearn.metrics import accuracy_score, f1_score, r2_score, roc_auc_score 9 | from tqdm import tqdm 10 | 11 | from tabdpt_datasets.openml import OpenMLDataset 12 | from model import TabDPTModel, Task 13 | from tabdpt import TabDPTClassifier, TabDPTRegressor 14 | from utils import FAISS, DataPreprocessor 15 | 16 | 17 | class FullEval: 18 | def __init__( 19 | self, 20 | device, 21 | max_feat, 22 | impute_method="mean", 23 | use_retrieval=False, 24 | ): 25 | """Initialize the FullEval class for evaluating tabular data. 26 | 27 | Args: 28 | device (str): The device to use for evaluation (e.g., "cuda:0" or "cpu"). 29 | max_feat (int): The maximum number of features to use for evaluation. 30 | impute_method (str, optional): The imputation method to use for missing values. Defaults to "mean". 31 | use_retrieval (bool, optional): Whether to use retrieval-augmented generation. Defaults to False. 32 | """ 33 | self.device = device 34 | self.max_feat = max_feat 35 | self.use_retrieval = use_retrieval 36 | 37 | # loading classification dataset 38 | df_eval_cls = pd.read_csv("data_splits/cls_datasets.csv") 39 | self.cc18_dids = df_eval_cls[df_eval_cls["test_all"] == True][ 40 | "did" 41 | ].values.tolist() # 72 datasets 42 | 43 | # loading regression dataset 44 | reg_df = pd.read_csv("data_splits/reg_datasets.csv") 45 | ctr_df = reg_df[reg_df["test"] == True] 46 | self.ctr_dids = ctr_df["did"].values.tolist() 47 | 48 | # get did (dataset ID) to tid (task ID) mapping 49 | did_tid_mapping = dict( 50 | zip( 51 | df_eval_cls[df_eval_cls["test_all"] == True]["did"], 52 | df_eval_cls[df_eval_cls["test_all"] == True]["tid"], 53 | ) 54 | ) 55 | did_tid_mapping.update(dict(zip(ctr_df["did"], ctr_df["tid"]))) 56 | 57 | self.datasets = {} 58 | for did in [*self.cc18_dids, *self.ctr_dids]: 59 | dataset = OpenMLDataset("openml_dataset", openml_task_id=int(did_tid_mapping[did])) 60 | dataset.prepare_data("data") 61 | dataset_name = dataset.openml_dataset.name 62 | 63 | X_train, y_train = dataset.train_instances() 64 | X_val, y_val = dataset.val_instances() 65 | 66 | X_train = np.concatenate([X_train, X_val], axis=0) 67 | y_train = np.concatenate([y_train, y_val], axis=0) 68 | 69 | X_test, y_test = dataset.test_instances() 70 | 71 | # TODO missing convert_cat2num step in full_dataset.py 72 | # Preprocess data 73 | preprocessor = DataPreprocessor(impute_method=impute_method) 74 | X_train = preprocessor.fit_transform(X_train) 75 | X_test = preprocessor.transform(X_test) 76 | 77 | # Create faiss index 78 | faiss_knn = FAISS(X_train, use_hnsw=False, metric="L2") 79 | 80 | # back to tensor 81 | X_train = torch.tensor(X_train).to(device) 82 | X_test = torch.tensor(X_test).to(device) 83 | y_train = torch.tensor(y_train).to(device) 84 | y_test = torch.tensor(y_test) 85 | 86 | self.datasets[did] = (dataset_name, faiss_knn, X_train, X_test, y_train, y_test) 87 | 88 | @torch.no_grad() 89 | def eval( 90 | self, 91 | model: TabDPTModel, 92 | context_length: int = 1024, 93 | inf_batch_size=512, 94 | temperature=0.8, 95 | return_individual_perfs=False, 96 | ) -> tuple[dict, Optional[dict]]: 97 | """Evaluate tdicthe model on classification and regression datasets. 98 | 99 | Args: 100 | model (TabDPTModel): The TabDPT model to evaluate. 101 | context_length (int, optional): . Defaults to 1024. 102 | inf_batch_size (int, optional): Inference batch size. Defaults to 512. 103 | temperature (float, optional): _description_. Defaults to 0.8. 104 | return_individual_perfs (bool, optional): _description_. Defaults to False. 105 | 106 | Returns: 107 | dict: performance metrics for classification and regression tasks. 108 | Optional[dict]: individual performance metrics for each dataset if return_individual_perfs is True. 109 | """ 110 | classifier = TabDPTClassifier( 111 | model=model, 112 | mode=Task.CLS, 113 | device=self.device, 114 | inf_batch_size=inf_batch_size, 115 | tensor_eval=True, 116 | ) 117 | regressor = TabDPTRegressor( 118 | model=model, 119 | mode=Task.REG, 120 | device=self.device, 121 | inf_batch_size=inf_batch_size, 122 | tensor_eval=True, 123 | ) 124 | 125 | cls_performance = defaultdict(lambda: defaultdict(list)) 126 | reg_performance = defaultdict(lambda: defaultdict(list)) 127 | 128 | final_perfs = {} 129 | individual_perfs = {} 130 | 131 | # evaluation for classification datasets 132 | for did in tqdm(self.cc18_dids): 133 | dataset_name, faiss_index, X_train, X_test, y_train, y_test = self.datasets[did] 134 | 135 | classifier.fit(X_train, y_train, faiss_index) 136 | pred_val = classifier.predict_proba( 137 | X_test, 138 | temperature=temperature, 139 | context_size=context_length, 140 | use_retrieval=self.use_retrieval, 141 | ) 142 | 143 | if len(np.unique(y_test)) == 2: 144 | auc = roc_auc_score(y_test, pred_val[:, 1]) 145 | else: 146 | auc = roc_auc_score(y_test, pred_val, multi_class="ovo") 147 | 148 | f1 = f1_score(y_test, np.argmax(pred_val, axis=1), average="weighted") 149 | acc = accuracy_score(y_test, np.argmax(pred_val, axis=1)) 150 | ce = torch.nn.functional.cross_entropy( 151 | torch.Tensor(pred_val).float(), torch.Tensor(y_test).long() 152 | ) 153 | cls_performance["cc18"][did] = [acc, f1, auc, ce] 154 | individual_perfs[dataset_name] = [acc, f1, auc, ce] 155 | 156 | cls_perfs = np.array(list(cls_performance["cc18"].values())) 157 | cls_perfs_mean = cls_perfs.mean(0) 158 | 159 | final_perfs["cls-cc18-acc"] = cls_perfs_mean[0] 160 | final_perfs["cls-cc18-f1"] = cls_perfs_mean[1] 161 | final_perfs["cls-cc18-auc"] = cls_perfs_mean[2] 162 | final_perfs["cls-cc18-ce"] = cls_perfs_mean[3] 163 | 164 | # evaluation for regression datasets 165 | for did in self.ctr_dids: 166 | dataset_name, faiss_index, X_train, X_test, y_train, y_test = self.datasets[did] 167 | 168 | regressor.fit(X_train, y_train, faiss_index) 169 | pred_val = regressor.predict( 170 | X_test, context_size=context_length, use_retrieval=self.use_retrieval 171 | ).flatten() 172 | # scaler = StandardScaler() 173 | # y_test = scaler.fit_transform(y_test.reshape(-1, 1)).flatten() 174 | 175 | mse = np.mean((y_test.cpu().numpy() - pred_val) ** 2) 176 | correlation = scipy.stats.pearsonr(y_test.cpu().numpy(), pred_val.flatten()) 177 | r2 = r2_score(y_test.cpu().numpy(), pred_val) 178 | 179 | reg_performance["ctr"][did] = [mse, correlation[0], r2] 180 | individual_perfs[dataset_name] = [mse, correlation[0], r2] 181 | 182 | reg_perfs = np.array(list(reg_performance["ctr"].values())) 183 | reg_perfs_mean = reg_perfs.mean(0) 184 | final_perfs["reg-ctr-mse"] = reg_perfs_mean[0] 185 | final_perfs["reg-ctr-cor"] = reg_perfs_mean[1] 186 | final_perfs["reg-ctr-r2"] = reg_perfs_mean[2] 187 | 188 | if return_individual_perfs: 189 | return final_perfs, individual_perfs 190 | else: 191 | return final_perfs 192 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from omegaconf import DictConfig 7 | 8 | from transformer_layer import TransformerEncoderLayer 9 | 10 | 11 | class Task(Enum): 12 | """Enum representing the type of task for the model. 13 | REG: Regression task. 14 | CLS: Classification task. 15 | """ 16 | 17 | REG = 1 18 | CLS = 2 19 | 20 | 21 | def pad_x(X: torch.Tensor, num_features: int) -> torch.Tensor: 22 | """Pad the input tensor X with zeros to match the specified number of features. 23 | Args: 24 | X (torch.Tensor): Input tensor of shape (seq_len, batch_size, n_features). 25 | num_features (int): Desired number of features after padding. 26 | Returns: 27 | torch.Tensor: Padded tensor of shape (seq_len, batch_size, num_features). 28 | """ 29 | seq_len, batch_size, n_features = X.shape 30 | zero_feature_padding = torch.zeros( 31 | (seq_len, batch_size, num_features - n_features), device=X.device 32 | ) 33 | return torch.cat([X, zero_feature_padding], -1) 34 | 35 | 36 | def maskmean(x: torch.Tensor, mask: torch.Tensor, dim: int) -> torch.Tensor: 37 | """Compute the mean of x along the specified dimension, ignoring masked values. 38 | Args: 39 | x (torch.Tensor): Input tensor of shape (time, batch, hidden dimension). 40 | mask (torch.Tensor): Mask tensor of the same shape as x, where True indicates valid values. 41 | dim (int): Dimension along which to compute the mean. 42 | Returns: 43 | torch.Tensor: Mean of x along the specified dimension, with masked values ignored. 44 | """ 45 | x = torch.where(mask, x, 0) 46 | return x.sum(dim=dim, keepdim=True) / mask.sum(dim=dim, keepdim=True) 47 | 48 | 49 | def maskstd(x: torch.Tensor, mask: torch.Tensor, dim: int = 0): 50 | """Compute the standard deviation of x along the specified dimension, ignoring masked values. 51 | Args: 52 | x (torch.Tensor): Input tensor of shape (time, batch, hidden dimension). 53 | mask (torch.Tensor): Mask tensor of the same shape as x, where True indicates valid values. 54 | dim (int): Dimension along which to compute the standard deviation. 55 | Returns: 56 | torch.Tensor: Standard deviation of x along the specified dimension, with masked values ignored. 57 | """ 58 | num = mask.sum(dim=dim, keepdim=True) 59 | mean = maskmean(x, mask, dim=0) 60 | diffs = torch.where(mask, mean - x, 0) 61 | return ((diffs**2).sum(dim=0, keepdim=True) / (num - 1)) ** 0.5 62 | 63 | 64 | def normalize_data(data: torch.Tensor, eval_pos: int) -> torch.Tensor: 65 | """Normalize the input data by subtracting the mean and dividing by the standard deviation. 66 | 67 | Args: 68 | data (torch.Tensor): input data of shape (time, batch, hidden dimension). 69 | eval_pos (int): Evaluation position used for slicing attention keys and values. 70 | 71 | Returns: 72 | torch.Tensor: normalized data of shape (time, batch, hidden dimension). 73 | """ 74 | X = data[:eval_pos] if eval_pos > 0 else data 75 | mask = ~torch.isnan(X) 76 | mean = maskmean(X, mask, dim=0) 77 | std = maskstd(X, mask, dim=0) + 1e-6 78 | data = (data - mean) / std 79 | return data 80 | 81 | 82 | def clip_outliers(data: torch.Tensor, eval_pos: int, n_sigma: int = 4): 83 | """ 84 | clip outliers in data based on a given number of standard deviations. 85 | 86 | Args: 87 | data (torch.Tensor): Input data of shape (time, batch, hidden dimension). 88 | eval_pos (int): Evaluation position used for slicing attention keys and values. 89 | n_sigma (int): Number of standard deviations to use for clipping outliers. 90 | Returns: 91 | torch.Tensor: Data with outliers clipped, of shape (time, batch, hidden dimension).""" 92 | assert len(data.shape) == 3, "X must be T,B,H" 93 | X = data[:eval_pos] if eval_pos > 0 else data 94 | mask = ~torch.isnan(X) 95 | mean = maskmean(X, mask, dim=0) 96 | cutoff = n_sigma * maskstd(X, mask, dim=0) 97 | mask &= cutoff >= torch.abs(X - mean) 98 | cutoff = n_sigma * maskstd(X, mask, dim=0) 99 | return torch.clip(data, mean - cutoff, mean + cutoff) 100 | 101 | 102 | def convert_to_torch_tensor(input: np.ndarray | torch.Tensor) -> torch.Tensor: 103 | """Convert a NumPy array or a PyTorch tensor to a PyTorch tensor. 104 | Args: 105 | input (np.ndarray | torch.Tensor): Input data to be converted. 106 | Returns: 107 | torch.Tensor: Converted PyTorch tensor. 108 | Raises: 109 | TypeError: If the input is neither a NumPy array nor a PyTorch tensor. 110 | """ 111 | if isinstance(input, np.ndarray): 112 | return torch.from_numpy(input) 113 | elif torch.is_tensor(input): 114 | return input 115 | else: 116 | raise TypeError("Input must be a NumPy array or a PyTorch tensor.") 117 | 118 | 119 | class TabDPTModel(nn.Module): 120 | def __init__( 121 | self, 122 | dropout: float, 123 | n_out: int, 124 | nhead: int, 125 | nhid: int, 126 | ninp: int, 127 | nlayers: int, 128 | num_features: int, 129 | ): 130 | """TabDPTModel initialization. 131 | 132 | Args: 133 | dropout (float): Dropout rate. 134 | n_out (int): Number of output classes. 135 | nhead (int): Number of attention heads. 136 | nhid (int): Hidden dimension. 137 | ninp (int): Input dimension. 138 | nlayers (int): Number of transformer layers. 139 | num_features (int): Number of input features. 140 | """ 141 | super().__init__() 142 | self.n_out = n_out # number of output classes 143 | self.ninp = ninp # embedding dimension 144 | self.transformer_encoder = nn.ModuleList( 145 | [ 146 | TransformerEncoderLayer( 147 | embed_dim=ninp, 148 | num_heads=nhead, 149 | ff_dim=nhid, 150 | ) 151 | for _ in range(nlayers) 152 | ] 153 | ) 154 | self.num_features = num_features 155 | self.encoder = nn.Linear(num_features, ninp) 156 | self.dropout = nn.Dropout(p=dropout) 157 | self.y_encoder = nn.Linear(1, ninp) 158 | self.head = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out + 1)) 159 | 160 | def forward( 161 | self, 162 | x_src: torch.Tensor, 163 | y_src: torch.Tensor, 164 | return_log_act_norms: bool = False, 165 | ) -> torch.Tensor: 166 | """Forward pass of the TabDPTModel. 167 | Args: 168 | x_src (torch.Tensor): Input features of shape (time, batch, hidden dimension). 169 | y_src (torch.Tensor): Target values of shape (T, B). 170 | return_log_act_norms (bool): Whether to return activation norms for logging. 171 | Returns: 172 | torch.Tensor: Predicted values of shape (T, B, n_out + 1). 173 | """ 174 | eval_pos = y_src.shape[0] 175 | 176 | # preproces features by normalizing and clipping outliers 177 | x_src = clip_outliers(x_src, -1 if self.training else eval_pos, n_sigma=4) 178 | x_src = normalize_data(x_src, -1 if self.training else eval_pos) 179 | x_src = clip_outliers(x_src, -1 if self.training else eval_pos, n_sigma=4) 180 | x_src = torch.nan_to_num(x_src, nan=0) 181 | 182 | # feature encoding 183 | x_src = self.encoder(x_src) 184 | mean = (x_src**2).mean(dim=-1, keepdim=True) 185 | rms = torch.sqrt(mean) 186 | x_src = x_src / rms 187 | 188 | # target encoding 189 | y_src = self.y_encoder(y_src.unsqueeze(-1)) 190 | train_x = x_src[:eval_pos] + y_src 191 | src = torch.cat([train_x, x_src[eval_pos:]], 0) 192 | 193 | log_act_norms = {} 194 | log_act_norms["y"] = torch.norm(y_src, dim=-1).mean() 195 | 196 | # transformer layers 197 | for l, layer in enumerate(self.transformer_encoder): 198 | if l in [0, 1, 3, 6, 9]: 199 | log_act_norms[f"layer_{l}"] = torch.norm(src, dim=-1).mean() 200 | src = layer(src, eval_pos) 201 | 202 | # final head 203 | pred = self.head(src) 204 | 205 | if return_log_act_norms: 206 | return pred[eval_pos:], log_act_norms 207 | else: 208 | return pred[eval_pos:] 209 | 210 | @classmethod 211 | def load(cls, model_state: dict, config: DictConfig) -> nn.Module: 212 | """Load a pre-trained TabDPTModel from a state dictionary. 213 | 214 | Args: 215 | cls: TODO 216 | model_state (dict): state dictionary containing the model parameters. 217 | config (DictConfig): configuration object containing model parameters. 218 | 219 | Returns: 220 | nn.Module: model instance with loaded parameters. 221 | """ 222 | # TODO loading model inside its own class without self? 223 | assert config.model.max_num_classes > 2 224 | model = TabDPTModel( 225 | dropout=config.training.dropout, 226 | n_out=config.model.max_num_classes, 227 | nhead=config.model.nhead, 228 | nhid=config.model.emsize * config.model.nhid_factor, 229 | ninp=config.model.emsize, 230 | nlayers=config.model.nlayers, 231 | num_features=config.model.max_num_features, 232 | ) 233 | 234 | module_prefix = "_orig_mod." 235 | model_state = {k.replace(module_prefix, ""): v for k, v in model_state.items()} 236 | model.load_state_dict(model_state) 237 | model.to(config.env.device) 238 | model.eval() 239 | return model 240 | -------------------------------------------------------------------------------- /tabdpt_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import openml 3 | import pandas as pd 4 | import torch 5 | 6 | 7 | def get_openml_classification(did, max_samples, multiclass=True, shuffled=True): 8 | dataset = openml.datasets.get_dataset(did) 9 | X, y, categorical_indicator, attribute_names = dataset.get_data( 10 | dataset_format="array", target=dataset.default_target_attribute 11 | ) 12 | 13 | if not multiclass: 14 | X = X[y < 2] 15 | y = y[y < 2] 16 | 17 | if multiclass and not shuffled: 18 | raise NotImplementedError("This combination of multiclass and shuffling isn't implemented") 19 | 20 | if not isinstance(X, np.ndarray) or not isinstance(y, np.ndarray): 21 | print("Not a NP Array, skipping") 22 | return None, None, None, None 23 | 24 | if not shuffled: 25 | sort = np.argsort(y) if y.mean() < 0.5 else np.argsort(-y) 26 | pos = int(y.sum()) if y.mean() < 0.5 else int((1 - y).sum()) 27 | X, y = X[sort][-pos * 2 :], y[sort][-pos * 2 :] 28 | y = torch.tensor(y).reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).float() 29 | X = ( 30 | torch.tensor(X) 31 | .reshape(2, -1, X.shape[1]) 32 | .transpose(0, 1) 33 | .reshape(-1, X.shape[1]) 34 | .flip([0]) 35 | .float() 36 | ) 37 | else: 38 | order = np.arange(y.shape[0]) 39 | np.random.seed(13) 40 | np.random.shuffle(order) 41 | X, y = torch.tensor(X[order]), torch.tensor(y[order]) 42 | if max_samples: 43 | X, y = X[:max_samples], y[:max_samples] 44 | 45 | return X, y, list(np.where(categorical_indicator)[0]), attribute_names 46 | 47 | 48 | def load_openml_list( 49 | dids, 50 | filter_for_nan=False, 51 | num_feats=100, 52 | min_samples=100, 53 | max_samples=400, 54 | multiclass=True, 55 | max_num_classes=10, 56 | shuffled=True, 57 | return_capped=False, 58 | ): 59 | datasets = [] 60 | openml_list = openml.datasets.list_datasets(dids) 61 | print(f"Number of datasets: {len(openml_list)}") 62 | 63 | datalist = pd.DataFrame.from_dict(openml_list, orient="index") 64 | if filter_for_nan: 65 | datalist = datalist[datalist["NumberOfInstancesWithMissingValues"] == 0] 66 | print(f"Number of datasets after Nan and feature number filtering: {len(datalist)}") 67 | 68 | for ds in datalist.index: 69 | modifications = { 70 | "samples_capped": False, 71 | "classes_capped": False, 72 | "feats_capped": False, 73 | } 74 | entry = datalist.loc[ds] 75 | 76 | if entry["NumberOfClasses"] == 0.0: 77 | raise Exception("Regression not supported") 78 | # X, y, categorical_feats, attribute_names = get_openml_regression(int(entry.did), max_samples) 79 | else: 80 | X, y, categorical_feats, attribute_names = get_openml_classification( 81 | int(entry.did), max_samples, multiclass=multiclass, shuffled=shuffled 82 | ) 83 | if X is None: 84 | continue 85 | 86 | if X.shape[1] > num_feats: 87 | if return_capped: 88 | X = X[:, 0:num_feats] 89 | categorical_feats = [c for c in categorical_feats if c < num_feats] 90 | modifications["feats_capped"] = True 91 | else: 92 | print("Too many features") 93 | continue 94 | if X.shape[0] == max_samples: 95 | modifications["samples_capped"] = True 96 | 97 | if X.shape[0] < min_samples: 98 | print("Too few samples left") 99 | continue 100 | 101 | if len(np.unique(y)) > max_num_classes: 102 | if return_capped: 103 | X = X[y < np.unique(y)[10]] 104 | y = y[y < np.unique(y)[10]] 105 | modifications["classes_capped"] = True 106 | else: 107 | print("Too many classes") 108 | continue 109 | 110 | datasets += [[entry["name"], X, y, categorical_feats, attribute_names, modifications]] 111 | 112 | return datasets, datalist 113 | 114 | 115 | # Classification 116 | valid_dids_classification = [13, 59, 4, 15, 40710, 43, 1498] 117 | test_dids_classification = [ 118 | 973, 119 | 1596, 120 | 40981, 121 | 1468, 122 | 40984, 123 | 40975, 124 | 41163, 125 | 41147, 126 | 1111, 127 | 41164, 128 | 1169, 129 | 1486, 130 | 41143, 131 | 1461, 132 | 41167, 133 | 40668, 134 | 41146, 135 | 41169, 136 | 41027, 137 | 23517, 138 | 41165, 139 | 41161, 140 | 41159, 141 | 41138, 142 | 1590, 143 | 41166, 144 | 1464, 145 | 41168, 146 | 41150, 147 | 1489, 148 | 41142, 149 | 3, 150 | 12, 151 | 31, 152 | 54, 153 | 1067, 154 | ] 155 | valid_large_classification = [ 156 | 943, 157 | 23512, 158 | 49, 159 | 838, 160 | 1131, 161 | 767, 162 | 1142, 163 | 748, 164 | 1112, 165 | 1541, 166 | 384, 167 | 912, 168 | 1503, 169 | 796, 170 | 20, 171 | 30, 172 | 903, 173 | 4541, 174 | 961, 175 | 805, 176 | 1000, 177 | 4135, 178 | 1442, 179 | 816, 180 | 1130, 181 | 906, 182 | 1511, 183 | 184, 184 | 181, 185 | 137, 186 | 1452, 187 | 1481, 188 | 949, 189 | 449, 190 | 50, 191 | 913, 192 | 1071, 193 | 831, 194 | 843, 195 | 9, 196 | 896, 197 | 1532, 198 | 311, 199 | 39, 200 | 451, 201 | 463, 202 | 382, 203 | 778, 204 | 474, 205 | 737, 206 | 1162, 207 | 1538, 208 | 820, 209 | 188, 210 | 452, 211 | 1156, 212 | 37, 213 | 957, 214 | 911, 215 | 1508, 216 | 1054, 217 | 745, 218 | 1220, 219 | 763, 220 | 900, 221 | 25, 222 | 387, 223 | 38, 224 | 757, 225 | 1507, 226 | 396, 227 | 4153, 228 | 806, 229 | 779, 230 | 746, 231 | 1037, 232 | 871, 233 | 717, 234 | 1480, 235 | 1010, 236 | 1016, 237 | 981, 238 | 1547, 239 | 1002, 240 | 1126, 241 | 1459, 242 | 846, 243 | 837, 244 | 1042, 245 | 273, 246 | 1524, 247 | 375, 248 | 1018, 249 | 1531, 250 | 1458, 251 | 6332, 252 | 1546, 253 | 1129, 254 | 679, 255 | 389, 256 | ] 257 | 258 | open_cc_dids = [ 259 | 11, 260 | 14, 261 | 15, 262 | 16, 263 | 18, 264 | 22, 265 | 23, 266 | 29, 267 | 31, 268 | 37, 269 | 50, 270 | 54, 271 | 188, 272 | 458, 273 | 469, 274 | 1049, 275 | 1050, 276 | 1063, 277 | 1068, 278 | 1510, 279 | 1494, 280 | 1480, 281 | 1462, 282 | 1464, 283 | 6332, 284 | 23381, 285 | 40966, 286 | 40982, 287 | 40994, 288 | 40975, 289 | ] 290 | # Filtered by N_samples < 2000, N feats < 100, N classes < 10 291 | 292 | open_cc_valid_dids = [ 293 | 13, 294 | 25, 295 | 35, 296 | 40, 297 | 41, 298 | 43, 299 | 48, 300 | 49, 301 | 51, 302 | 53, 303 | 55, 304 | 56, 305 | 59, 306 | 61, 307 | 187, 308 | 285, 309 | 329, 310 | 333, 311 | 334, 312 | 335, 313 | 336, 314 | 337, 315 | 338, 316 | 377, 317 | 446, 318 | 450, 319 | 451, 320 | 452, 321 | 460, 322 | 463, 323 | 464, 324 | 466, 325 | 470, 326 | 475, 327 | 481, 328 | 679, 329 | 694, 330 | 717, 331 | 721, 332 | 724, 333 | 733, 334 | 738, 335 | 745, 336 | 747, 337 | 748, 338 | 750, 339 | 753, 340 | 756, 341 | 757, 342 | 764, 343 | 765, 344 | 767, 345 | 774, 346 | 778, 347 | 786, 348 | 788, 349 | 795, 350 | 796, 351 | 798, 352 | 801, 353 | 802, 354 | 810, 355 | 811, 356 | 814, 357 | 820, 358 | 825, 359 | 826, 360 | 827, 361 | 831, 362 | 839, 363 | 840, 364 | 841, 365 | 844, 366 | 852, 367 | 853, 368 | 854, 369 | 860, 370 | 880, 371 | 886, 372 | 895, 373 | 900, 374 | 906, 375 | 907, 376 | 908, 377 | 909, 378 | 915, 379 | 925, 380 | 930, 381 | 931, 382 | 934, 383 | 939, 384 | 940, 385 | 941, 386 | 949, 387 | 966, 388 | 968, 389 | 984, 390 | 987, 391 | 996, 392 | 1048, 393 | 1054, 394 | 1071, 395 | 1073, 396 | 1100, 397 | 1115, 398 | 1412, 399 | 1442, 400 | 1443, 401 | 1444, 402 | 1446, 403 | 1447, 404 | 1448, 405 | 1451, 406 | 1453, 407 | 1488, 408 | 1490, 409 | 1495, 410 | 1498, 411 | 1499, 412 | 1506, 413 | 1508, 414 | 1511, 415 | 1512, 416 | 1520, 417 | 1523, 418 | 4153, 419 | 23499, 420 | 40496, 421 | 40646, 422 | 40663, 423 | 40669, 424 | 40680, 425 | 40682, 426 | 40686, 427 | 40690, 428 | 40693, 429 | 40705, 430 | 40706, 431 | 40710, 432 | 40711, 433 | 40981, 434 | 41430, 435 | 41538, 436 | 41919, 437 | 41976, 438 | 42172, 439 | 42261, 440 | 42544, 441 | 42585, 442 | 42638, 443 | ] 444 | 445 | grinzstjan_categorical_regression = [ 446 | 44054, 447 | 44055, 448 | 44056, 449 | 44057, 450 | 44059, 451 | 44061, 452 | 44062, 453 | 44063, 454 | 44064, 455 | 44065, 456 | 44066, 457 | 44068, 458 | 44069, 459 | ] 460 | 461 | grinzstjan_numerical_classification = [ 462 | 44089, 463 | 44090, 464 | 44091, 465 | 44120, 466 | 44121, 467 | 44122, 468 | 44123, 469 | 44124, 470 | 44125, 471 | 44126, 472 | 44127, 473 | 44128, 474 | 44129, 475 | 44130, 476 | 44131, 477 | ] 478 | 479 | grinzstjan_categorical_classification = [44156, 44157, 44159, 44160, 44161, 44162, 44186] 480 | 481 | grinzstjan_classification = ( 482 | grinzstjan_numerical_classification + grinzstjan_categorical_classification 483 | ) 484 | -------------------------------------------------------------------------------- /tabdpt_datasets/talent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | from pathlib import Path 4 | 5 | import gdown 6 | import numpy as np 7 | from sklearn.preprocessing import LabelEncoder, OrdinalEncoder 8 | 9 | from tabdpt_datasets.dataset import Dataset 10 | 11 | 12 | class TalentDataset(Dataset): 13 | """ 14 | Dataset for LAMDA-TALENT baseline 15 | """ 16 | 17 | @staticmethod 18 | def suite_name(): 19 | return "talent" 20 | 21 | @staticmethod 22 | def all_names(): 23 | return [ 24 | "1000-Cameras-Dataset", 25 | "2dplanes", 26 | "3D_Estimation_using_RSSI_of_WLAN_dataset", 27 | "3D_Estimation_using_RSSI_of_WLAN_dataset_complete_1_target", 28 | "abalone", 29 | "Abalone_reg", 30 | "accelerometer", 31 | "ada", 32 | "ada_agnostic", 33 | "ada_prior", 34 | "Ailerons", 35 | "airfoil_self_noise", 36 | "airlines_seed_0_nrows_2000_nclasses_10_ncols_100_stratify_True", 37 | "allbp", 38 | "allrep", 39 | "Amazon_employee_access", 40 | "analcatdata_authorship", 41 | "analcatdata_supreme", 42 | "Another-Dataset-on-used-Fiat-500-(1538-rows)", 43 | "archive2", 44 | "archive_r56_Maths", 45 | "archive_r56_Portuguese", 46 | "artificial-characters", 47 | "ASP-POTASSCO-classification", 48 | "auction_verification", 49 | "autoUniv-au4-2500", 50 | "autoUniv-au7-1100", 51 | "avocado_sales", 52 | "bank", 53 | "bank32nh", 54 | "bank8FM", 55 | "Bank_Customer_Churn_Dataset", 56 | "banknote_authentication", 57 | "baseball", 58 | "Basketball_c", 59 | "Bias_correction_r", 60 | "Bias_correction_r_2", 61 | "BLE_RSSI_dataset_for_Indoor_localization", 62 | "blogfeedback", 63 | "BNG(breast-w)", 64 | "BNG(cmc)", 65 | "BNG(echoMonths)", 66 | "BNG(lowbwt)", 67 | "BNG(mv)", 68 | "BNG(stock)", 69 | "BNG(tic-tac-toe)", 70 | "Brazilian_houses_reproduced", 71 | "California-Housing-Classification", 72 | "Cardiovascular-Disease-dataset", 73 | "car-evaluation", 74 | "CDC_Diabetes_Health_Indicators", 75 | "churn", 76 | "Click_prediction_small", 77 | "cmc", 78 | "combined_cycle_power_plant", 79 | "communities_and_crime", 80 | "company_bankruptcy_prediction", 81 | "compass", 82 | "compass_reg", 83 | "concrete_compressive_strength", 84 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-10.0GHz(Urbinati)", 85 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-10.5GHz(Urbinati)", 86 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-11.0GHz(Urbinati)", 87 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-9.0GHz(Urbinati)", 88 | "Contaminant-detection-in-packaged-cocoa-hazelnut-spread-jars-using-Microwaves-Sensing-and-Machine-Learning-9.5GHz(Urbinati)", 89 | "contraceptive_method_choice", 90 | "CookbookReviews", 91 | "CPMP-2015-regression", 92 | "CPMP-2015-runtime-regression", 93 | "CPS1988", 94 | "cpu_act", 95 | "cpu_small", 96 | "credit", 97 | "Credit_c", 98 | "credit_reg", 99 | "Customer_Personality_Analysis", 100 | "customer_satisfaction_in_airline", 101 | "dabetes_130-us_hospitals", 102 | "Data_Science_for_Good_Kiva_Crowdfunding", 103 | "Data_Science_Salaries", 104 | "dataset_sales", 105 | "debutanizer", 106 | "default_of_credit_card_clients", 107 | "delta_ailerons", 108 | "delta_elevators", 109 | "Diabetic_Retinopathy_Debrecen", 110 | "Diamonds", 111 | "dis", 112 | "dna", 113 | "drug_consumption", 114 | "dry_bean_dataset", 115 | "E-CommereShippingData", 116 | "eeg-eye-state", 117 | "electricity", 118 | "elevators", 119 | "Employee", 120 | "estimation_of_obesity_levels", 121 | "eye_movements", 122 | "eye_movements_bin", 123 | "Facebook_Comment_Volume", 124 | "FICO-HELOC-cleaned", 125 | "fifa", 126 | "Firm-Teacher_Clave-Direction_Classification", 127 | "first-order-theorem-proving", 128 | "Fitness_Club_c", 129 | "Food_Delivery_Time", 130 | "FOREX_audcad-day-High", 131 | "FOREX_audcad-hour-High", 132 | "FOREX_audchf-day-High", 133 | "FOREX_audjpy-day-High", 134 | "FOREX_audjpy-hour-High", 135 | "FOREX_audsgd-hour-High", 136 | "FOREX_audusd-hour-High", 137 | "FOREX_cadjpy-day-High", 138 | "FOREX_cadjpy-hour-High", 139 | "fried", 140 | "GAMETES_Epistasis_2-Way_20atts_0.1H_EDM-1_1", 141 | "GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_50_EDM-2_001", 142 | "garments_worker_productivity", 143 | "gas-drift", 144 | "gas_turbine_CO_and_NOx_emission", 145 | "Gender_Gap_in_Spanish_WP", 146 | "GesturePhaseSegmentationProcessed", 147 | "gina_agnostic", 148 | "golf_play_dataset_extended", 149 | "Goodreads-Computer-Books", 150 | "healthcare_insurance_expenses", 151 | "Heart-Disease-Dataset-(Comprehensive)", 152 | "heloc", 153 | "hill-valley", 154 | "house_16H", 155 | "house_16H_reg", 156 | "house_8L", 157 | "houses", 158 | "house_sales_reduced", 159 | "housing_price_prediction", 160 | "HR_Analytics_Job_Change_of_Data_Scientists", 161 | "htru", 162 | "ibm-employee-performance", 163 | "IBM_HR_Analytics_Employee_Attrition_and_Performance", 164 | "IEEE80211aa-GATS", 165 | "Indian_pines", 166 | "INNHotelsGroup", 167 | "Insurance", 168 | "internet_firewall", 169 | "internet_usage", 170 | "Intersectional-Bias-Assessment", 171 | "in_vehicle_coupon_recommendation", 172 | "Is-this-a-good-customer", 173 | "JapaneseVowels", 174 | "jm1", 175 | "Job_Profitability", 176 | "jungle_chess_2pcs_raw_endgame_complete", 177 | "Kaggle_bike_sharing_demand_challange", 178 | "kc1", 179 | "KDD", 180 | "KDDCup09_upselling", 181 | "kdd_ipums_la_97-small", 182 | "kin8nm", 183 | "kropt", 184 | "kr-vs-k", 185 | "Laptop_Prices_Dataset", 186 | "Large-scale_Wave_Energy_Farm_Perth_100", 187 | "Large-scale_Wave_Energy_Farm_Perth_49", 188 | "Large-scale_Wave_Energy_Farm_Sydney_100", 189 | "Large-scale_Wave_Energy_Farm_Sydney_49", 190 | "law-school-admission-bianry", 191 | "led24", 192 | "led7", 193 | "letter", 194 | "Long", 195 | "madeline", 196 | "MagicTelescope", 197 | "mammography", 198 | "Marketing_Campaign", 199 | "maternal_health_risk", 200 | "mauna-loa-atmospheric-co2", 201 | "mfeat-factors", 202 | "mfeat-fourier", 203 | "mfeat-karhunen", 204 | "mfeat-morphological", 205 | "mfeat-pixel", 206 | "mfeat-zernike", 207 | "MiamiHousing2016", 208 | "MIC", 209 | "mice_protein_expression", 210 | "microaggregation2", 211 | "mobile_c36_oversampling", 212 | "Mobile_Phone_Market_in_Ghana", 213 | "Mobile_Price_Classification", 214 | "mozilla4", 215 | "mv", 216 | "NASA_PHM2008", 217 | "naticusdroid+android+permissions+dataset", 218 | "National_Health_and_Nutrition_Health_Survey", 219 | "national-longitudinal-survey-binary", 220 | "NHANES_age_prediction", 221 | "okcupid_stem", 222 | "one-hundred-plants-margin", 223 | "one-hundred-plants-shape", 224 | "one-hundred-plants-texture", 225 | "online_shoppers", 226 | "optdigits", 227 | "ozone_level", 228 | "ozone-level-8hr", 229 | "page-blocks", 230 | "Parkinson_Multiple_Sound_Recording", 231 | "Parkinsons_Telemonitoring", 232 | "pc1", 233 | "pc3", 234 | "pc4", 235 | "pendigits", 236 | "Performance-Prediction", 237 | "PhishingWebsites", 238 | "phoneme", 239 | "Physicochemical_r", 240 | "PieChart3", 241 | "Pima_Indians_Diabetes_Database", 242 | "PizzaCutter3", 243 | "pol", 244 | "pole", 245 | "pol_reg", 246 | "predict_students_dropout_and_academic_success", 247 | "puma32H", 248 | "puma8NH", 249 | "Pumpkin_Seeds", 250 | "qsar", 251 | "qsar_aquatic_toxicity", 252 | "QSAR_biodegradation", 253 | "qsar_fish_toxicity", 254 | "Rain_in_Australia", 255 | "rice_cammeo_and_osmancik", 256 | "ringnorm", 257 | "rl", 258 | "Satellite", 259 | "satellite_image", 260 | "satimage", 261 | "SDSS17", 262 | "segment", 263 | "seismic+bumps", 264 | "semeion", 265 | "sensory", 266 | "shill-bidding", 267 | "Shipping", 268 | "Shop_Customer_Data", 269 | "shrutime", 270 | "shuttle", 271 | "Smoking_and_Drinking_Dataset_with_body_signal", 272 | "socmob", 273 | "space_ga", 274 | "spambase", 275 | "splice", 276 | "sports_articles_for_objectivity_analysis", 277 | "statlog", 278 | "steel_industry_energy_consumption", 279 | "steel_plates_faults", 280 | "stock", 281 | "stock_fardamento02", 282 | "Student_Alcohol_Consumption", 283 | "Student_Performance_Portuguese", 284 | "sulfur", 285 | "Superconductivty", 286 | "svmguide3", 287 | "sylvine", 288 | "taiwanese_bankruptcy_prediction", 289 | "telco-customer-churn", 290 | "Telecom_Churn_Dataset", 291 | "texture", 292 | "thyroid", 293 | "thyroid-ann", 294 | "thyroid-dis", 295 | "topo_2_1", 296 | "treasury", 297 | "turiye_student_evaluation", 298 | "twonorm", 299 | "UJIndoorLoc", 300 | "UJI_Pen_Characters", 301 | "vehicle", 302 | "volkert", 303 | "volume", 304 | "VulNoneVul", 305 | "walking-activity", 306 | "wall-robot-navigation", 307 | "water_quality", 308 | "Water_Quality_and_Potability", 309 | "Waterstress", 310 | "waveform-5000", 311 | "waveform_database_generator", 312 | "waveform_database_generator_version_1", 313 | "weather_izmir", 314 | "website_phishing", 315 | "Wilt", 316 | "wind", 317 | "wine", 318 | "wine+quality", 319 | "Wine_Quality_red", 320 | "wine-quality-red", 321 | "Wine_Quality_white", 322 | "wine-quality-white", 323 | "yeast", 324 | ] 325 | 326 | def __init__(self, name, task_id=None): 327 | super().__init__(name, task_id) 328 | 329 | def prepare_data(self, download_dir): 330 | sub_dir = Path(download_dir) / "talent" 331 | if not sub_dir.exists(): 332 | zip_path = os.path.join(download_dir, "talent.zip") 333 | gdown.download(id="1-dzY-BhMzcqjCM8vMTkVwa0hOYQ1598T", output=zip_path) 334 | with zipfile.ZipFile(zip_path, "r") as z: 335 | z.extractall(sub_dir) 336 | os.remove(zip_path) 337 | 338 | # Data is split into numeric 'N' and categorical 'C' files 339 | N_splits = [] 340 | C_splits = [] 341 | y_splits = [] 342 | for split in ("train", "val", "test"): 343 | npy = sub_dir / "data" / self.name / f"N_{split}.npy" 344 | if npy.exists(): 345 | N_splits.append(np.load(npy, allow_pickle=True).astype(np.float32)) 346 | npy = sub_dir / "data" / self.name / f"C_{split}.npy" 347 | if npy.exists(): 348 | C_splits.append(np.load(npy, allow_pickle=True)) 349 | y_splits.append( 350 | np.load(sub_dir / "data" / self.name / f"y_{split}.npy", allow_pickle=True) 351 | ) 352 | if C_splits: 353 | C = OrdinalEncoder().fit_transform(np.concatenate(C_splits, axis=0)) 354 | if N_splits: 355 | N = np.concatenate(N_splits, axis=0) 356 | self.X = np.concatenate((N, C), axis=1) 357 | self.metadata["categorical_feature_inds"] = [ 358 | i + N.shape[1] for i in range(C.shape[1]) 359 | ] 360 | else: 361 | self.X = C 362 | self.metadata["categorical_feature_inds"] = list(range(C.shape[1])) 363 | else: 364 | self.X = np.concatenate(N_splits, axis=0) 365 | self.y = np.concatenate(y_splits, axis=0) 366 | if self.y.dtype == "object": 367 | self.y = LabelEncoder().fit_transform(self.y) 368 | self.metadata["target_type"] = "classification" 369 | else: 370 | self.metadata["target_type"] = "regression" 371 | self.y = self.y.squeeze() 372 | 373 | train_len, val_len, test_len = len(y_splits[0]), len(y_splits[1]), len(y_splits[2]) 374 | self._train_inds = range(train_len) 375 | self._val_inds = range(train_len, train_len + val_len) 376 | self._test_inds = range(train_len + val_len, train_len + val_len + test_len) 377 | 378 | def all_instances(self): 379 | return self.X, self.y 380 | 381 | def train_inds(self): 382 | return self._train_inds 383 | 384 | def val_inds(self): 385 | return self._val_inds 386 | 387 | def test_inds(self): 388 | return self._test_inds 389 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import random 4 | import sys 5 | from typing import Optional 6 | 7 | import faiss 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torch.distributed as dist 12 | import torch.nn as nn 13 | from sklearn.impute import SimpleImputer 14 | from sklearn.preprocessing import LabelEncoder, StandardScaler 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | 18 | def compute_losses(output, task, y, config): 19 | """ 20 | Compute the averaged classification and regression losses. 21 | 22 | Parameters: 23 | - output: Tensor of shape [L, B, C+1], the outputs of the network. 24 | - task: Tensor of shape [B, 1, 1], containing 0 for classification or 1 for regression. 25 | - y: Tensor of shape [L, B, 1], the target values. 26 | 27 | Returns: 28 | - class_loss: Averaged classification loss (CrossEntropyLoss) or None if no classification task. 29 | - reg_loss: Averaged regression loss (MSELoss) or None if no regression task. 30 | """ 31 | 32 | # reshape the output and y tensors to [B, L, C+1] and [B, L, 1] respectively 33 | output = output.transpose(0, 1) 34 | y = y.transpose(0, 1) 35 | 36 | # Flatten the task tensor to [B] 37 | task = task.view(-1) 38 | classification_mask = task == 0 39 | regression_mask = task == 1 40 | 41 | class_loss = torch.zeros(1, device=output.device) 42 | reg_loss = torch.zeros(1, device=output.device) 43 | 44 | # Classification Loss 45 | if classification_mask.any(): 46 | # Select classification batches 47 | outputs_class = output[classification_mask] # Shape: [num_class_batches, L, C+1] 48 | y_class = y[classification_mask] # Shape: [num_class_batches, L, 1] 49 | 50 | # Reshape for CrossEntropyLoss: [N, C+1], targets [N] 51 | outputs_class = outputs_class.view(-1, outputs_class.size(-1)) # [N, C+1] 52 | y_class = y_class.view(-1).long().squeeze() # [N] 53 | 54 | # Compute CrossEntropyLoss 55 | ce_loss_fn = nn.CrossEntropyLoss( 56 | reduction="mean", label_smoothing=config.training.label_smoothing 57 | ) 58 | class_loss = ce_loss_fn(outputs_class, y_class) 59 | 60 | # Regression Loss 61 | if regression_mask.any(): 62 | # Select regression batches 63 | outputs_reg = output[regression_mask, :, -1] # Shape: [num_reg_batches, L] 64 | y_reg = y[regression_mask].squeeze(-1) # Shape: [num_reg_batches, L] 65 | 66 | # Compute MSELoss 67 | mse_loss_fn = nn.MSELoss(reduction="mean") 68 | reg_loss = mse_loss_fn(outputs_reg, y_reg) 69 | 70 | return class_loss, reg_loss 71 | 72 | 73 | def get_combined_loss(loss_cls, loss_reg, task, config): 74 | """ 75 | Combine classification and regression losses based on the task ratio. 76 | Parameters: 77 | - loss_cls: Tensor, classification loss. 78 | - loss_reg: Tensor, regression loss. 79 | - task: Tensor, indicating the task type (0 for classification, 1 for regression). 80 | - config: Configuration object containing training parameters. 81 | Returns: 82 | - loss: Tensor, combined loss. 83 | """ 84 | reg_ratio = task.float().mean() 85 | loss_cls = loss_cls / config.training.num_agg 86 | loss_reg = loss_reg / config.training.num_agg 87 | loss = loss_cls * (1 - reg_ratio) + loss_reg * reg_ratio 88 | return loss 89 | 90 | 91 | class FAISS: 92 | """ 93 | This class initializes a FAISS index with the provided data and allows for efficient nearest neighbor search. 94 | """ 95 | 96 | def __init__( 97 | self, X: np.ndarray, use_hnsw: bool = False, hnsw_m: int = 32, metric: str = "L2" 98 | ) -> None: 99 | """ 100 | Initializes the FAISS index with the provided data. 101 | Args: 102 | X (np.ndarray): The data to index, shape should be (n_samples, n_features). 103 | use_hnsw (bool): Whether to use HNSW index or not. 104 | hnsw_m (int): The number of bi-directional links created for each element in the HNSW index. 105 | metric (str): The distance metric to use, either "L2" for Euclidean distance or "IP" for inner product. 106 | """ 107 | assert isinstance(X, np.ndarray), "X must be a numpy array" 108 | X = np.ascontiguousarray(X) 109 | X = X.astype(np.float32) 110 | if use_hnsw: 111 | self.index = faiss.IndexHNSWFlat(X.shape[1], hnsw_m) 112 | if metric == "L2": 113 | self.index.metric_type = faiss.METRIC_L2 114 | elif metric == "IP": 115 | self.index.metric_type = faiss.METRIC_INNER_PRODUCT 116 | else: 117 | raise NotImplementedError 118 | else: 119 | if metric == "L2": 120 | self.index = faiss.IndexFlatL2(X.shape[1]) 121 | elif metric == "IP": 122 | self.index = faiss.IndexFlatIP(X.shape[1]) 123 | else: 124 | raise NotImplementedError 125 | self.index.add(X) 126 | 127 | def get_knn_indices(self, queries: np.ndarray | torch.Tensor, k: int) -> np.ndarray: 128 | """retreive the k-nearest neighbors indices for the given queries. 129 | 130 | Args: 131 | queries (np.ndarray|torch.Tensor): query points for which to find nearest neighbors. 132 | k (int): number of nearest neighbors to retrieve. 133 | 134 | Returns: 135 | np.ndarray: k nearest neighbors indices for each query point. 136 | """ 137 | if isinstance(queries, torch.Tensor): 138 | queries = queries.cpu().numpy() 139 | queries = np.ascontiguousarray(queries) 140 | assert isinstance(k, int) 141 | 142 | knns = self.index.search(queries, k) 143 | indices_Xs = knns[1] 144 | return indices_Xs 145 | 146 | 147 | def seed_everything(seed: int): 148 | """Set the random seed for reproducibility. 149 | 150 | Args: 151 | seed (int): seed value to set for random number generators. 152 | """ 153 | random.seed(seed) 154 | os.environ["PYTHONHASHSEED"] = str(seed) 155 | np.random.seed(seed) 156 | torch.manual_seed(seed) 157 | torch.cuda.manual_seed(seed) 158 | torch.backends.cudnn.deterministic = True 159 | torch.backends.cudnn.benchmark = True 160 | 161 | 162 | def cleanup(): 163 | """Cleanup the distributed training environment.""" 164 | dist.destroy_process_group() 165 | 166 | 167 | def signal_handler(*_): 168 | """Handle Ctrl+C signal to gracefully exit the program.""" 169 | print("Received Ctrl+C, exiting...") 170 | cleanup() 171 | sys.exit(0) 172 | 173 | 174 | def get_module(model: nn.Module) -> nn.Module: 175 | """Get the underlying module from a DistributedDataParallel model. 176 | Args: 177 | model (torch.nn.Module): The model, possibly wrapped in DistributedDataParallel. 178 | Returns: 179 | torch.nn.Module: The underlying module if model is DistributedDataParallel, else the model itself. 180 | """ 181 | return model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model 182 | 183 | 184 | def log_param_norms( 185 | model: nn.Module, writer: SummaryWriter, step: int, task: str, global_step: int 186 | ): 187 | """Log the norms of model parameters to TensorBoard. 188 | 189 | Args: 190 | model (nn.Module): model to log parameter norms for. 191 | writer (SummaryWriter): TensorBoard writer to log the norms. 192 | step (int): current training step, used for logging. 193 | task (str): regression or classification task, used for logging. 194 | global_step (int): current global step, used for logging. 195 | """ 196 | model = get_module(model) 197 | # Log encoder weight and bias norms (if bias exists) 198 | writer.add_scalar("norms/encoder_weight", torch.norm(model.encoder.weight).item(), step) 199 | if model.encoder.bias is not None: 200 | writer.add_scalar("norms/encoder_bias", torch.norm(model.encoder.bias).item(), step) 201 | 202 | # Log y_encoder weight and bias norms (if bias exists) 203 | writer.add_scalar("norms/y_encoder_weight", torch.norm(model.y_encoder.weight).item(), step) 204 | if model.y_encoder.bias is not None: 205 | writer.add_scalar("norms/y_encoder_bias", torch.norm(model.y_encoder.bias).item(), step) 206 | 207 | # Log transformer encoder total weights and biases (if biases exist) 208 | transformer_weights_norm = sum( 209 | torch.norm(layer.kv_proj.weight).item() for layer in model.transformer_encoder 210 | ) 211 | 212 | writer.add_scalar("norms/transformer_weights", transformer_weights_norm, step) 213 | 214 | # Log classifier head weight and bias norms 215 | head_weights_norm = sum(torch.norm(param).item() for param in model.head.parameters()) 216 | 217 | writer.add_scalar("norms/head_weights_biases", head_weights_norm, step) 218 | 219 | total_norm = torch.norm( 220 | torch.stack( 221 | [torch.norm(p.grad.detach(), 2) for p in model.parameters() if p.grad is not None] 222 | ), 223 | 2, 224 | ) 225 | total_norm1 = torch.norm( 226 | torch.stack( 227 | [torch.norm(p.grad.detach(), 1) for p in model.parameters() if p.grad is not None] 228 | ), 229 | 1, 230 | ) 231 | param_norm = torch.norm(torch.stack([torch.norm(p.detach(), 2) for p in model.parameters()]), 2) 232 | writer.add_scalar("Gradient/Norm/", total_norm, global_step=global_step) 233 | writer.add_scalar("Gradient/Norm_L1/", total_norm1, global_step=global_step) 234 | writer.add_scalar("Parameter/Norm/", param_norm, global_step=global_step) 235 | writer.add_scalar("Prior/reg_ratio", task.float().mean().item(), global_step=global_step) 236 | 237 | 238 | def print_on_master_only(is_master: bool): 239 | """Override the built-in print function to only print on the master process. 240 | Args: 241 | is_master (bool): Whether the current process is the master process. 242 | """ 243 | import builtins as __builtin__ 244 | 245 | builtin_print = __builtin__.print 246 | 247 | def print(*args, **kwargs): 248 | force = kwargs.pop("force", False) 249 | if is_master or force: 250 | builtin_print(*args, **kwargs) 251 | 252 | __builtin__.print = print 253 | 254 | 255 | def init_dist(device: str): 256 | """Initialize distributed training. 257 | 258 | Args: 259 | device (str): The device to use for training (e.g., "cuda:0" or "cpu"). 260 | 261 | Returns: 262 | bool: Whether distributed training is being used. 263 | int: The rank of the current process. 264 | str: The device to use for training. 265 | """ 266 | if "LOCAL_RANK" in os.environ: 267 | rank = int(os.environ["LOCAL_RANK"]) 268 | print("torch.distributed.launch and my rank is", rank) 269 | torch.cuda.set_device(rank) 270 | torch.distributed.init_process_group( 271 | backend="nccl", 272 | init_method="env://", 273 | timeout=datetime.timedelta(seconds=20), 274 | world_size=torch.cuda.device_count(), 275 | rank=rank, 276 | ) 277 | torch.distributed.barrier() 278 | print_on_master_only(rank == 0) 279 | return True, rank, f"cuda:{rank}" 280 | else: 281 | return False, 0, device 282 | 283 | 284 | class DataPreprocessor: 285 | def __init__(self, impute_method: str = "mean", encode_y: bool = False): 286 | """Initialize the DataPreprocessor. 287 | Parameters: 288 | - impute_method: str, method for imputing missing values (default is "mean"). 289 | - encode_y: bool, whether to encode the target variable y (default is False). 290 | """ 291 | self.impute_method = impute_method 292 | self.encode_y = encode_y 293 | # Initialize the imputer and scaler 294 | self.imputer = SimpleImputer(strategy=impute_method) 295 | self.scaler = StandardScaler() 296 | 297 | def convert_cat2num( 298 | self, X: pd.DataFrame, y: pd.Series | None = None 299 | ) -> tuple[np.ndarray, np.ndarray, list[bool]]: 300 | """Convert categorical columns to numerical values. 301 | 302 | Parameters: 303 | - X: DataFrame, input features. 304 | - y: Series or None, target variable. 305 | 306 | Returns: 307 | - X: numpy array, features with categorical columns converted to numerical. 308 | - y: numpy array, target variable with categorical values converted to numerical. 309 | - cat_vals: list of bool, indicating which columns were categorical. 310 | """ 311 | cat_vals = [] 312 | # Convert categorical columns to numerical values using dataframe information 313 | for col in X.columns: 314 | if X[col].dtype == "object" or pd.api.types.is_categorical_dtype(X[col]): 315 | le = LabelEncoder() 316 | X[col] = le.fit_transform(X[col]) 317 | cat_vals.append(True) 318 | else: 319 | cat_vals.append(False) 320 | X = X.to_numpy().astype(np.float32) 321 | 322 | # Convert categorical target to numerical values 323 | if y is not None: 324 | if y.dtype == "object" or pd.api.types.is_categorical_dtype(y): 325 | le = LabelEncoder() 326 | y = le.fit_transform(y) 327 | cat_vals.append(True) 328 | else: 329 | y = y.to_numpy().astype(np.float32) 330 | cat_vals.append(False) 331 | else: 332 | y = X[:, -1] 333 | X = np.delete(X, -1, axis=1) 334 | return X, y, cat_vals 335 | 336 | def fit_transform(self, X: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 337 | """ 338 | Fit the imputer and scaler on the training data and transform it. 339 | 340 | Parameters: 341 | - X: numpy array, training features. 342 | 343 | Returns: 344 | - X: numpy array, transformed training features. 345 | """ 346 | # Impute missing values 347 | X = self.imputer.fit_transform(X) 348 | 349 | # Scale the features 350 | X = self.scaler.fit_transform(X) 351 | 352 | return X 353 | 354 | def transform(self, X: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 355 | """ 356 | Transform the test data using the fitted imputer and scaler. 357 | 358 | Parameters: 359 | - X: numpy array, testing features. 360 | 361 | Returns: 362 | - X: numpy array, transformed testing features. 363 | """ 364 | # Impute missing values 365 | X = self.imputer.transform(X) 366 | 367 | # Scale the features 368 | X = self.scaler.transform(X) 369 | 370 | return X 371 | 372 | 373 | def standardize(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 374 | """Standardize the input tensor by subtracting the mean and dividing by the standard deviation. 375 | Args: 376 | tensor (torch.Tensor): Input tensor to standardize. 377 | Returns: 378 | torch.Tensor: Standardized tensor. 379 | torch.Tensor: Mean of the input tensor. 380 | torch.Tensor: Standard deviation of the input tensor. 381 | """ 382 | y_means = tensor.mean(dim=0) 383 | y_stds = tensor.std(dim=0) + 1e-6 384 | return (tensor - y_means) / y_stds, y_means, y_stds 385 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import signal 5 | 6 | import hydra 7 | import openml 8 | import torch.distributed as dist 9 | import schedulefree 10 | import torch 11 | import torch.nn as nn 12 | from omegaconf import DictConfig, OmegaConf 13 | from torch.nn.attention import SDPBackend, sdpa_kernel 14 | from torch.utils.data import DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | from tqdm import tqdm 17 | 18 | from dataset import FullDataset, collate_fn 19 | from eval_full import FullEval 20 | from model import TabDPTModel 21 | from utils import ( 22 | cleanup, 23 | compute_losses, 24 | get_combined_loss, 25 | init_dist, 26 | log_param_norms, 27 | seed_everything, 28 | signal_handler, 29 | ) 30 | 31 | 32 | def save( 33 | model: nn.Module, 34 | optimizer: torch.optim.Optimizer, 35 | config: DictConfig, 36 | stats: dict, 37 | path: str, 38 | name: str, 39 | ) -> None: 40 | """Save the model and optimizer state. 41 | 42 | Args: 43 | model (nn.Module): The model to save. 44 | optimizer (torch.optim.Optimizer): The optimizer to save. 45 | config (DictConfig): The configuration for the experiment. 46 | stats (dict): The training statistics. 47 | path (str): The path to save the checkpoint. 48 | name (str): The name of the checkpoint file. 49 | 50 | Returns: 51 | None 52 | """ 53 | ckpt = { 54 | "model": model.state_dict(), 55 | "opt": optimizer.state_dict(), 56 | "cfg": config, 57 | "stats": stats, 58 | } 59 | torch.save(ckpt, f"{path}/{name}.ckpt") 60 | 61 | 62 | def save_eval_callback( 63 | model: nn.Module, 64 | optimizer: torch.optim.Optimizer, 65 | vals: list, 66 | epoch: int, 67 | config: DictConfig, 68 | writer: SummaryWriter, 69 | rank: int, 70 | stats: dict, 71 | ) -> None: 72 | """Save evaluation checkpoints and log metrics. 73 | 74 | Args: 75 | model (nn.Module): The model being trained. 76 | optimizer (torch.optim.Optimizer): The optimizer used for training. 77 | vals (list): A list of validation datasets. 78 | epoch (int): The current epoch number. 79 | config (DictConfig): The configuration for the training process. 80 | writer (SummaryWriter): The TensorBoard writer for logging. 81 | rank (int): The rank of the current process. 82 | stats (dict): A dictionary to store training statistics. 83 | 84 | Returns: 85 | None 86 | """ 87 | # save latest model 88 | stats["epoch_in_training"] = epoch 89 | if rank == 0: 90 | save(model, optimizer, config, stats, config.exp_path, "latest") 91 | 92 | # eval and save best 93 | model.eval() 94 | if hasattr(optimizer, "eval"): 95 | optimizer.eval() 96 | 97 | for val in vals: 98 | for name, metric in val.eval(model, context_length=config.training.eval_seq_len).items(): 99 | if rank == 0: 100 | print(f"Epoch {epoch} | {name}: {metric}") 101 | # only save checkpoints using the validation metric 102 | if metric > stats.get(f"best_{name}", -float("inf")): 103 | stats[f"best_{name}"] = metric 104 | if name in config.logging.save_metrics: 105 | save(model, optimizer, config, stats, config.exp_path, f"best_{name}") 106 | writer.add_scalar(f"val/{name}/", metric, epoch) 107 | 108 | # reset the model and optimizer to training mode 109 | model.train() 110 | if hasattr(optimizer, "train"): 111 | optimizer.train() 112 | 113 | 114 | def set_experiment_from_config(config: DictConfig) -> tuple[nn.Module, torch.optim.Optimizer, dict]: 115 | """Set up the experiment based on the provided configuration. 116 | 117 | Args: 118 | config (DictConfig): The configuration for the experiment. 119 | 120 | Raises: 121 | Exception: If the reset policy is not recognized. 122 | 123 | Returns: 124 | Tuple[nn.Module, torch.optim.Optimizer, dict]: A tuple containing the model, 125 | optimizer, and a dictionary with training statistics. 126 | """ 127 | 128 | load_state_from_saved = False 129 | stats = {"epoch_in_training": 0} 130 | 131 | seed_everything(config.seed) 132 | 133 | if hasattr(config.data, "single_dataset_id") and config.data.single_dataset_id is not None: 134 | # Load the dataset metadata (we assume that download_data is not necessary) 135 | dataset = openml.datasets.get_dataset(config.data.single_dataset_id, download_data=False) 136 | # Use the dataset name for the experiment (replace spaces with underscores for file system safety) 137 | dataset_name = dataset.name.replace(" ", "_") 138 | config.exp_name = dataset_name 139 | print("Using single dataset for training:", dataset_name) 140 | config.exp_path = f"runs/{config.folder}/{config.exp_name}" 141 | 142 | # Only rank 0 should modify the directory 143 | is_master = (not torch.distributed.is_initialized()) or (torch.distributed.get_rank() == 0) 144 | 145 | # directory setup and checkpoint handling 146 | if is_master: 147 | if os.path.exists(config.exp_path): 148 | print("Directory already exists:", config.exp_path) 149 | # restart training 150 | if config.training.reset_policy == "rm" or not os.path.exists( 151 | f"{config.exp_path}/latest.ckpt" 152 | ): 153 | print("Remove the existing directory.") 154 | shutil.rmtree(config.exp_path, ignore_errors=True) 155 | os.makedirs(config.exp_path, exist_ok=True) 156 | # continue training from a saved checkpoint 157 | elif config.training.reset_policy == "cnt": 158 | load_state_from_saved = True 159 | checkpoint = torch.load(f"{config.exp_path}/latest.ckpt") 160 | model_state = checkpoint["model"] 161 | opt_state = checkpoint["opt"] 162 | stats = checkpoint["stats"] 163 | non_saved_num_epochs = config.training.num_epochs 164 | config.training.num_epochs = non_saved_num_epochs 165 | print("Continue training. Using saved config.") 166 | else: 167 | raise ValueError( 168 | "Invalid reset_policy: must be either 'cnt' (resume) or 'rm' (delete)." 169 | ) 170 | else: 171 | os.makedirs(config.exp_path, exist_ok=True) 172 | print("Created directory: ", config.exp_path) 173 | 174 | # Synchronize all processes so that the directory is created before continuing 175 | if torch.distributed.is_initialized(): 176 | torch.distributed.barrier() 177 | 178 | # Check if the number of GPUs matches the configuration 179 | # assert torch.cuda.device_count() == len( 180 | # config.env.gpus 181 | # ), f"Number of GPUs does not match the number of GPUs in the config, expected {len(config.env.gpus)}, found {torch.cuda.device_count()}" 182 | 183 | assert ( 184 | config.training.batch_size % len(config.env.gpus) == 0 185 | ), "Batch size should be divisible by the number of GPUs" 186 | 187 | # adapt batch size for distributed training 188 | config.training.batch_size //= len(config.env.gpus) 189 | 190 | # save config 191 | OmegaConf.save(config=config, f=f"{config.exp_path}/config.yaml") 192 | 193 | print(f"Using {config.env.gpus} GPUs") 194 | 195 | # instantiate the model 196 | model = TabDPTModel( 197 | dropout=config.training.dropout, 198 | n_out=config.model.max_num_classes, 199 | nhead=config.model.nhead, 200 | nhid=config.model.nhid_factor * config.model.emsize, 201 | ninp=config.model.emsize, 202 | nlayers=config.model.nlayers, 203 | num_features=config.model.max_num_features, 204 | ) 205 | print(f"{sum(p.numel() for p in model.parameters())/1e6:.{2}f} M parameters") 206 | 207 | # load the model state if resuming from a saved checkpoint 208 | # TODO : handle model_state definition in a more robust way 209 | if load_state_from_saved: 210 | model = TabDPTModel.load(model_state, config) 211 | del model_state 212 | 213 | # instantiate the optimizer 214 | optimizer = schedulefree.AdamWScheduleFree( 215 | model.parameters(), 216 | lr=config.training.lr, 217 | weight_decay=config.training.weight_decay, 218 | warmup_steps=1000, 219 | betas=(0.98, 0.999), 220 | ) 221 | 222 | # load the optimizer state if resuming from a saved checkpoint 223 | # TODO : handle model_state definition in a more robust way 224 | if load_state_from_saved: 225 | optimizer.load_state_dict(opt_state) 226 | del opt_state 227 | 228 | return model, optimizer, stats 229 | 230 | 231 | @hydra.main(version_base=None, config_path="configs", config_name="default_config") 232 | def main(config: DictConfig): 233 | """Main function to run the training process. 234 | 235 | Args: 236 | config (DictConfig): The configuration for the training process. 237 | """ 238 | 239 | # signal handling 240 | signal.signal(signal.SIGINT, signal_handler) 241 | 242 | # print config 243 | print("Config:", OmegaConf.to_yaml(config)) 244 | 245 | # distributed training 246 | using_dist, rank, device = init_dist(config.env.device) 247 | 248 | model, optimizer, stats = set_experiment_from_config(config) 249 | model.to(device) 250 | 251 | # compile the model if specified in the config 252 | if config.training.compile: 253 | model = torch.compile(model) 254 | 255 | # using distributed data parallel if specified in the config 256 | if using_dist: 257 | print("Distributed training enabled.") 258 | model = torch.nn.parallel.DistributedDataParallel( 259 | model, 260 | device_ids=[rank], 261 | output_device=rank, 262 | ) 263 | 264 | # Initialize the SummaryWriter for TensorBoard logging 265 | writer = SummaryWriter(config.exp_path) 266 | 267 | # set validation datasets 268 | # TODO : make this configurable 269 | vals = [ 270 | FullEval( 271 | device=device, 272 | max_feat=config.model.max_num_features, 273 | use_retrieval=config.data.eval_retrieval, 274 | ) 275 | ] 276 | 277 | # select the device type 278 | device_type = "cpu" if config.env.device == "cpu" else "cuda" 279 | 280 | # Use autocast for mixed precision training 281 | selected_autocast = torch.autocast(device_type=device_type, dtype=torch.bfloat16) 282 | 283 | # Initialize the FullDataset for training 284 | print("Initializing Dataset...") 285 | dataset = FullDataset(device, config) 286 | 287 | # initialize the DataLoader for the dataset 288 | data_loader = DataLoader( 289 | dataset, 290 | batch_size=config.training.batch_size, 291 | num_workers=config.env.num_workers, 292 | pin_memory=True, 293 | collate_fn=collate_fn, 294 | shuffle=False, 295 | persistent_workers=True, 296 | ) 297 | 298 | # Create an iterator for the DataLoader 299 | iter_data_loader = iter(data_loader) 300 | 301 | # Initialize the outer training loop 302 | for epoch in range(stats["epoch_in_training"] + 1, config.training.num_epochs + 1): 303 | print("Epoch ", epoch) 304 | 305 | # initialize the loss accumulators 306 | epoch_loss_cls = 0.0 307 | epoch_loss_reg = 0.0 308 | 309 | # set the model and optimizer to training mode 310 | model.train() 311 | if hasattr(optimizer, "train"): 312 | optimizer.train() 313 | 314 | # initialize the inner training loop 315 | for batch in tqdm(range(config.training.num_model_updates * config.training.num_agg)): 316 | # randomly set the evaluation position, i.e. the context length 317 | # ---- 1. choose a common eval_pos --------------------------------------- 318 | if dist.get_rank() == 0: 319 | eval_pos_t = torch.randint( 320 | config.training.min_eval_pos, 321 | config.training.max_eval_pos + 1, 322 | (1,), 323 | device=device, 324 | dtype=torch.long, 325 | ) 326 | else: 327 | eval_pos_t = torch.empty(1, dtype=torch.long, device=device) 328 | 329 | dist.broadcast(eval_pos_t, src=0) 330 | eval_pos = int(eval_pos_t.item()) 331 | # eval_pos = random.randint(config.training.min_eval_pos, config.training.max_eval_pos) 332 | 333 | # get the next batch from the DataLoader 334 | x, y, task = [a.to(device) for a in next(iter_data_loader)] 335 | 336 | # efficient forward pass and loss computation 337 | with selected_autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 338 | output, log_act_norms = model( 339 | x, y.squeeze(-1)[:eval_pos], return_log_act_norms=True 340 | ) 341 | 342 | # compute the losses 343 | loss_cls, loss_reg = compute_losses(output, task, y.squeeze(-1)[eval_pos:], config) 344 | 345 | # reweight the loss for combined classification and regression tasks 346 | loss = get_combined_loss(loss_cls, loss_reg, task, config) 347 | 348 | # detach the log_act_norms to avoid memory leak 349 | log_act_norms = {k: v.detach().item() for k, v in log_act_norms.items()} 350 | 351 | epoch_loss_cls += loss_cls.cpu().detach().item() 352 | epoch_loss_reg += loss_reg.cpu().detach().item() 353 | 354 | # backpropagation 355 | loss.backward() 356 | 357 | # gradient accumulation 358 | if (batch + 1) % config.training.num_agg == 0: 359 | # compute the global step 360 | global_step = ( 361 | batch + config.training.num_model_updates * config.training.num_agg * epoch 362 | ) 363 | 364 | # log the losses and norms 365 | log_param_norms(model, writer, global_step, task, global_step) 366 | 367 | # clip gradients to avoid exploding gradients 368 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.clip_grad_norm) 369 | 370 | # step the optimizer 371 | optimizer.step() 372 | optimizer.zero_grad() 373 | 374 | # log the average losses for the epoch 375 | writer.add_scalar( 376 | "loss/cls-average_train/", epoch_loss_cls / config.training.num_model_updates, epoch 377 | ) 378 | writer.add_scalar( 379 | "loss/reg-average_train/", epoch_loss_reg / config.training.num_model_updates, epoch 380 | ) 381 | 382 | # evaluate the model every eval_every epochs 383 | if epoch % config.logging.eval_every == 0: 384 | # synchronize all processes before evaluation 385 | if using_dist: 386 | torch.distributed.barrier() 387 | 388 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 389 | model_eval: nn.Module = model.module 390 | else: 391 | model_eval: nn.Module = model 392 | 393 | # save evaluation checkpoints and log metrics 394 | save_eval_callback(model_eval, optimizer, vals, epoch, config, writer, rank, stats) 395 | 396 | if using_dist: 397 | torch.distributed.barrier() 398 | 399 | # cleanup the model and optimizer 400 | if using_dist: 401 | cleanup() 402 | 403 | writer.close() 404 | 405 | 406 | if __name__ == "__main__": 407 | main() 408 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import openml 6 | import pandas as pd 7 | import torch 8 | from omegaconf import DictConfig 9 | from sklearn.preprocessing import LabelEncoder 10 | from torch.utils.data import Dataset 11 | from tqdm import tqdm 12 | 13 | from model import pad_x 14 | from utils import FAISS, DataPreprocessor 15 | 16 | 17 | def collate_fn(batch): 18 | """Collate function for DataLoader. 19 | 20 | Args: 21 | batch (list): A list of samples from the dataset. 22 | 23 | Returns: 24 | tuple: A tuple containing the concatenated input tensors and the corresponding target tensors. 25 | """ 26 | ret = [torch.cat(el, dim=1) for el in zip(*batch)] 27 | return ret 28 | 29 | 30 | class FullDataset(Dataset): 31 | """FullDataset class for loading and processing datasets. 32 | This class inherits from `torch.utils.data.Dataset` and is used to load datasets from OpenML, 33 | preprocess them, and provide samples for training a model. 34 | It supports both regression and classification tasks, with the ability to use k-NN for context sampling. 35 | """ 36 | 37 | def __init__(self, device: str, config: DictConfig): 38 | """FullDataset initialization. 39 | 40 | Args: 41 | device (str): The device to use for tensor operations. 42 | config (DictConfig): The configuration for the dataset. 43 | """ 44 | self.steps_per_epoch = ( 45 | config.training.num_agg * config.training.num_model_updates * config.training.num_epochs 46 | ) 47 | self.batch_size = config.training.batch_size 48 | self.device = device 49 | self.context_length = getattr(config.training, "seq_len", 1024) 50 | self.retrieval = getattr(config.data, "retrieval", True) 51 | self.max_feat = config.model.max_num_features 52 | self.y_reg_augment = config.data.y_reg_augment 53 | 54 | # Load dataset IDs from CSV file 55 | # TODO: make dataset_ids configurable 56 | self.dataset_ids = ( 57 | pd.read_csv("data_splits/noleak_training_datasets.csv")["did"].values.ravel().tolist() 58 | ) 59 | self.datasets = defaultdict(dict) 60 | 61 | # Load datasets and preprocess them 62 | for did in tqdm(self.dataset_ids): 63 | # download the dataset from OpenML 64 | dataset = openml.datasets.get_dataset( 65 | did, 66 | download_data=False, 67 | download_qualities=False, 68 | download_features_meta_data=False, 69 | ) 70 | 71 | # Get the data as a pandas DataFrame 72 | X, y, _, _ = dataset.get_data( 73 | dataset_format="dataframe", target=dataset.default_target_attribute 74 | ) 75 | 76 | # TODO: preprocessing is a bit different from the usual .fit_transform() since concatenation happens in the middle 77 | preprocessor = DataPreprocessor() 78 | # Preprocess the data 79 | X, y, cat_vals = preprocessor.convert_cat2num(X, y) 80 | X = preprocessor.imputer.fit_transform(X) 81 | 82 | # Concatenate the target as the last column 83 | X = np.concatenate([X, y[:, None]], axis=1) 84 | X = preprocessor.scaler.fit_transform(X) 85 | 86 | if config.data.retrieval: 87 | faiss_knn = FAISS(X, metric="L2", use_hnsw=False) 88 | else: 89 | faiss_knn = None 90 | 91 | self.datasets[did] = {"X": X, "index": faiss_knn, "cat": cat_vals} 92 | 93 | def __len__(self): 94 | """Return the total number of samples in the dataset.""" 95 | return self.steps_per_epoch * self.batch_size 96 | 97 | def transform_target( 98 | self, y: np.ndarray, cls_threshold: float = 10, cls_prob: float = 0.3 99 | ) -> tuple[np.ndarray, str]: 100 | """Transform the target variable into classification or regression format. 101 | Args: 102 | y (np.ndarray): The target variable. 103 | cls_threshold (int): The threshold for classification. 104 | cls_prob (float): The probability of returning a classification task. 105 | Returns: 106 | tuple: A tuple containing the transformed target variable and the task type ("cls" or "reg"). 107 | """ 108 | unique_y = np.unique(y) 109 | if len(unique_y) > cls_threshold: 110 | if np.random.rand() > cls_prob: 111 | return y, "reg" 112 | else: 113 | num_class = np.random.randint(2, cls_threshold) 114 | cls_boundary = np.random.choice( 115 | sorted(np.unique(y))[1:-1], num_class - 1, replace=False 116 | ) 117 | y = (y[:, None] > cls_boundary[None, :]).sum(1) 118 | return y, "cls" 119 | elif len(unique_y) >= 2: 120 | le = LabelEncoder() 121 | y = le.fit_transform(y) 122 | y = y.astype(np.float32) 123 | return y, "cls" 124 | else: 125 | return y, "reg" 126 | 127 | def check_cls_sample(self, X: torch.Tensor, y: torch.Tensor) -> bool: 128 | """Check if the classification sample is valid. 129 | Args: 130 | X (torch.Tensor): The input features. 131 | y (torch.Tensor): The target labels. 132 | Returns: 133 | bool: True if the sample is valid, False otherwise. 134 | """ 135 | assert y.dim() == 3 and y.shape[1] == 1 and y.shape[2] == 1 136 | assert X.dim() == 3 and X.shape[1] == 1 137 | return True 138 | 139 | def check_reg_sample(self, X: torch.Tensor, y: torch.Tensor) -> bool: 140 | """Check if the regression sample is valid. 141 | 142 | Args: 143 | X (torch.Tensor): The input features. 144 | y (torch.Tensor): The target labels. 145 | Returns: 146 | bool: True if the sample is valid, False otherwise. 147 | """ 148 | assert y.dim() == 3 and y.shape[1] == 1 and y.shape[2] == 1 149 | assert X.dim() == 3 and X.shape[1] == 1 150 | if torch.max(y.abs().ravel()) > 10: 151 | return False 152 | if np.max(np.unique(y.ravel(), return_counts=True)[1]) > 0.95 * y.numel(): 153 | return False 154 | return True 155 | 156 | @torch.no_grad() 157 | def __getitem__(self, _) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 158 | """get a sample from the dataset. 159 | 160 | Returns: 161 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The input features, target values, and task type. 162 | """ 163 | i = 0 164 | while i < 100: 165 | X, y, task = self.generate_sample() 166 | if task == 0: 167 | if self.check_cls_sample(X, y): 168 | break 169 | else: 170 | if self.check_reg_sample(X, y): 171 | break 172 | i += 1 173 | return X, y, task 174 | 175 | @torch.no_grad() 176 | def context_sampler(self, num_features: int, num_samples: int, X_sample: torch.Tensor): 177 | """Sample a context from the input tensor. 178 | 179 | Args: 180 | num_features (int): The number of features in the input tensor. 181 | num_samples (int): The number of samples in the input tensor. 182 | X_sample (torch.Tensor): The input tensor. 183 | 184 | Returns: 185 | tuple[torch.Tensor, torch.Tensor, str]: The context features, target values, and task type. 186 | """ 187 | target_idx = random.randint(0, num_features - 1) 188 | 189 | # If context length > the number of samples, we need to sample with replacement 190 | if self.context_length > num_samples: 191 | random_selection_indices = np.random.choice(num_samples, num_samples, replace=False) 192 | random_selection_indices = np.concatenate( 193 | [ 194 | np.random.choice(num_samples, self.context_length - num_samples, replace=True), 195 | random_selection_indices, 196 | ] 197 | ) 198 | else: 199 | # If context length <= the number of samples, we can sample without replacement 200 | random_selection_indices = np.random.choice( 201 | num_samples, self.context_length, replace=False 202 | ) 203 | # Select the samples and extract the target feature 204 | X_nni = X_sample[random_selection_indices] 205 | y_nni = X_nni[:, target_idx] 206 | 207 | # Remove the target feature from the feature set 208 | X_nni = np.delete(X_nni, target_idx, axis=1) 209 | 210 | # Transform the target feature into classification or regression format 211 | y_nni, task = self.transform_target(y_nni.ravel(), cls_threshold=10, cls_prob=0.5) 212 | 213 | # Reshape the tensors to match the expected dimensions 214 | X_nni = X_nni.reshape(-1, 1, X_nni.shape[-1]) 215 | y_nni = y_nni.reshape(-1, 1, 1) 216 | 217 | # Select a random subset of features 218 | # Ensure that the number of features sampled is within the limits 219 | num_features_sampled = random.randint( 220 | min(self.max_feat // 2, X_nni.shape[-1] // 2), min(self.max_feat, X_nni.shape[-1]) 221 | ) 222 | random_feature_indices = np.random.choice( 223 | X_nni.shape[-1], num_features_sampled, replace=False 224 | ) 225 | X_nni = X_nni[..., random_feature_indices] 226 | 227 | # Convert to PyTorch tensors 228 | X_nni = torch.Tensor(X_nni) 229 | y_nni = torch.Tensor(y_nni) 230 | 231 | return X_nni, y_nni, task 232 | 233 | @torch.no_grad() 234 | def use_knn(self, X_sample, index_sample, num_samples, num_features): 235 | """Use k-NN to find similar samples. 236 | 237 | Args: 238 | X_sample (np.ndarray): The input features. 239 | index_sample (FAISS): The FAISS index for k-NN search. 240 | num_samples (int): The number of samples. 241 | num_features (int): The number of features. 242 | 243 | Returns: 244 | Tuple[torch.Tensor, torch.Tensor, str]: The context features, target values, and task type. 245 | """ 246 | # Randomly select a query point 247 | x_q_idx = random.randint(0, num_samples - 1) 248 | x_q = X_sample[x_q_idx].reshape(1, -1).copy() 249 | x_q = x_q.astype(np.float32) 250 | 251 | # Randomly select a target feature index 252 | target_idx = random.randint(0, num_features - 1) 253 | 254 | # Get the indices of the k-nearest neighbors of the query point 255 | indices_X_nni = index_sample.get_knn_indices(x_q, self.context_length) 256 | X_nni = X_sample[torch.tensor(indices_X_nni)] 257 | X_nni = np.swapaxes(X_nni, 0, 1) 258 | 259 | # Extract the target feature and remove it from the feature set 260 | y_nni = X_nni[:, :, target_idx] 261 | X_nni = np.delete(X_nni, target_idx, axis=2) 262 | 263 | # Transform the target feature 264 | y_nni, task = self.transform_target(y_nni.ravel(), cls_threshold=10, cls_prob=0.5) 265 | 266 | # select a random subset of features 267 | num_features_sampled = random.randint( 268 | min(self.max_feat // 2, X_nni.shape[-1] // 2), min(self.max_feat, X_nni.shape[-1]) 269 | ) 270 | random_feature_indices = np.random.choice( 271 | X_nni.shape[-1], num_features_sampled, replace=False 272 | ) 273 | X_nni = X_nni[..., random_feature_indices] 274 | 275 | # Convert to PyTorch tensors 276 | X_nni = torch.Tensor(X_nni) 277 | y_nni = torch.Tensor(y_nni) 278 | return X_nni, y_nni, task 279 | 280 | @torch.no_grad() 281 | def generate_sample(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 282 | """Generate a sample from the dataset. 283 | 284 | Raises: 285 | ValueError: If the generated sample is invalid. 286 | 287 | Returns: 288 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The input features, target values, and task type. 289 | """ 290 | # Randomly select a dataset ID 291 | did_sample = random.choices(self.dataset_ids, k=1)[0] 292 | 293 | # Get the dataset 294 | X_sample, index_sample, _ = self.datasets[did_sample].values() 295 | num_samples, num_features = X_sample.shape 296 | 297 | if self.retrieval: # Randomly select a query point and its k-nearest neighbors as context 298 | X_nni, y_nni, task = self.use_knn(X_sample, index_sample, num_samples, num_features) 299 | else: # Randomly select a sample and its context 300 | X_nni, y_nni, task = self.context_sampler(num_features, num_samples, X_sample) 301 | 302 | # Pad the features to the maximum number of features 303 | X_nni, y_nni = ( 304 | pad_x(X_nni, self.max_feat), 305 | y_nni, 306 | ) 307 | 308 | # Shuffle the rows 309 | shuffle_indices = np.random.choice(self.context_length, self.context_length, replace=False) 310 | X = X_nni[shuffle_indices] 311 | y = y_nni[shuffle_indices] 312 | 313 | if task == "cls": # If classification task, shuffle the labels 314 | le = LabelEncoder() 315 | y_ = le.fit_transform(y.ravel()) 316 | classes = list(le.classes_) 317 | random.shuffle(classes) 318 | mapping = {original: shuffled for original, shuffled in zip(le.classes_, classes)} 319 | y_random = np.vectorize(mapping.get)(y_) 320 | y = y_random 321 | elif task == "reg": # If regression task, apply normalization and augmentation 322 | y_means = y.mean(dim=0) 323 | y_stds = y.std(dim=0) + 1e-6 324 | y = (y - y_means) / y_stds 325 | if self.y_reg_augment: 326 | # Apply random nonlinear transformations to the target for augmentation 327 | # TODO: shouldn't this non-linear transformation be implemented 328 | # only for a certain calls, not always, e.g. with: 329 | # if random.random() < 0.5: 330 | y = random_nonlinear_transform(y) 331 | y_means = y.mean(dim=0) 332 | y_stds = y.std(dim=0) + 1e-6 333 | y = (y - y_means) / y_stds 334 | else: 335 | raise ValueError("Task must be 'cls' or 'reg'") 336 | 337 | return ( 338 | torch.Tensor(X), 339 | torch.Tensor(y).view(-1, 1, 1), 340 | torch.Tensor([0 if task == "cls" else 1]).unsqueeze(-1).unsqueeze(-1).long(), 341 | ) 342 | 343 | 344 | # A small bank of nonlinear functions (feel free to add/remove as appropriate) 345 | FUNCTION_BANK = { 346 | "sin": torch.sin, 347 | "tanh": torch.tanh, 348 | "square": lambda x: x**2, 349 | "identity": lambda x: x, 350 | "step": lambda x: torch.where(x > 0, 1, 0), 351 | "relu": torch.nn.functional.relu, 352 | "sqrt": lambda x: torch.sign(x) * torch.sqrt(torch.abs(x)), 353 | "log": lambda x: torch.log(1 + torch.abs(x)) * torch.sign(x), 354 | } 355 | 356 | 357 | def random_nonlinear_transform( 358 | y: torch.Tensor, 359 | n_transforms: int = 2, 360 | scale_range: tuple = (0.5, 2.0), 361 | bias_range: tuple = (-1.0, 1.0), 362 | ) -> torch.Tensor: 363 | """Applies a series of random nonlinear transformations to the input tensor. 364 | 365 | Args: 366 | y (torch.Tensor): target tensor to be transformed. 367 | n_transforms (int, optional): number of transformations to apply. Defaults to 2. 368 | scale_range (tuple, optional): range of scaling factors to apply. Defaults to (0.5, 2.0). 369 | bias_range (tuple, optional): range of bias values to apply. Defaults to (-1.0, 1.0). 370 | 371 | Returns: 372 | torch.Tensor: Transformed tensor after applying the random nonlinear transformations. 373 | """ 374 | device = y.device 375 | y_out = y.clone() 376 | func_keys = list(FUNCTION_BANK.keys()) 377 | for _ in range(n_transforms): 378 | func_idx = torch.randint(low=0, high=len(func_keys), size=(1,)).item() 379 | func = FUNCTION_BANK[func_keys[func_idx]] 380 | scale = torch.empty(1, device=device).uniform_(*scale_range).item() 381 | bias = torch.empty(1, device=device).uniform_(*bias_range).item() 382 | if torch.rand(1, device=device) < 0.5: 383 | y_out = scale * func(y_out + bias) 384 | else: 385 | y_out = func(scale * y_out + bias) 386 | return y_out.float() 387 | -------------------------------------------------------------------------------- /tabdpt_datasets/openml.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import openml 7 | import pandas as pd 8 | from sklearn.preprocessing import LabelEncoder, OrdinalEncoder 9 | 10 | from tabdpt_datasets.dataset import Dataset 11 | 12 | 13 | class OpenMLDataset(Dataset): 14 | """ 15 | Generic class for loading any OpenML dataset 16 | """ 17 | 18 | @staticmethod 19 | def all_names(): 20 | return None 21 | 22 | @staticmethod 23 | def suite_name(): 24 | return "openml" 25 | 26 | def __init__( 27 | self, 28 | name: str, 29 | task_id: Optional[str] = None, 30 | openml_dataset_id: Optional[str | int] = None, 31 | openml_task_id: Optional[str | int] = None, 32 | ): 33 | """Initializes the OpenMLDataset. 34 | 35 | name (str): The name of the dataset. 36 | task_id (Optional[str], optional): Specifies the type of data split to use. Supported formats: 37 | - "default": Uses the default split provided by the OpenML dataset if available, otherwise 38 | defaults to a random split with seed 0. 39 | - "fold": Uses the fold with the given index from the OpenML task definition. Only 40 | supported with OpenML task IDs. 41 | - "random-seed": Creates a random 70/15/15 split using the specified integer seed. 42 | openml_dataset_id (Optional[str | int], optional): The OpenML dataset ID. Must be specified if 43 | `openml_task_id` is not provided. Defaults to None. 44 | openml_task_id (Optional[str | int], optional): The OpenML task ID. Must be specified if 45 | `openml_dataset_id` is not provided. Defaults to None. 46 | 47 | ValueError: Raised if neither or both `openml_dataset_id` and `openml_task_id` are provided. 48 | ValueError: Raised if `task_id` starts with "fold" but `openml_task_id` is not provided. 49 | ValueError: Raised if `task_id` has an invalid format. 50 | """ 51 | super().__init__(name, task_id) 52 | 53 | # Initialize split indices to None 54 | self._train_inds = None 55 | self._val_inds = None 56 | self._test_inds = None 57 | 58 | if (openml_dataset_id is None) == (openml_task_id is None): 59 | raise ValueError("Must specify exactly one of openml_dataset_id or openml_task_id") 60 | self.did = openml_dataset_id 61 | self.tid = openml_task_id 62 | 63 | if task_id is None or task_id == "default": 64 | self.rng = np.random.default_rng(0) 65 | self.fold = 0 66 | elif task_id.startswith("fold"): 67 | if openml_task_id is None: 68 | raise ValueError("Can only use fold tasks with openml_task_id") 69 | self.fold = int(task_id.removeprefix("fold")) 70 | elif task_id.startswith("random-seed"): 71 | split_seed = int(task_id.removeprefix("random-seed")) 72 | self.rng = np.random.default_rng(split_seed) 73 | self.fold = None 74 | else: 75 | raise ValueError(f"Invalid task_id {task_id}") 76 | 77 | @property 78 | def openml_dataset(self): 79 | if not hasattr(self, "_openml_dataset"): 80 | raise ValueError("Data not loaded yet") 81 | return self._openml_dataset 82 | 83 | def prepare_data(self, download_dir: str): 84 | """ 85 | Downloads the OpenML dataset and prepares the data for use. 86 | Args: 87 | download_dir (str): Directory to download the OpenML dataset to. 88 | """ 89 | openml.config.set_root_cache_directory(os.path.join(download_dir, "openml_cache")) 90 | 91 | if self.tid: 92 | # retreive task and dataset information 93 | task = openml.tasks.get_task( 94 | self.tid, 95 | download_splits=True, 96 | download_data=True, 97 | download_qualities=True, 98 | download_features_meta_data=True, 99 | ) 100 | dataset = openml.datasets.get_dataset( 101 | task.dataset_id, 102 | download_data=True, 103 | download_qualities=True, 104 | download_features_meta_data=True, 105 | ) 106 | 107 | # retrieve data and metadata 108 | X, y, _, self.column_names = dataset.get_data(target=dataset.default_target_attribute) 109 | n = len(X) 110 | 111 | self.metadata["openml_task_id"] = self.tid 112 | self.metadata["openml_dataset_id"] = dataset.dataset_id 113 | 114 | # 115 | if self.fold is not None: 116 | split = task.get_train_test_split_indices(fold=self.fold) 117 | self._train_inds = split.train 118 | self._val_inds = [] 119 | self._test_inds = split.test 120 | 121 | else: 122 | dataset = openml.datasets.get_dataset( 123 | self.did, 124 | download_data=True, 125 | download_qualities=True, 126 | download_features_meta_data=True, 127 | ) 128 | X, y, _, self.column_names = dataset.get_data(target=dataset.default_target_attribute) 129 | self.metadata["openml_dataset_id"] = self.did 130 | 131 | self.metadata["openml_dataset_name"] = dataset.name 132 | self.metadata["openml_dataset_description"] = dataset.description 133 | 134 | if not self.tid or self.fold is None: 135 | n = len(X) 136 | perm = self.rng.permutation(n) 137 | self._train_inds = perm[: int(n * 0.7)] 138 | self._val_inds = perm[int(n * 0.7) : int(n * 0.85)] 139 | self._test_inds = perm[int(n * 0.85) :] 140 | 141 | if dataset.default_target_attribute and "," in dataset.default_target_attribute: 142 | y = None 143 | warnings.warn( 144 | f"Dataset {self.metadata['openml_dataset_id']} has multiple targets, which is " 145 | "not supported. Omitting targets." 146 | ) 147 | 148 | self._openml_dataset = dataset 149 | 150 | categorical_inds = [] 151 | for i, col in enumerate(X.columns): 152 | # Convert categorical features to ordinal integers 153 | if X[col].dtype == "object" or pd.api.types.is_categorical_dtype(X[col]): 154 | enc = OrdinalEncoder() 155 | X[[col]] = enc.fit_transform(X[[col]]) 156 | categorical_inds.append(i) 157 | 158 | self.metadata["categorical_feature_inds"] = categorical_inds 159 | 160 | self.X = X.to_numpy().astype(np.float32) 161 | 162 | if y is None: 163 | self.y = None 164 | self.metadata["target_type"] = "none" 165 | return 166 | 167 | target_feature = [ 168 | f for f in dataset.features.values() if f.name == dataset.default_target_attribute 169 | ][0] 170 | 171 | # encode target variable if it is categorical 172 | if ( 173 | target_feature.data_type == "nominal" 174 | or y.dtype == "object" 175 | or pd.api.types.is_categorical_dtype(y) 176 | ): 177 | enc = LabelEncoder() 178 | self.y = enc.fit_transform(y) 179 | self.metadata["target_type"] = "classification" 180 | else: 181 | self.y = y.to_numpy().astype(np.float32) 182 | self.metadata["target_type"] = "regression" 183 | 184 | def all_instances(self): 185 | return self.X, self.y 186 | 187 | def train_inds(self): 188 | return self._train_inds 189 | 190 | def val_inds(self): 191 | return self._val_inds 192 | 193 | def test_inds(self): 194 | return self._test_inds 195 | 196 | 197 | class OpenMLTaskDataset(Dataset): 198 | """ 199 | Abstract class for dataset classes that are fully defined by a list of OpenML task IDs, which 200 | are the names in all_names() 201 | """ 202 | 203 | def __init__(self, name, task_id=None): 204 | self.tid = int(name) 205 | self.openml_dataset = OpenMLDataset(name, task_id, openml_task_id=self.tid) 206 | task_id = self.openml_dataset.task_id() 207 | super().__init__(name, task_id) 208 | assert name in self.all_names() 209 | 210 | def prepare_data(self, download_dir): 211 | self.openml_dataset.prepare_data(download_dir) 212 | self.metadata = {**self.openml_dataset.metadata, **self.metadata} 213 | 214 | def all_instances(self): 215 | return self.openml_dataset.all_instances() 216 | 217 | def train_inds(self): 218 | return self.openml_dataset.train_inds() 219 | 220 | def val_inds(self): 221 | return self.openml_dataset.val_inds() 222 | 223 | def test_inds(self): 224 | return self.openml_dataset.test_inds() 225 | 226 | 227 | # Note: openml task ids, not dataset ids 228 | # See https://github.com/LeoGrin/tabular-benchmark?tab=readme-ov-file#downloading-the-datasets 229 | # for source suite ids. 230 | GRIN_NUM_CLS_IDS = [ 231 | 361055, 232 | 361060, 233 | 361061, 234 | 361062, 235 | 361063, 236 | 361065, 237 | 361066, 238 | 361068, 239 | 361069, 240 | 361070, 241 | 361273, 242 | 361274, 243 | 361275, 244 | 361276, 245 | 361277, 246 | 361278, 247 | ] 248 | GRIN_CAT_CLS_IDS = [361110, 361111, 361113, 361282, 361283, 361285, 361286] 249 | GRIN_NUM_REG_IDS = [ 250 | 361072, 251 | 361073, 252 | 361074, 253 | 361076, 254 | 361077, 255 | 361078, 256 | 361079, 257 | 361080, 258 | 361081, 259 | 361082, 260 | 361083, 261 | 361084, 262 | 361085, 263 | 361086, 264 | 361087, 265 | 361088, 266 | 361279, 267 | 361280, 268 | 361281, 269 | ] 270 | GRIN_CAT_REG_IDS = [ 271 | 361093, 272 | 361094, 273 | 361096, 274 | 361097, 275 | 361098, 276 | 361099, 277 | 361101, 278 | 361102, 279 | 361103, 280 | 361104, 281 | 361287, 282 | 361288, 283 | 361289, 284 | 361291, 285 | 361292, 286 | 361293, 287 | 361294, 288 | ] 289 | 290 | GRIN_ALL_IDS = GRIN_NUM_CLS_IDS + GRIN_CAT_CLS_IDS + GRIN_NUM_REG_IDS + GRIN_CAT_REG_IDS 291 | 292 | 293 | class GrinsztajnDataset(OpenMLTaskDataset): 294 | @staticmethod 295 | def all_names(): 296 | return [str(tid) for tid in GRIN_ALL_IDS] 297 | 298 | @staticmethod 299 | def suite_name(): 300 | return "grinsztajn" 301 | 302 | 303 | CC18_IDS = [ 304 | 3, 305 | 6, 306 | 11, 307 | 12, 308 | 14, 309 | 15, 310 | 16, 311 | 18, 312 | 22, 313 | 23, 314 | 28, 315 | 29, 316 | 31, 317 | 32, 318 | 37, 319 | 43, 320 | 45, 321 | 49, 322 | 53, 323 | 219, 324 | 2074, 325 | 2079, 326 | 3021, 327 | 3022, 328 | 3481, 329 | 3549, 330 | 3560, 331 | 3573, 332 | 3902, 333 | 3903, 334 | 3904, 335 | 3913, 336 | 3917, 337 | 3918, 338 | 7592, 339 | 9910, 340 | 9946, 341 | 9952, 342 | 9957, 343 | 9960, 344 | 9964, 345 | 9971, 346 | 9976, 347 | 9977, 348 | 9978, 349 | 9981, 350 | 9985, 351 | 10093, 352 | 10101, 353 | 14952, 354 | 14954, 355 | 14965, 356 | 14969, 357 | 14970, 358 | 125920, 359 | 125922, 360 | 146195, 361 | 146800, 362 | 146817, 363 | 146819, 364 | 146820, 365 | 146821, 366 | 146822, 367 | 146824, 368 | 146825, 369 | 167119, 370 | 167120, 371 | 167121, 372 | 167124, 373 | 167125, 374 | 167140, 375 | 167141, 376 | ] 377 | 378 | 379 | class CC18Dataset(OpenMLTaskDataset): 380 | @staticmethod 381 | def all_names(): 382 | return [str(tid) for tid in CC18_IDS] 383 | 384 | @staticmethod 385 | def suite_name(): 386 | return "cc18" 387 | 388 | 389 | CTR23_IDS = [ 390 | 361234, 391 | 361235, 392 | 361236, 393 | 361237, 394 | 361241, 395 | 361242, 396 | 361243, 397 | 361244, 398 | 361247, 399 | 361249, 400 | 361250, 401 | 361251, 402 | 361252, 403 | 361253, 404 | 361254, 405 | 361255, 406 | 361256, 407 | 361257, 408 | 361258, 409 | 361259, 410 | 361260, 411 | 361261, 412 | 361264, 413 | 361266, 414 | 361267, 415 | 361268, 416 | 361269, 417 | 361272, 418 | 361616, 419 | 361617, 420 | 361618, 421 | 361619, 422 | 361621, 423 | 361622, 424 | 361623, 425 | ] 426 | 427 | 428 | class CTR23Dataset(OpenMLTaskDataset): 429 | @staticmethod 430 | def all_names(): 431 | return [str(tid) for tid in CTR23_IDS] 432 | 433 | @staticmethod 434 | def suite_name(): 435 | return "ctr23" 436 | 437 | 438 | AMLB_IDS = [ 439 | 3, 440 | 12, 441 | 31, 442 | 53, 443 | 3917, 444 | 3945, 445 | 7592, 446 | 7593, 447 | 9952, 448 | 9977, 449 | 9981, 450 | 10101, 451 | 14965, 452 | 34539, 453 | 146195, 454 | 146212, 455 | 146606, 456 | 146818, 457 | 146821, 458 | 146822, 459 | 146825, 460 | 167119, 461 | 167120, 462 | 168329, 463 | 168330, 464 | 168331, 465 | 168332, 466 | 168335, 467 | 168337, 468 | 168338, 469 | 168868, 470 | 168908, 471 | 168909, 472 | 168910, 473 | 168911, 474 | 168912, 475 | 189354, 476 | 189355, 477 | 189356, 478 | ] 479 | 480 | 481 | class AMLBDataset(OpenMLTaskDataset): 482 | @staticmethod 483 | def all_names(): 484 | return [str(tid) for tid in AMLB_IDS] 485 | 486 | @staticmethod 487 | def suite_name(): 488 | return "amlb" 489 | 490 | 491 | ICLR_TRAINING_IDS = [ 492 | 41138, 493 | 4135, 494 | 4535, 495 | 41434, 496 | 375, 497 | 1120, 498 | 41150, 499 | 40900, 500 | 40536, 501 | 1043, 502 | 1119, 503 | 1169, 504 | 41147, 505 | 1459, 506 | 1466, 507 | 1118, 508 | 41142, 509 | 23380, 510 | 1596, 511 | 41163, 512 | 1471, 513 | 846, 514 | 1044, 515 | 41164, 516 | 1477, 517 | 1476, 518 | 1038, 519 | 41159, 520 | 23512, 521 | 1479, 522 | 821, 523 | 41168, 524 | 41143, 525 | 184, 526 | 1483, 527 | 40679, 528 | 24, 529 | 1116, 530 | 1568, 531 | 1493, 532 | 30, 533 | 41145, 534 | 1567, 535 | 871, 536 | 41161, 537 | 41165, 538 | 312, 539 | 40685, 540 | 1036, 541 | 41146, 542 | 41166, 543 | 1509, 544 | 40733, 545 | 44089, 546 | 44122, 547 | 45022, 548 | 45020, 549 | 45028, 550 | 45026, 551 | 45038, 552 | 45039, 553 | 1111, 554 | 1457, 555 | 41167, 556 | 41158, 557 | 41144, 558 | 41156, 559 | 40498, 560 | 41169, 561 | 41162, 562 | 42734, 563 | 42732, 564 | 42746, 565 | 42742, 566 | 43072, 567 | 137, 568 | 273, 569 | 382, 570 | 389, 571 | 396, 572 | 802, 573 | 816, 574 | 843, 575 | 930, 576 | 966, 577 | 981, 578 | 1002, 579 | 1018, 580 | 1037, 581 | 1042, 582 | 1112, 583 | 1130, 584 | 1142, 585 | 1444, 586 | 1453, 587 | 1481, 588 | 1503, 589 | 1507, 590 | 40646, 591 | 40680, 592 | 40706, 593 | 44055, 594 | 44056, 595 | 44061, 596 | 44063, 597 | 44065, 598 | 44068, 599 | 44069, 600 | 45041, 601 | 45043, 602 | 45045, 603 | 45046, 604 | 45047, 605 | 44136, 606 | 44137, 607 | 44145, 608 | 45032, 609 | 4549, 610 | 42572, 611 | 42705, 612 | 42728, 613 | 41540, 614 | 42724, 615 | 42727, 616 | 42730, 617 | 41980, 618 | 42563, 619 | 3050, 620 | 3277, 621 | 43071, 622 | ] 623 | 624 | 625 | class ICLRTrainingDataset(Dataset): 626 | @staticmethod 627 | def all_names(): 628 | return [str(did) for did in ICLR_TRAINING_IDS] 629 | 630 | @staticmethod 631 | def suite_name(): 632 | return "iclr-training" 633 | 634 | def __init__(self, name): 635 | super().__init__(name) 636 | assert name in self.all_names() 637 | self.did = int(name) 638 | self.openml_dataset = OpenMLDataset(name, openml_dataset_id=self.did) 639 | 640 | def prepare_data(self, download_dir): 641 | self.openml_dataset.prepare_data(download_dir) 642 | self.metadata = {**self.openml_dataset.metadata, **self.metadata} 643 | 644 | def all_instances(self): 645 | return self.openml_dataset.all_instances() 646 | 647 | def train_inds(self): 648 | return self.openml_dataset.train_inds() 649 | 650 | def val_inds(self): 651 | return self.openml_dataset.val_inds() 652 | 653 | def test_inds(self): 654 | return self.openml_dataset.test_inds() 655 | -------------------------------------------------------------------------------- /tabdpt_datasets/catalogue.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import difflib 4 | import json 5 | import os 6 | import sys 7 | import warnings 8 | from dataclasses import dataclass 9 | from typing import Literal 10 | 11 | import numpy as np 12 | import scipy 13 | 14 | from tabdpt_datasets.annotated_tables import AnnotatedTablesDataset 15 | from tabdpt_datasets.dataset import Dataset 16 | from tabdpt_datasets.openml import ( 17 | AMLBDataset, 18 | CC18Dataset, 19 | CTR23Dataset, 20 | GrinsztajnDataset, 21 | ICLRTrainingDataset, 22 | TabZillaDataset, 23 | ) 24 | from tabdpt_datasets.tabred import TabredDataset 25 | from tabdpt_datasets.talent import TalentDataset 26 | 27 | 28 | class NpGenericEncoder(json.JSONEncoder): 29 | # Handle types like np.float32 30 | def default(self, x): 31 | if isinstance(x, np.generic): 32 | return x.item() 33 | return json.JSONEncoder.default(self, x) 34 | 35 | 36 | # Only include dataset classes that provide a suite of named datasets that can each be loaded by 37 | # name 38 | SUITES = [ 39 | CC18Dataset, 40 | CTR23Dataset, 41 | GrinsztajnDataset, 42 | TabredDataset, 43 | TalentDataset, 44 | AMLBDataset, 45 | TabZillaDataset, 46 | ICLRTrainingDataset, 47 | AnnotatedTablesDataset, 48 | ] 49 | 50 | EVAL_SUITES = [CC18Dataset, CTR23Dataset] 51 | EVAL_SUITE_NAMES = [s.suite_name() for s in EVAL_SUITES] 52 | 53 | 54 | def longest_overlap(s1: str, s2: str) -> int: 55 | sm = difflib.SequenceMatcher(a=s1.lower(), b=s2.lower()) 56 | return sm.find_longest_match(0, len(s1), 0, len(s2))[2] 57 | 58 | 59 | @dataclass 60 | class Duplicate: 61 | suite_name_1: str 62 | dataset_name_1: str 63 | suite_name_2: str 64 | dataset_name_2: str 65 | likelihood: Literal["low", "high", "certain"] 66 | reasons: list[str] 67 | 68 | def likelihood_gte(self, other: str): 69 | """ 70 | Returns True if this duplicate's likelihood is greater than or equal to the other given 71 | likelihood. 72 | """ 73 | assert other in ["low", "high", "certain"] 74 | return ( 75 | other == "low" 76 | or other == "high" 77 | and self.likelihood != "low" 78 | or other == "certain" 79 | and self.likelihood == "certain" 80 | ) 81 | 82 | 83 | class CatalogueView: 84 | """ 85 | Base class for catalogue functions that don't manage metadata on disk. Allows for filtered views 86 | of the catalogue. 87 | 88 | Code outside this module generally shouldn't directly instantiate CatalogueViews, the Catalogue 89 | class should be used instead. 90 | """ 91 | 92 | def __init__(self, metadata): 93 | self._metadata = metadata 94 | self.dataset_map = {(c.suite_name(), name): c for c in SUITES for name in c.all_names()} 95 | 96 | def filter(self, key: str, pred) -> CatalogueView: 97 | """ 98 | Filter catalogue using a metadata key and corresponding required value, or a metadata key 99 | and a predicate that must return True when passed the value. 100 | 101 | E.g., c.filter('target_type', 'regression') or c.filter('size', lambda s: s > 1000) 102 | 103 | Returns a new CatalogueView. 104 | """ 105 | if callable(pred): 106 | return CatalogueView({k: v for k, v in self._metadata.items() if pred(v[key])}) 107 | return CatalogueView({k: v for k, v in self._metadata.items() if v[key] == pred}) 108 | 109 | def metadata(self): 110 | return self._metadata 111 | 112 | def dataset(self, suite_name: str, dataset_name: str) -> Dataset: 113 | """ 114 | Get a Dataset object for the dataset with the given name. 115 | """ 116 | return self.dataset_map[(suite_name, dataset_name)](dataset_name) 117 | 118 | def datasets(self) -> list[Dataset]: 119 | """ 120 | Get dataset objects for all datasets in the current view. 121 | """ 122 | return [self.dataset(*k) for k in self._metadata.keys()] 123 | 124 | def detect_duplicates(self, verbose=False) -> list[Duplicate]: 125 | """ 126 | Detects possible duplicate datasets. Sensitivity can be selected for detecting leakage (high 127 | sensitivity, default) or just removing redundant data (low sensitivity). Returns a list of 128 | Duplicate objects. 129 | 130 | verbose: Print out duplicates and info on detection as they're detected 131 | """ 132 | entries = list(self._metadata.items()) 133 | kd_trees = {} 134 | for k, m in entries: 135 | col_coeffs = [] 136 | for mean, var, skew, kurt in zip( 137 | m["column_means"], m["column_vars"], m["column_skews"], m["column_kurtoses"] 138 | ): 139 | col_coeffs.append([mean, np.sqrt(var), skew, kurt]) 140 | # Using moderate values for infinities so they don't drive all other values to zero once 141 | # normalized 142 | col_coeffs = np.nan_to_num(np.array(col_coeffs), posinf=1.0, neginf=-1.0) 143 | if len(col_coeffs) > 0: 144 | # Normalize row-wise so that the distance tolerance works reasonably 145 | col_coeffs = (col_coeffs - col_coeffs.mean(axis=1)[:, None]) / ( 146 | col_coeffs.std(axis=1)[:, None] + 1e-8 147 | ) 148 | kd_trees[k] = scipy.spatial.KDTree(col_coeffs) 149 | 150 | duplicates = [] 151 | for i in range(len(entries)): 152 | for j in range(i + 1, len(entries)): 153 | k1, m1 = entries[i] 154 | k2, m2 = entries[j] 155 | # These suites should already be deduped 156 | if m1["suite_name"] == m2["suite_name"] and m1["suite_name"] in ( 157 | "cc18", 158 | "ctr23", 159 | "grinsztajn", 160 | "tabred", 161 | "amlb", 162 | ): 163 | continue 164 | 165 | dup_reasons = [] 166 | certain_dup = False 167 | high_feature_similarity = False 168 | 169 | if m1["n_rows"] == m2["n_rows"] and sum(c != "0" for c in str(m1["n_rows"])) > 1: 170 | # Ignore n_rows mismatches if they only have one non-zero digit, e.g., 10000 vs 171 | # 10000 is not really suspicious 172 | dup_reasons.append(f"same n_rows {m1['n_rows']}") 173 | 174 | if m1["n_features"] == m2["n_features"]: 175 | dup_reasons.append(f"same number of features {m1['n_features']}") 176 | 177 | if ( 178 | m1["target_type"] != "none" 179 | and m2["target_type"] != "none" 180 | and np.isclose(m1["y_mean"], m2["y_mean"]) 181 | and np.isclose(m1["y_var"], m2["y_var"]) 182 | ): 183 | dup_reasons.append( 184 | f"similar target statistics: mean {m1['y_mean']} var {m1['y_var']} " 185 | f"vs mean {m2['y_mean']} var {m2['y_var']}" 186 | ) 187 | 188 | if "file_sha1" in m1 and "file_sha1" in m2 and m1["file_sha1"] == m2["file_sha1"]: 189 | certain_dup = True 190 | dup_reasons.append("same sha1 hash") 191 | 192 | m1_name = m1.get("openml_dataset_name", m1.get("kaggle_dataset_name", None)) 193 | m2_name = m2.get("openml_dataset_name", m2.get("kaggle_dataset_name", None)) 194 | if m1_name and m2_name and longest_overlap(m1_name, m2_name) > 4: 195 | dup_reasons.append("overlap in names:" f"{m1_name}, {m2_name}") 196 | 197 | # Compare statistics for each column wrt the target. If enough are similar, then 198 | # there might be leakage. 199 | if k1 in kd_trees and k2 in kd_trees: 200 | # Check the number of columns that seem to have a pair in the other table 201 | # Annoying output format - list of lists of indices in other tree 202 | pair_inds = kd_trees[k1].query_ball_tree(kd_trees[k2], 1e-3) 203 | n_similar_k1 = sum(len(xs) > 0 for xs in pair_inds) 204 | n_similar_k2 = len(set(x for xs in pair_inds for x in xs)) 205 | n_similar = min(n_similar_k1, n_similar_k2) 206 | thresh = min(5, m1["n_features"], m2["n_features"]) 207 | 208 | if n_similar >= thresh: 209 | dup_reasons.append(f"{n_similar} columns with similar statistics") 210 | min_n_features = min(m1["n_features"], m2["n_features"]) 211 | max_n_features = max(m1["n_features"], m2["n_features"]) 212 | n_feature_ratio = max_n_features / min_n_features 213 | detection_frac = min(1, 1 / 2 ** (np.log10(min_n_features) - 1)) 214 | if n_similar > detection_frac * min_n_features and n_feature_ratio < 1.5: 215 | high_feature_similarity = True 216 | 217 | likelihood = ( 218 | "certain" 219 | if certain_dup 220 | else ( 221 | "high" 222 | if len(dup_reasons) > 2 223 | else "low" if len(dup_reasons) > 1 or high_feature_similarity else None 224 | ) 225 | ) 226 | if likelihood: 227 | if verbose: 228 | print(f"Suspected duplicate {k1} {k2}: {', '.join(dup_reasons)}") 229 | if m1["suite_name"] == m2["suite_name"]: 230 | print(f"Within same suite! {m1['suite_name']}") 231 | duplicates.append( 232 | Duplicate(k1[0], k1[1], k2[0], k2[1], likelihood, dup_reasons) 233 | ) 234 | 235 | return duplicates 236 | 237 | def filter_duplicates( 238 | self, eval_min_likelihood="low", train_min_likelihood="high" 239 | ) -> CatalogueView: 240 | """ 241 | Returns a new CatalogueView with duplicate datasets removed. Eval datasets are always kept. 242 | 243 | eval_min_likelihood: Either 'low', 'high', 'certain', or None. Determines minimum duplicate 244 | likelihood between a train and eval dataset for the train dataset to be removed. Set to None 245 | to ignore eval duplicates. The default is 'low' to prevent all suspected leakage. 246 | train_min_likelihood: Either 'low', 'high', 'certain', or None. Determines minimum duplicate 247 | likelihood between two train datasets for the first train dataset to be removed. Set to None 248 | to ignore train duplicates. The default is 'high' to avoid removing useful data, since 249 | leakage isn't a concern in the train set. 250 | """ 251 | # Remove all duplicates with eval datasets, then go back and remove the first dataset from 252 | # each remaining pair of duplicates. Since detect_duplicates uses a fixed ordering, this 253 | # ensures all duplicates are resolved. 254 | duplicates = self.detect_duplicates(verbose=False) 255 | to_remove: set[tuple[str, str]] = set() 256 | dups_no_eval = [] 257 | for d in duplicates: 258 | if d.suite_name_1 in EVAL_SUITE_NAMES and d.suite_name_2 in EVAL_SUITE_NAMES: 259 | continue 260 | elif d.suite_name_1 in EVAL_SUITE_NAMES: 261 | if eval_min_likelihood and d.likelihood_gte(eval_min_likelihood): 262 | to_remove.add((d.suite_name_2, d.dataset_name_2)) 263 | elif d.suite_name_2 in EVAL_SUITE_NAMES: 264 | if eval_min_likelihood and d.likelihood_gte(eval_min_likelihood): 265 | to_remove.add((d.suite_name_1, d.dataset_name_1)) 266 | else: 267 | dups_no_eval.append(d) 268 | for d in dups_no_eval: 269 | if ( 270 | train_min_likelihood 271 | and d.likelihood_gte(train_min_likelihood) 272 | and (d.suite_name_2, d.dataset_name_2) not in to_remove 273 | ): 274 | to_remove.add((d.suite_name_1, d.dataset_name_1)) 275 | return CatalogueView({k: v for k, v in self._metadata.items() if k not in to_remove}) 276 | 277 | def total_size(self) -> (int, int): 278 | """ 279 | Total instances and cells across all datasets 280 | """ 281 | return ( 282 | sum(d["n_rows"] for d in self._metadata.values()), 283 | sum(d["n_cells"] for d in self._metadata.values()), 284 | ) 285 | 286 | def split(self) -> tuple[CatalogueView, CatalogueView]: 287 | """ 288 | Splits datasets based on name into a training and eval set. Note that to avoid leakage, 289 | duplicates should already be filtered out. Uses fixed splits as discussed. 290 | """ 291 | train, eval = {}, {} 292 | for k, v in self._metadata.items(): 293 | if k[0] in EVAL_SUITE_NAMES: 294 | eval[k] = v 295 | else: 296 | train[k] = v 297 | return CatalogueView(train), CatalogueView(eval) 298 | 299 | 300 | class Catalogue(CatalogueView): 301 | """ 302 | Dataset catalogue, used to access the aggregate set of datasets in our project and filter, 303 | remove duplicates, split into train and hold-out, etc. 304 | 305 | Unlike CatalogueView, this class includes functionality to generate, load, and store the 306 | catalogue, so it should be generally used to create and access the catalogue. 307 | """ 308 | 309 | def __init__(self, download_dir, suites=SUITES): 310 | self.download_dir = download_dir 311 | self.suites = suites 312 | self.cache_path = os.path.join(download_dir, "metadata.json") 313 | if os.path.exists(self.cache_path): 314 | with open(self.cache_path, "r") as f: 315 | try: 316 | metadata = json.load(f) 317 | except json.decoder.JSONDecodeError as e: 318 | print( 319 | f"JSON decoding failed - solution is probably to delete {self.cache_path} and rerun full_update" 320 | ) 321 | raise e 322 | # Stored as a list, convert back into a dict indexed by suite and dataset name 323 | metadata = {(d["suite_name"], d["dataset_name"]): d for d in metadata} 324 | else: 325 | metadata = {} 326 | warnings.warn( 327 | "Catalogue cache doesn't exist yet - run catalogue.incremental_update() or catalogue.full_update() before using" 328 | ) 329 | super().__init__(metadata) 330 | 331 | def update_dataset(self, dataset, update_cache=True): 332 | """ 333 | Run auto_populate_metadata to update the metadata for an individual dataset, optionally 334 | writing to the cache. 335 | """ 336 | dataset.prepare_data(self.download_dir) 337 | dataset.auto_populate_metadata() 338 | self._metadata[(dataset.suite_name(), dataset.name)] = dataset.metadata 339 | if update_cache: 340 | with open(self.cache_path, "w") as f: 341 | json.dump(list(self._metadata.values()), f, cls=NpGenericEncoder, indent=4) 342 | 343 | def incremental_update(self): 344 | """ 345 | Update metadata and cache for datasets that don't already appear in metadata. 346 | """ 347 | n_updated = 0 348 | n_total = sum(len(c.all_names()) for c in self.suites) 349 | for c in self.suites: 350 | for name in c.all_names(): 351 | n_updated += 1 352 | sys.stdout.write(f"\rUpdating catalogue... {n_updated}/{n_total}") 353 | sys.stdout.flush() 354 | if name not in self._metadata: 355 | self.update_dataset(c(name), update_cache=False) 356 | print("Done") 357 | with open(self.cache_path, "w") as f: 358 | json.dump(list(self._metadata.values()), f, cls=NpGenericEncoder, indent=4) 359 | 360 | def full_update(self): 361 | """ 362 | Re-generate metadata and cache entirely. 363 | """ 364 | self._metadata = {} 365 | n_updated = 0 366 | n_total = sum(len(c.all_names()) for c in self.suites) 367 | for c in self.suites: 368 | for name in c.all_names(): 369 | n_updated += 1 370 | sys.stdout.write(f"\rUpdating catalogue... {n_updated}/{n_total}") 371 | sys.stdout.flush() 372 | self.update_dataset(c(name), update_cache=False) 373 | print("Done") 374 | with open(self.cache_path, "w") as f: 375 | json.dump(list(self._metadata.values()), f, cls=NpGenericEncoder, indent=4) 376 | -------------------------------------------------------------------------------- /tabdpt.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin 7 | from torch.nn.attention import SDPBackend, sdpa_kernel 8 | 9 | from model import TabDPTModel, Task, convert_to_torch_tensor, pad_x 10 | from utils import FAISS, DataPreprocessor, standardize, seed_everything 11 | 12 | _DEFAULT_DEVICE = "cuda:0" 13 | _INF_BATCH_SIZE = 512 14 | _DEFAULT_CONTEXT_SIZE = 128 15 | _DEFAULT_RETRIEVAL = True 16 | _DEFAULT_TEMPERATURE = 0.8 17 | 18 | 19 | class TabDPTEstimator(BaseEstimator): 20 | """A PyTorch-based implementation of TabDPT for tabular data tasks.""" 21 | 22 | def __init__( 23 | self, 24 | model: Optional[TabDPTModel] = None, 25 | path: str = "", 26 | mode: Task = Task.CLS, 27 | inf_batch_size: int = _INF_BATCH_SIZE, 28 | device: str = _DEFAULT_DEVICE, 29 | tensor_eval: bool = False, 30 | ): 31 | """Initialize the TabDPTEstimator. 32 | 33 | Args: 34 | model (Optional[TabDPTModel], optional): The TabDPT model to use. Defaults to None. 35 | path (str, optional): Path to the model checkpoint. Defaults to "". 36 | mode (Task, optional): The task mode (classification or regression). Defaults to Task.CLS. 37 | inf_batch_size (int, optional): Inference batch size. Defaults to 512. 38 | device (str, optional): Device to run the model on. Defaults to _DEFAULT_DEVICE. 39 | tensor_eval (bool, optional): Whether to use tensor evaluation. Defaults to False. 40 | 41 | Raises: 42 | ValueError: If both model and path are None. 43 | """ 44 | seed_everything(42) 45 | self.mode = mode 46 | self.inf_batch_size = inf_batch_size 47 | self.device = device 48 | self.tensor_eval = tensor_eval 49 | if model is None: 50 | if path: 51 | checkpoint = torch.load(path, weights_only=False) 52 | self.model = TabDPTModel.load( 53 | model_state=checkpoint["model"], config=checkpoint["cfg"] 54 | ) 55 | self.model.eval() 56 | else: 57 | raise ValueError("Either model or path must be provided") 58 | else: 59 | self.model = model 60 | self.max_features = self.model.num_features 61 | self.max_num_classes = self.model.n_out 62 | 63 | def fit(self, X: np.ndarray, y: np.ndarray, faiss_index=None) -> None: 64 | """Fit the model to the training data. 65 | 66 | Args: 67 | X (np.ndarray): Training features, a 2D numpy array of shape [n_samples, n_features]. 68 | y (np.ndarray): Training labels, a 1D numpy array of shape [n_samples]. 69 | faiss_index (Optional[FAISS], optional): Precomputed FAISS index for retrieval. Defaults to None. 70 | """ 71 | if not self.tensor_eval: 72 | assert isinstance(X, np.ndarray), "X must be a numpy array" 73 | assert isinstance(y, np.ndarray), "y must be a numpy array" 74 | assert X.shape[0] == y.shape[0], "X and y must have the same number of samples" 75 | assert X.ndim == 2, "X must be a 2D array" 76 | assert y.ndim == 1, "y must be a 1D array" 77 | 78 | self.preprocessor = DataPreprocessor() 79 | X = self.preprocessor.fit_transform(X) 80 | 81 | # initialize the Faiss index if not provided 82 | self.faiss_knn = faiss_index if faiss_index is not None else FAISS(X) 83 | 84 | self.n_instances, self.n_features = X.shape 85 | self.X_train = X 86 | self.y_train = y 87 | 88 | self.autocast = torch.autocast(device_type="cuda", dtype=torch.bfloat16) 89 | 90 | def _prepare_prediction(self, X: np.ndarray) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 91 | """preprocess the input data for prediction. 92 | This method handles the transformation of the input data, including imputation, 93 | scaling, and dimensionality reduction if necessary. 94 | 95 | Args: 96 | X (np.ndarray): input features for prediction. 97 | 98 | Returns: 99 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: training features, training labels, and test features. 100 | """ 101 | 102 | # preprocess the input data 103 | if not self.tensor_eval: 104 | self.X_test = self.preprocessor.transform(X) 105 | else: 106 | self.X_test = X 107 | 108 | train_x, train_y, test_x = ( 109 | convert_to_torch_tensor(self.X_train).to(self.device).float(), 110 | convert_to_torch_tensor(self.y_train).to(self.device).float(), 111 | convert_to_torch_tensor(self.X_test).to(self.device).float(), 112 | ) 113 | 114 | # Apply PCA optionally to reduce the number of features 115 | if self.n_features > self.max_features: 116 | _, _, self.V = torch.pca_lowrank(train_x, q=self.max_features) 117 | train_x = train_x @ self.V 118 | else: 119 | self.V = None 120 | 121 | # apply PCA to the test set if V is not None (i.e., PCA was applied to the training set) 122 | test_x = test_x @ self.V if self.V is not None else test_x 123 | 124 | return train_x, train_y, test_x 125 | 126 | @torch.no_grad() 127 | def no_retrieval_data( 128 | self, 129 | train_x: torch.Tensor, 130 | train_y: torch.Tensor, 131 | test_x: torch.Tensor, 132 | context_size: int = _DEFAULT_CONTEXT_SIZE, 133 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 134 | """Prepare data for the no retrieval scenario. 135 | 136 | Args: 137 | train_x (torch.Tensor): Training features with shape [n_train, d]. 138 | train_y (torch.Tensor): Training labels with shape [n_train]. 139 | test_x (torch.Tensor): Test features with shape [n_test, d]. 140 | context_size (int, optional): Number of context samples to use. 141 | Defaults to _DEFAULT_CONTEXT_SIZE. 142 | 143 | Returns: 144 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 145 | Context features, context labels, and evaluation features. 146 | """ 147 | n_context = min(context_size, train_x.shape[0]) 148 | idx_context = np.random.choice(train_x.shape[0], n_context, replace=False) 149 | 150 | # shape: [n_context, d] 151 | context_features = train_x[idx_context] 152 | context_labels = train_y[idx_context] 153 | 154 | # shape: [n_context, 1, d] 155 | context_features = pad_x(context_features.unsqueeze(1), self.max_features).to(self.device) 156 | context_labels = context_labels.unsqueeze(1).float() # shape: [n_context, 1] 157 | 158 | # shape: [n_test, 1, d] 159 | eval_features = pad_x(test_x.unsqueeze(1), self.max_features).to(self.device) 160 | 161 | return context_features, context_labels, eval_features 162 | 163 | @torch.no_grad() 164 | def batch_retrieval_data( 165 | self, 166 | train_x: torch.Tensor, 167 | train_y: torch.Tensor, 168 | test_x: torch.Tensor, 169 | batch_index: int, 170 | context_size: int, 171 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 172 | """Prepare data for the batch retrieval scenario. 173 | 174 | Args: 175 | batch_index (int): Batch index. 176 | context_size (int): Number of context samples to use. 177 | 178 | Returns: 179 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 180 | Context features, context labels, and evaluation features. 181 | """ 182 | n_test = test_x.shape[0] 183 | 184 | start = batch_index * self.inf_batch_size 185 | end = min(n_test, (batch_index + 1) * self.inf_batch_size) 186 | 187 | indices_knn = self.faiss_knn.get_knn_indices(self.X_test[start:end], k=context_size) 188 | knn_features = train_x[torch.tensor(indices_knn)] 189 | knn_labels = train_y[torch.tensor(indices_knn)] 190 | 191 | # swap => [context_size, batch_size, d] 192 | knn_features = np.swapaxes(knn_features.cpu().numpy(), 0, 1) 193 | knn_labels = np.swapaxes(knn_labels.cpu().numpy(), 0, 1) 194 | 195 | knn_features = pad_x(torch.Tensor(knn_features), self.max_features).to(self.device) 196 | knn_labels = torch.Tensor(knn_labels).to(self.device) 197 | 198 | eval_features = pad_x(test_x[start:end].unsqueeze(0), self.max_features).to(self.device) 199 | return knn_features, knn_labels, eval_features 200 | 201 | 202 | class TabDPTClassifier(TabDPTEstimator, ClassifierMixin): 203 | """ 204 | A PyTorch-based implementation of TabDPT for classification tasks. 205 | """ 206 | 207 | def __init__( 208 | self, 209 | model: Optional[TabDPTModel] = None, 210 | path: str = "", 211 | mode: Task = Task.CLS, 212 | inf_batch_size: int = _INF_BATCH_SIZE, 213 | device: str = _DEFAULT_DEVICE, 214 | tensor_eval: bool = False, 215 | ): 216 | super().__init__( 217 | model=model, 218 | path=path, 219 | mode=mode, 220 | inf_batch_size=inf_batch_size, 221 | device=device, 222 | tensor_eval=tensor_eval, 223 | ) 224 | 225 | def fit(self, X, y, faiss_index=None): 226 | super().fit(X, y, faiss_index) 227 | # Number of classes 228 | if self.tensor_eval: 229 | self.num_classes = len(torch.unique(self.y_train)) 230 | else: 231 | self.num_classes = len(np.unique(self.y_train)) 232 | assert self.num_classes > 1, "Number of classes must be greater than 1" 233 | 234 | def _predict_large_cls(self, X_context, X_eval, y_context) -> torch.Tensor: 235 | """Digit-level prediction for the case where num_classes > self.max_num_classes. 236 | Here, X_context + X_eval is [L, 1, d], and y_context is [L_context, 1]. 237 | 238 | Args: 239 | X_context (torch.Tensor): Context features with shape [L_context, 1, d]. 240 | X_eval (torch.Tensor): Evaluation features with shape [L_eval, 1, d]. 241 | y_context (torch.Tensor): Context labels with shape [L_context, 1]. 242 | 243 | Returns: 244 | torch.Tensor: Predicted probabilities with shape [L_eval, 1, num_classes]. 245 | """ 246 | # number of digits needed to represent num_classes 247 | num_digits = math.ceil(math.log(self.num_classes, self.max_num_classes)) 248 | 249 | digit_preds = [] 250 | for i in range(num_digits): 251 | y_context_digit = (y_context // (self.max_num_classes**i)) % self.max_num_classes 252 | 253 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 254 | pred = self.model( 255 | x_src=torch.cat([X_context, X_eval], dim=0), 256 | y_src=y_context_digit, 257 | ) 258 | # shape: [L_context + L_eval, 1, max_num_classes] 259 | digit_preds.append(pred[..., : self.max_num_classes].float()) 260 | 261 | # Combine digit predictions 262 | B = X_context.shape[1] 263 | L_eval = X_eval.shape[0] 264 | full_pred = torch.zeros((L_eval, B, self.num_classes), device=X_context.device) 265 | # For each of the L_eval positions, compute the class probabilities 266 | for class_idx in range(self.num_classes): 267 | # sum across digits 268 | class_pred = torch.zeros((L_eval, B), device=X_eval.device) 269 | for digit_idx, digit_pred in enumerate(digit_preds): 270 | digit_value = ( 271 | class_idx // (self.max_num_classes**digit_idx) 272 | ) % self.max_num_classes 273 | # digit_pred shape: [L_context + L_eval, 1, max_num_classes] 274 | # The last L_eval rows correspond to the actual predictions 275 | class_pred += digit_pred[-L_eval:, :, digit_value] # shape: [L_eval] 276 | 277 | full_pred[:, :, class_idx] = class_pred 278 | return full_pred # shape: [L_eval, 1, self.num_classes] 279 | 280 | @torch.no_grad() 281 | def predict_proba( 282 | self, 283 | X: np.ndarray, 284 | temperature: float = _DEFAULT_TEMPERATURE, 285 | context_size: int = _DEFAULT_CONTEXT_SIZE, 286 | use_retrieval: bool = _DEFAULT_RETRIEVAL, 287 | ) -> np.ndarray: 288 | """predict class probabilities for the input data. 289 | 290 | 291 | Args: 292 | X (np.ndarray): input features for prediction. 293 | temperature (float, optional): output probability temperature. Defaults to 0.8. 294 | context_size (int, optional): context size. Defaults to 128. 295 | use_retrieval (bool, optional): 296 | - If `use_retrieval=True`, do batch-by-batch retrieval. 297 | - If `use_retrieval=False`, pick one random context and feed all test points in one shot: 298 | [N_context + N_test, 1, d]. Defaults to True. 299 | 300 | Returns: 301 | np.ndarray: probbilities for each class for each test instance. 302 | """ 303 | train_x, train_y, test_x = self._prepare_prediction(X) 304 | n_test = test_x.shape[0] 305 | 306 | # 1) If context_size >= entire training set => use them all 307 | if context_size >= self.n_instances: 308 | # shape: [n_train, 1, d] 309 | X_context = pad_x(train_x[:, None, :], self.max_features).to(self.device) 310 | y_context = train_y[:, None].float() 311 | # shape: [n_test, 1, d] 312 | X_eval = pad_x(test_x[:, None, :], self.max_features).to(self.device) 313 | 314 | if self.num_classes <= self.max_num_classes: 315 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 316 | pred = self.model( 317 | x_src=torch.cat([X_context, X_eval], dim=0), 318 | y_src=y_context, 319 | ) 320 | # shape: [n_train + n_test, 1, max_num_classes] 321 | pred = pred[..., : self.num_classes].float() 322 | # extract the last n_test rows => the test predictions 323 | test_preds = pred[-n_test:, 0, :] # shape: [n_test, num_classes] 324 | 325 | else: 326 | # Large-class approach 327 | test_preds = self._predict_large_cls(X_context, X_eval, y_context) 328 | test_preds = test_preds.squeeze(1) # shape: [n_test, num_classes] 329 | 330 | test_preds /= temperature 331 | test_preds = torch.nn.functional.softmax(test_preds, dim=-1) 332 | return test_preds.detach().cpu().numpy() 333 | 334 | # TODO: combine retrieval step with that of training 335 | # 2) If we want retrieval 336 | if use_retrieval: 337 | preds_list = [] 338 | # batch the test set 339 | for b in range(math.ceil(n_test / self.inf_batch_size)): 340 | X_nni, y_nni, X_eval = self.batch_retrieval_data( 341 | train_x, train_y, test_x, b, context_size 342 | ) 343 | # forward pass 344 | if self.num_classes <= self.max_num_classes: # small-class case 345 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 346 | pred = self.model( 347 | x_src=torch.cat([X_nni, X_eval], dim=0), 348 | y_src=y_nni, 349 | ) 350 | # shape: [context_size + 1, batch_size, max_num_classes] 351 | pred = pred[..., : self.num_classes].float() 352 | # last row => predictions for test batch 353 | batch_preds = pred[-1, :, :] # shape: [batch_size, num_classes] 354 | else: 355 | # large-class case 356 | batch_preds_full = self._predict_large_cls(X_nni, X_eval, y_nni) 357 | batch_preds = batch_preds_full[-1, :, :] 358 | 359 | batch_preds /= temperature 360 | batch_preds = torch.nn.functional.softmax(batch_preds, dim=-1) 361 | preds_list.append(batch_preds) 362 | 363 | preds_all = torch.cat(preds_list, dim=0) # [n_test, num_classes] 364 | return preds_all.cpu().numpy() 365 | 366 | # 3) If we want NO retrieval => single pass 367 | # [N_context, 1, d] + [N_test, 1, d] => [N_context + N_test, 1, d] 368 | else: 369 | X_ctx, y_ctx, X_eval = self.no_retrieval_data(train_x, train_y, test_x) 370 | 371 | if self.num_classes <= self.max_num_classes: 372 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 373 | pred = self.model( 374 | x_src=torch.cat([X_ctx, X_eval], dim=0), 375 | y_src=y_ctx, 376 | ) 377 | # shape: [n_context + n_test, 1, max_num_classes] 378 | pred = pred[..., : self.num_classes].float() 379 | # The last n_test rows => the actual test preds 380 | test_preds = pred[-n_test:, 0, :] # [n_test, num_classes] 381 | else: 382 | test_preds_full = self._predict_large_cls(X_ctx, X_eval, y_ctx) 383 | # shape: [n_context + n_test - ???, 1, num_classes] 384 | # We only want the last n_test rows 385 | test_preds = test_preds_full[-n_test:, 0, :] 386 | 387 | test_preds /= temperature 388 | test_preds = torch.nn.functional.softmax(test_preds, dim=-1) 389 | return test_preds.detach().cpu().numpy() 390 | 391 | def predict( 392 | self, 393 | X: np.ndarray, 394 | temperature: float = _DEFAULT_TEMPERATURE, 395 | context_size: int = _DEFAULT_CONTEXT_SIZE, 396 | use_retrieval: bool = _DEFAULT_RETRIEVAL, 397 | ) -> np.ndarray: 398 | """predict class labels for the input data. 399 | This method uses the `predict_proba` method to get class probabilities and then 400 | returns the class with the highest probability for each instance. 401 | 402 | Args: 403 | X (np.ndarray): input features for prediction. 404 | temperature (float, optional): temperature for prediction. Defaults to 0.8. 405 | context_size (int, optional): size of context. Defaults to 128. 406 | use_retrieval (bool, optional): whether to use retrieval. Defaults to True. 407 | 408 | Returns: 409 | np.ndarray: class labels for each test instance. 410 | """ 411 | return self.predict_proba( 412 | X, temperature=temperature, context_size=context_size, use_retrieval=use_retrieval 413 | ).argmax(axis=-1) 414 | 415 | 416 | class TabDPTRegressor(TabDPTEstimator, RegressorMixin): 417 | """A PyTorch-based implementation of TabDPT for regression tasks.""" 418 | 419 | def __init__( 420 | self, 421 | model: Optional[TabDPTModel] = None, 422 | path: str = "", 423 | mode: Task = Task.REG, 424 | inf_batch_size: int = _INF_BATCH_SIZE, 425 | device: str = _DEFAULT_DEVICE, 426 | tensor_eval: bool = False, 427 | ): 428 | super().__init__( 429 | model=model, 430 | path=path, 431 | mode=mode, 432 | inf_batch_size=inf_batch_size, 433 | device=device, 434 | tensor_eval=tensor_eval, 435 | ) 436 | 437 | @torch.no_grad() 438 | def predict( 439 | self, 440 | X: np.ndarray, 441 | context_size: int = _DEFAULT_CONTEXT_SIZE, 442 | use_retrieval: bool = _DEFAULT_RETRIEVAL, 443 | ) -> np.ndarray: 444 | """predict regression values for the input data. 445 | 446 | Args: 447 | X (np.ndarray): input features for prediction. 448 | context_size (int, optional): size of the conte. Defaults to True.xt. Defaults to 128. 449 | use_retrieval (bool, optional): whether to use retrieval. 450 | - If use_retrieval=True, do KNN retrieval (batched). 451 | - If use_retrieval=False, create a single random context and do [N_context + N_test, 1, d] once. 452 | Defaults to True. 453 | Returns: 454 | np.ndarray: regression values for each test instance. 455 | """ 456 | train_x, train_y, test_x = self._prepare_prediction(X) 457 | n_test = test_x.shape[0] 458 | 459 | # 1) If context_size >= entire training set => use them all 460 | if context_size >= self.n_instances: 461 | X_train = pad_x(train_x[:, None, :], self.max_features).to(self.device) 462 | X_test = pad_x(test_x[:, None, :], self.max_features).to(self.device) 463 | y_train = train_y[:, None].float() 464 | 465 | # standardize 466 | y_train, y_means, y_stds = standardize(y_train) 467 | 468 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 469 | pred = self.model( 470 | x_src=torch.cat([X_train, X_test], dim=0), 471 | y_src=y_train, 472 | ) 473 | pred = pred[..., -1].float() 474 | # last n_test rows => test preds 475 | test_preds = pred[-n_test:, 0] # shape: [n_test] 476 | test_preds = test_preds * y_stds + y_means 477 | return test_preds.detach().cpu().numpy() 478 | 479 | # 2) If we want retrieval 480 | if use_retrieval: 481 | pred_list = [] 482 | for b in range(math.ceil(n_test / self.inf_batch_size)): 483 | X_nni, y_nni, X_eval = self.batch_retrieval_data( 484 | train_x, train_y, test_x, b, context_size 485 | ) 486 | # standardize context targets 487 | y_nni, y_means, y_stds = standardize(y_nni) 488 | 489 | # forward pass 490 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 491 | pred = self.model( 492 | x_src=torch.cat([X_nni, X_eval], dim=0), 493 | y_src=y_nni, 494 | ) 495 | pred = pred[..., -1].float() 496 | # last row => predictions for the test batch 497 | batch_preds = pred[-1, :] # shape: [batch_size] 498 | # reverse standardization 499 | batch_preds = batch_preds * y_stds + y_means 500 | pred_list.append(batch_preds.cpu()) 501 | 502 | return torch.cat(pred_list).squeeze().detach().cpu().numpy() 503 | 504 | # 3) If we want NO retrieval => single pass 505 | else: 506 | X_ctx, y_ctx, X_eval = self.no_retrieval_data(train_x, train_y, test_x) 507 | # standardize y_ctx 508 | y_ctx_norm, y_means, y_stds = standardize(y_ctx) 509 | 510 | # forward pass 511 | with self.autocast, sdpa_kernel(SDPBackend.FLASH_ATTENTION): 512 | pred = self.model( 513 | x_src=torch.cat([X_ctx, X_eval], dim=0), # shape: [n_context + n_test, 1, d] 514 | y_src=y_ctx_norm, 515 | ) 516 | # shape: [n_context + n_test, 1, out_dim] 517 | # The last dimension presumably has the regression output at the last index 518 | pred = pred[..., -1].float() 519 | 520 | # The last n_test positions are the test predictions 521 | test_preds = pred[-n_test:, 0] # shape: [n_test] 522 | # reverse standardization 523 | test_preds = test_preds * y_stds + y_means 524 | return test_preds.detach().cpu().numpy() 525 | -------------------------------------------------------------------------------- /data_splits/cls_datasets.csv: -------------------------------------------------------------------------------- 1 | did,tid,num_instances,num_features,num_classes,num_missing_values,tabzilla_name,source,test_small,test_large,test_all 2 | 41138,168868.0,76000.0,171.0,2.0,1078695.0,openml__APSFailure__168868,"['tabzilla', 'amlb']",False,False,False 3 | 4135,34539.0,32769.0,10.0,2.0,0.0,openml__Amazon_employee_access__34539,"['tabzilla', 'amlb']",False,False,False 4 | 40981,146818.0,690.0,15.0,2.0,0.0,openml__Australian__146818,"['tabzilla', 'amlb']",False,False,False 5 | 4134,9910.0,3751.0,1777.0,2.0,0.0,openml__Bioresponse__9910,"['tabzilla', 'cc18', 'amlb', 'amlb']",False,False,True 6 | 40927,167124.0,60000.0,3073.0,10.0,0.0,openml__CIFAR_10__167124,"['tabzilla', 'cc18']",False,False,True 7 | 4535,168340.0,299285.0,42.0,,0.0,openml__Census-Income__168340,['tabzilla'],False,False,False 8 | 41434,190408.0,39948.0,12.0,2.0,0.0,openml__Click_prediction_small__190408,['tabzilla'],False,False,False 9 | 40923,167121.0,92000.0,1025.0,46.0,0.0,openml__Devnagari-Script__167121,"['tabzilla', 'cc18']",False,False,True 10 | 40996,146825.0,70000.0,785.0,10.0,0.0,openml__Fashion-MNIST__146825,"['tabzilla', 'cc18', 'amlb']",False,False,True 11 | 4538,14969.0,9873.0,33.0,5.0,0.0,openml__GesturePhaseSegmentationProcessed__14969,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 12 | 40978,167125.0,3279.0,1559.0,2.0,0.0,openml__Internet-Advertisements__167125,"['tabzilla', 'cc18', 'amlb', 'amlb']",False,False,True 13 | 375,3510.0,9961.0,15.0,9.0,0.0,openml__JapaneseVowels__3510,['tabzilla'],False,False,False 14 | 40496,125921.0,500.0,8.0,10.0,0.0,openml__LED-display-domain-7digit__125921,['tabzilla'],False,False,False 15 | 1120,3954.0,19020.0,12.0,2.0,0.0,openml__MagicTelescope__3954,['tabzilla'],False,False,False 16 | 40966,146800.0,1080.0,82.0,8.0,1396.0,openml__MiceProtein__146800,"['tabzilla', 'cc18']",True,True,True 17 | 41150,168335.0,130064.0,51.0,2.0,0.0,openml__MiniBooNE__168335,"['tabzilla', 'amlb']",False,False,False 18 | 4534,14952.0,11055.0,31.0,2.0,0.0,openml__PhishingWebsites__14952,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 19 | 40900,167211.0,5100.0,37.0,2.0,0.0,openml__Satellite__167211,"['tabzilla', 'amlb', 'amlb']",False,False,False 20 | 40536,146607.0,8378.0,121.0,2.0,18372.0,openml__SpeedDating__146607,['tabzilla'],False,False,False 21 | 1455,10089.0,120.0,7.0,2.0,0.0,openml__acute-inflammations__10089,['tabzilla'],False,False,False 22 | 1043,3896.0,4562.0,49.0,2.0,0.0,openml__ada_agnostic__3896,['tabzilla'],False,False,False 23 | 1119,3953.0,32561.0,16.0,2.0,4262.0,openml__adult-census__3953,['tabzilla'],False,False,False 24 | 1590,7592.0,48842.0,15.0,2.0,6465.0,openml__adult__7592,"['tabzilla', 'cc18', 'amlb']",True,True,True 25 | 1169,189354.0,539383.0,8.0,2.0,0.0,openml__airlines__189354,"['tabzilla', 'amlb']",False,False,False 26 | 41147,189356.0,425240.0,79.0,2.0,2734000.0,openml__albert__189356,"['tabzilla', 'amlb']",False,False,False 27 | 458,3549.0,841.0,71.0,4.0,0.0,openml__analcatdata_authorship__3549,"['tabzilla', 'cc18']",True,True,True 28 | 448,3540.0,120.0,4.0,2.0,0.0,openml__analcatdata_boxing1__3540,['tabzilla'],False,False,False 29 | 875,3739.0,100.0,4.0,2.0,0.0,openml__analcatdata_chlamydia__3739,['tabzilla'],False,False,False 30 | 469,3560.0,797.0,5.0,6.0,0.0,openml__analcatdata_dmft__3560,"['tabzilla', 'cc18']",True,True,True 31 | 1,2867.0,898.0,39.0,5.0,0.0,openml__anneal__2867,['tabzilla'],False,False,False 32 | 5,5.0,452.0,280.0,13.0,408.0,openml__arrhythmia__5,['tabzilla'],False,False,False 33 | 1459,14964.0,10218.0,8.0,10.0,0.0,openml__artificial-characters__14964,['tabzilla'],False,False,False 34 | 7,7.0,226.0,70.0,24.0,317.0,openml__audiology__7,['tabzilla'],False,False,False 35 | 9,9.0,205.0,26.0,6.0,59.0,openml__autos__9,['tabzilla'],False,False,False 36 | 11,11.0,625.0,5.0,3.0,0.0,openml__balance-scale__11,"['tabzilla', 'cc18']",True,True,True 37 | 1461,14965.0,45211.0,17.0,2.0,0.0,openml__bank-marketing__14965,"['tabzilla', 'cc18', 'amlb']",True,True,True 38 | 1558,9899.0,4521.0,17.0,2.0,0.0,openml__bank-marketing__9899,['tabzilla'],False,False,False 39 | 1462,10093.0,1372.0,5.0,2.0,0.0,openml__banknote-authentication__10093,"['tabzilla', 'cc18']",True,True,True 40 | 1464,10101.0,748.0,5.0,2.0,0.0,openml__blood-transfusion-service-center__145836,"['tabzilla', 'tabzilla', 'cc18', 'amlb']",True,True,True 41 | 13,145799.0,286.0,10.0,2.0,9.0,openml__breast-cancer__145799,['tabzilla'],False,False,False 42 | 15,15.0,699.0,10.0,2.0,16.0,openml__breast-w__15,"['tabzilla', 'cc18']",True,True,True 43 | 40664,146192.0,1728.0,22.0,4.0,0.0,openml__car-evaluation__146192,['tabzilla'],False,False,False 44 | 40975,146821.0,1728.0,7.0,4.0,0.0,openml__car__146821,"['tabzilla', 'cc18', 'amlb']",True,True,True 45 | 1466,9979.0,2126.0,36.0,10.0,0.0,openml__cardiotocography__9979,['tabzilla'],False,False,False 46 | 1118,3952.0,28056.0,7.0,18.0,0.0,openml__chess__3952,['tabzilla'],False,False,False 47 | 41142,168908.0,5418.0,1637.0,2.0,0.0,openml__christine__168908,"['tabzilla', 'amlb']",False,False,False 48 | 40701,167141.0,5000.0,21.0,2.0,0.0,openml__churn__167141,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 49 | 23380,14967.0,2796.0,35.0,6.0,68100.0,openml__cjs__14967,['tabzilla'],False,False,False 50 | 40994,146819.0,540.0,21.0,2.0,0.0,openml__climate-model-simulation-crashes__146819,"['tabzilla', 'cc18']",True,True,True 51 | 23,23.0,1473.0,10.0,3.0,0.0,openml__cmc__23,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 52 | 1468,9981.0,1080.0,857.0,9.0,0.0,openml__cnae-9__9981,"['tabzilla', 'cc18', 'amlb']",False,False,True 53 | 25,25.0,368.0,27.0,2.0,1927.0,openml__colic__25,['tabzilla'],False,False,False 54 | 27,27.0,368.0,23.0,2.0,1927.0,openml__colic__27,['tabzilla'],False,False,False 55 | 478,3567.0,500.0,22.0,15.0,0.0,openml__collins__3567,['tabzilla'],False,False,False 56 | 40668,146195.0,67557.0,43.0,3.0,0.0,openml__connect-4__146195,"['tabzilla', 'cc18', 'amlb']",True,True,True 57 | 1596,7593.0,581012.0,55.0,7.0,0.0,openml__covertype__7593,"['tabzilla', 'amlb']",False,False,False 58 | 29,29.0,690.0,16.0,2.0,67.0,openml__credit-approval__29,"['tabzilla', 'cc18']",True,True,True 59 | 31,31.0,1000.0,21.0,2.0,0.0,openml__credit-g__31,"['tabzilla', 'cc18', 'amlb']",True,True,True 60 | 6332,14954.0,540.0,40.0,2.0,999.0,openml__cylinder-bands__14954,"['tabzilla', 'cc18']",True,True,True 61 | 35,35.0,366.0,35.0,6.0,8.0,openml__dermatology__35,['tabzilla'],False,False,False 62 | 37,37.0,768.0,9.0,2.0,0.0,openml__diabetes__37,"['tabzilla', 'cc18']",True,True,True 63 | 41163,168909.0,10000.0,2001.0,5.0,0.0,openml__dilbert__168909,"['tabzilla', 'amlb']",False,False,False 64 | 40670,167140.0,3186.0,181.0,3.0,0.0,openml__dna__167140,"['tabzilla', 'cc18', 'amlb', 'amlb']",False,True,True 65 | 23381,125920.0,500.0,13.0,2.0,835.0,openml__dresses-sales__125920,"['tabzilla', 'cc18']",True,True,True 66 | 39,145977.0,336.0,8.0,8.0,0.0,openml__ecoli__145977,['tabzilla'],False,False,False 67 | 1471,14951.0,14980.0,15.0,2.0,0.0,openml__eeg-eye-state__14951,['tabzilla'],False,False,False 68 | 151,219.0,45312.0,9.0,2.0,0.0,openml__electricity__219,"['tabzilla', 'cc18']",True,True,True 69 | 846,3711.0,16599.0,19.0,2.0,0.0,openml__elevators__3711,['tabzilla'],False,False,False 70 | 188,2079.0,736.0,20.0,5.0,448.0,openml__eucalyptus__2079,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 71 | 1044,3897.0,10936.0,28.0,3.0,0.0,openml__eye_movements__3897,['tabzilla'],False,False,False 72 | 41164,168910.0,8237.0,801.0,7.0,0.0,openml__fabert__168910,"['tabzilla', 'amlb']",False,False,False 73 | 1473,9984.0,100.0,10.0,2.0,0.0,openml__fertility__9984,['tabzilla'],False,False,False 74 | 1475,9985.0,6118.0,52.0,6.0,0.0,openml__first-order-theorem-proving__9985,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 75 | 477,3566.0,67.0,16.0,5.0,0.0,openml__fl2000__3566,['tabzilla'],False,False,False 76 | 754,3620.0,100.0,6.0,2.0,0.0,openml__fri_c0_100_5__3620,['tabzilla'],False,False,False 77 | 916,3779.0,100.0,6.0,2.0,0.0,openml__fri_c3_100_5__3779,['tabzilla'],False,False,False 78 | 1477,9987.0,13910.0,130.0,6.0,0.0,openml__gas-drift-different-concentrations__9987,['tabzilla'],False,False,False 79 | 1476,9986.0,13910.0,129.0,6.0,0.0,openml__gas-drift__9986,['tabzilla'],False,False,False 80 | 1038,3891.0,3468.0,971.0,2.0,0.0,openml__gina_agnostic__3891,['tabzilla'],False,False,False 81 | 41,40.0,214.0,10.0,6.0,0.0,openml__glass__40,['tabzilla'],False,False,False 82 | 41159,168337.0,20000.0,4297.0,2.0,0.0,openml__guillermo__168337,"['tabzilla', 'amlb']",False,False,False 83 | 43,42.0,306.0,4.0,2.0,0.0,openml__haberman__42,['tabzilla'],False,False,False 84 | 1478,14970.0,10299.0,562.0,6.0,0.0,openml__har__14970,"['tabzilla', 'cc18']",False,False,True 85 | 329,146063.0,160.0,5.0,3.0,0.0,openml__hayes-roth__146063,['tabzilla'],False,False,False 86 | 49,48.0,303.0,14.0,2.0,7.0,openml__heart-c__48,['tabzilla'],False,False,False 87 | 51,50.0,294.0,14.0,2.0,782.0,openml__heart-h__50,['tabzilla'],False,False,False 88 | 55,54.0,155.0,20.0,2.0,167.0,openml__hepatitis__54,['tabzilla'],False,False,False 89 | 23512,146606.0,98050.0,29.0,2.0,9.0,openml__higgs__146606,['tabzilla'],False,False,False 90 | 1479,145847.0,1212.0,101.0,2.0,0.0,openml__hill-valley__145847,['tabzilla'],False,False,False 91 | 821,3686.0,22784.0,17.0,2.0,0.0,openml__house_16H__3686,['tabzilla'],False,False,False 92 | 1480,9971.0,583.0,11.0,2.0,0.0,openml__ilpd__9971,"['tabzilla', 'cc18']",True,True,True 93 | 59,145984.0,351.0,35.0,2.0,0.0,openml__ionosphere__145984,['tabzilla'],False,False,False 94 | 61,59.0,150.0,5.0,3.0,0.0,openml__iris__59,['tabzilla'],False,False,False 95 | 451,3543.0,500.0,6.0,2.0,32.0,openml__irish__3543,['tabzilla'],False,False,False 96 | 300,3481.0,7797.0,618.0,26.0,0.0,openml__isolet__3481,"['tabzilla', 'cc18']",False,False,True 97 | 41168,168330.0,83733.0,55.0,4.0,0.0,openml__jannis__168330,"['tabzilla', 'amlb']",False,False,False 98 | 41143,168911.0,2984.0,145.0,2.0,0.0,openml__jasmine__168911,"['tabzilla', 'amlb']",False,False,False 99 | 1053,3904.0,10885.0,22.0,2.0,25.0,openml__jm1__3904,"['tabzilla', 'cc18']",True,True,True 100 | 41027,167119.0,44819.0,7.0,3.0,0.0,openml__jungle_chess_2pcs_raw_endgame_complete__167119,"['tabzilla', 'cc18', 'amlb']",True,True,True 101 | 1067,3917.0,2109.0,22.0,2.0,0.0,openml__kc1__3917,"['tabzilla', 'cc18', 'amlb']",True,True,True 102 | 1063,3913.0,522.0,22.0,2.0,0.0,openml__kc2__3913,"['tabzilla', 'cc18']",True,True,True 103 | 3,3.0,3196.0,37.0,2.0,0.0,openml__kr-vs-kp__3,"['tabzilla', 'cc18', 'amlb']",True,True,True 104 | 184,2076.0,28056.0,7.0,18.0,0.0,openml__kropt__2076,['tabzilla'],False,False,False 105 | 4,4.0,57.0,17.0,2.0,326.0,openml__labor__4,['tabzilla'],False,False,False 106 | 1483,9974.0,164860.0,8.0,11.0,0.0,openml__ldpa__9974,['tabzilla'],False,False,False 107 | 6,6.0,20000.0,17.0,26.0,0.0,openml__letter__6,"['tabzilla', 'cc18']",False,False,True 108 | 42855,360948.0,360.0,105.0,,0.0,openml__libras__360948,['tabzilla'],False,False,False 109 | 163,146024.0,32.0,57.0,3.0,5.0,openml__lung-cancer__146024,['tabzilla'],False,False,False 110 | 10,10.0,148.0,19.0,4.0,0.0,openml__lymph__10,['tabzilla'],False,False,False 111 | 1485,9976.0,2600.0,501.0,2.0,0.0,openml__madelon__9976,"['tabzilla', 'cc18']",False,False,True 112 | 40679,146206.0,19020.0,11.0,2.0,0.0,openml__magic__146206,['tabzilla'],False,False,False 113 | 12,12.0,2000.0,217.0,10.0,0.0,openml__mfeat-factors__12,"['tabzilla', 'cc18', 'amlb']",False,True,True 114 | 14,14.0,2000.0,77.0,10.0,0.0,openml__mfeat-fourier__14,"['tabzilla', 'cc18']",True,True,True 115 | 16,16.0,2000.0,65.0,10.0,0.0,openml__mfeat-karhunen__16,"['tabzilla', 'cc18']",True,True,True 116 | 18,18.0,2000.0,7.0,10.0,0.0,openml__mfeat-morphological__18,"['tabzilla', 'cc18']",True,True,True 117 | 40979,146824.0,2000.0,241.0,10.0,0.0,openml__mfeat-pixel__146824,"['tabzilla', 'cc18']",False,True,True 118 | 22,22.0,2000.0,48.0,10.0,0.0,openml__mfeat-zernike__22,"['tabzilla', 'cc18']",True,True,True 119 | 554,3573.0,70000.0,785.0,10.0,0.0,openml__mnist_784__3573,"['tabzilla', 'cc18']",False,False,True 120 | 334,146065.0,601.0,7.0,2.0,0.0,openml__monks-problems-2__146065,['tabzilla'],False,False,False 121 | 24,24.0,8124.0,23.0,2.0,2480.0,openml__mushroom__24,['tabzilla'],False,False,False 122 | 1116,3950.0,6598.0,168.0,2.0,0.0,openml__musk__3950,['tabzilla'],False,False,False 123 | 1486,9977.0,34465.0,119.0,2.0,0.0,openml__nomao__9977,"['tabzilla', 'cc18', 'amlb']",False,True,True 124 | 23517,167120.0,96320.0,22.0,2.0,0.0,openml__numerai28.6__167120,"['tabzilla', 'cc18', 'amlb']",True,True,True 125 | 1568,9892.0,12958.0,9.0,4.0,0.0,openml__nursery__9892,['tabzilla'],False,False,False 126 | 1493,9956.0,1599.0,65.0,100.0,0.0,openml__one-hundred-plants-texture__9956,['tabzilla'],False,False,False 127 | 28,28.0,5620.0,65.0,10.0,0.0,openml__optdigits__28,"['tabzilla', 'cc18']",True,True,True 128 | 1487,9978.0,2534.0,73.0,2.0,0.0,openml__ozone-level-8hr__9978,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 129 | 30,30.0,5473.0,11.0,5.0,0.0,openml__page-blocks__30,['tabzilla'],False,False,False 130 | 1068,3918.0,1109.0,22.0,2.0,0.0,openml__pc1__3918,"['tabzilla', 'cc18']",True,True,True 131 | 1050,3903.0,1563.0,38.0,2.0,0.0,openml__pc3__3903,"['tabzilla', 'cc18']",True,True,True 132 | 1049,3902.0,1458.0,38.0,2.0,0.0,openml__pc4__3902,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 133 | 32,32.0,10992.0,17.0,10.0,0.0,openml__pendigits__32,"['tabzilla', 'cc18']",True,True,True 134 | 41145,190410.0,5832.0,309.0,2.0,0.0,openml__philippine__190410,"['tabzilla', 'amlb', 'amlb']",False,False,False 135 | 1489,9952.0,5404.0,6.0,2.0,0.0,openml__phoneme__9952,"['tabzilla', 'cc18', 'amlb']",True,True,True 136 | 1567,9890.0,1025009.0,11.0,10.0,0.0,openml__poker-hand__9890,['tabzilla'],False,False,False 137 | 871,3735.0,3848.0,6.0,2.0,0.0,openml__pollen__3735,['tabzilla'],False,False,False 138 | 40683,146210.0,88.0,9.0,2.0,0.0,openml__postoperative-patient-data__146210,['tabzilla'],False,False,False 139 | 171,146032.0,339.0,18.0,21.0,225.0,openml__primary-tumor__146032,['tabzilla'],False,False,False 140 | 470,3561.0,672.0,10.0,2.0,1200.0,openml__profb__3561,['tabzilla'],False,False,False 141 | 1494,9957.0,1055.0,42.0,2.0,0.0,openml__qsar-biodeg__9957,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 142 | 782,3647.0,120.0,3.0,2.0,0.0,openml__rabe_266__3647,['tabzilla'],False,False,False 143 | 41161,168338.0,20000.0,4297.0,2.0,0.0,openml__riccardo__168338,"['tabzilla', 'amlb']",False,False,False 144 | 41165,168332.0,10000.0,7201.0,10.0,0.0,openml__robert__168332,"['tabzilla', 'amlb']",False,False,False 145 | 182,2074.0,6430.0,37.0,6.0,0.0,openml__satimage__2074,"['tabzilla', 'cc18']",True,True,True 146 | 312,3485.0,2407.0,300.0,2.0,0.0,openml__scene__3485,['tabzilla'],False,False,False 147 | 40984,146822.0,2310.0,20.0,7.0,0.0,openml__segment__146822,"['tabzilla', 'cc18', 'amlb']",True,True,True 148 | 1501,9964.0,1593.0,257.0,10.0,0.0,openml__semeion__9964,"['tabzilla', 'cc18']",False,True,True 149 | 40685,146212.0,58000.0,10.0,7.0,0.0,openml__shuttle__146212,"['tabzilla', 'amlb']",False,False,False 150 | 38,3021.0,3772.0,30.0,2.0,6064.0,openml__sick__3021,"['tabzilla', 'cc18']",True,True,True 151 | 1502,9965.0,245057.0,4.0,2.0,0.0,openml__skin-segmentation__9965,['tabzilla'],False,False,False 152 | 934,3797.0,1156.0,6.0,2.0,0.0,openml__socmob__3797,['tabzilla'],False,False,False 153 | 174,2068.0,1066.0,13.0,3.0,0.0,openml__solar-flare__2068,['tabzilla'],False,False,False 154 | 40,39.0,208.0,61.0,2.0,0.0,openml__sonar__39,['tabzilla'],False,False,False 155 | 42,41.0,683.0,36.0,19.0,2337.0,openml__soybean__41,['tabzilla'],False,False,False 156 | 44,43.0,4601.0,58.0,2.0,0.0,openml__spambase__43,"['tabzilla', 'cc18']",True,True,True 157 | 46,45.0,3190.0,61.0,3.0,0.0,openml__splice__45,"['tabzilla', 'cc18']",True,True,True 158 | 40982,146817.0,1941.0,28.0,7.0,0.0,openml__steel-plates-fault__146817,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 159 | 1036,3889.0,14395.0,217.0,2.0,0.0,openml__sylva_agnostic__3889,['tabzilla'],False,False,False 160 | 41146,168912.0,5124.0,21.0,2.0,0.0,openml__sylvine__168912,"['tabzilla', 'amlb']",False,False,False 161 | 377,3512.0,600.0,61.0,6.0,0.0,openml__synthetic_control__3512,['tabzilla'],False,False,False 162 | 48,47.0,151.0,6.0,3.0,0.0,openml__tae__47,['tabzilla'],False,False,False 163 | 40499,125922.0,5500.0,41.0,11.0,0.0,openml__texture__125922,"['tabzilla', 'cc18']",False,True,True 164 | 50,49.0,958.0,10.0,2.0,0.0,openml__tic-tac-toe__49,"['tabzilla', 'cc18']",True,True,True 165 | 885,3748.0,131.0,4.0,2.0,0.0,openml__transplant__3748,['tabzilla'],False,False,False 166 | 54,53.0,846.0,19.0,4.0,0.0,openml__vehicle__53,"['tabzilla', 'cc18', 'amlb']",True,True,True 167 | 736,3602.0,111.0,4.0,2.0,0.0,openml__visualizing_environmental__3602,['tabzilla'],False,False,False 168 | 867,3731.0,130.0,3.0,2.0,0.0,openml__visualizing_livestock__3731,['tabzilla'],False,False,False 169 | 41166,168331.0,58310.0,181.0,10.0,0.0,openml__volkert__168331,"['tabzilla', 'amlb']",False,False,False 170 | 307,3022.0,990.0,13.0,11.0,0.0,openml__vowel__3022,"['tabzilla', 'cc18']",False,True,True 171 | 1509,9945.0,149332.0,5.0,22.0,0.0,openml__walking-activity__9945,['tabzilla'],False,False,False 172 | 1497,9960.0,5456.0,25.0,4.0,0.0,openml__wall-robot-navigation__9960,"['tabzilla', 'cc18']",True,True,True 173 | 1510,9946.0,569.0,31.0,2.0,0.0,openml__wdbc__9946,"['tabzilla', 'cc18']",True,True,True 174 | 40983,146820.0,4839.0,6.0,2.0,0.0,openml__wilt__146820,"['tabzilla', 'cc18', 'amlb', 'amlb']",True,True,True 175 | 40733,145793.0,1269.0,9.0,4.0,0.0,openml__yeast__145793,['tabzilla'],False,False,False 176 | 44089,361055.0,16714.0,11.0,2.0,0.0,,['grins'],False,False,False 177 | 44120,361060.0,38474.0,8.0,2.0,0.0,,['grins'],False,False,False 178 | 44121,361061.0,566602.0,11.0,2.0,0.0,,['grins'],False,False,False 179 | 44122,361062.0,10082.0,27.0,2.0,0.0,,['grins'],False,False,False 180 | 44123,361063.0,13488.0,17.0,2.0,0.0,,['grins'],False,False,False 181 | 44125,361065.0,13376.0,11.0,2.0,0.0,,['grins'],False,False,False 182 | 44126,361066.0,10578.0,8.0,2.0,0.0,,['grins'],False,False,False 183 | 44128,361068.0,72998.0,51.0,2.0,0.0,,['grins'],False,False,False 184 | 44129,361069.0,940160.0,25.0,2.0,0.0,,['grins'],False,False,False 185 | 44130,361070.0,7608.0,21.0,2.0,0.0,,['grins'],False,False,False 186 | 45022,361273.0,71090.0,8.0,2.0,0.0,,['grins'],False,False,False 187 | 45021,361274.0,57580.0,55.0,2.0,0.0,,['grins'],False,False,False 188 | 45020,361275.0,13272.0,21.0,2.0,0.0,,['grins'],False,False,False 189 | 45019,361276.0,3434.0,420.0,2.0,0.0,,['grins'],False,False,False 190 | 45028,361277.0,20634.0,9.0,2.0,0.0,,['grins'],False,False,False 191 | 45026,361278.0,10000.0,23.0,2.0,0.0,,['grins'],False,False,False 192 | 44156,361110.0,38474.0,9.0,2.0,0.0,,['grins'],False,False,False 193 | 44157,361111.0,7608.0,24.0,2.0,0.0,,['grins'],False,False,False 194 | 44159,361113.0,423680.0,55.0,2.0,0.0,,['grins'],False,False,False 195 | 45035,361282.0,58252.0,32.0,2.0,0.0,,['grins'],False,False,False 196 | 45036,361283.0,13272.0,22.0,2.0,0.0,,['grins'],False,False,False 197 | 45038,361285.0,111762.0,33.0,2.0,0.0,,['grins'],False,False,False 198 | 45039,361286.0,4966.0,12.0,2.0,0.0,,['grins'],False,False,False 199 | 181,2073.0,1484.0,9.0,10.0,0.0,,"['amlb', 'amlb']",False,False,False 200 | 1111,3945.0,50000.0,231.0,2.0,8024152.0,,['amlb'],False,False,False 201 | 1457,10090.0,1500.0,10001.0,50.0,0.0,,"['amlb', 'amlb']",False,False,False 202 | 41167,189355.0,416188.0,61.0,355.0,0.0,,['amlb'],False,False,False 203 | 41158,189922.0,3153.0,971.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False 204 | 41144,190392.0,3140.0,260.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False 205 | 41156,190411.0,4147.0,49.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False 206 | 41157,190412.0,100.0,10001.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False 207 | 4541,211986.0,101766.0,50.0,3.0,0.0,,"['amlb', 'amlb']",False,False,False 208 | 1515,359953.0,571.0,1301.0,20.0,0.0,,"['amlb', 'amlb']",False,False,False 209 | 40498,359974.0,4898.0,12.0,7.0,0.0,,"['amlb', 'amlb']",False,False,False 210 | 41169,359984.0,65196.0,28.0,100.0,0.0,,['amlb'],False,False,False 211 | 41162,359991.0,72983.0,33.0,2.0,149271.0,,"['amlb', 'amlb']",False,False,False 212 | 42733,359992.0,39948.0,12.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False 213 | 42734,359993.0,50789.0,20.0,3.0,154107.0,,"['amlb', 'amlb']",False,False,False 214 | 42732,359994.0,2215023.0,9.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False 215 | 42746,360112.0,4898431.0,42.0,23.0,0.0,,"['amlb', 'amlb']",False,False,False 216 | 42742,360113.0,595212.0,58.0,2.0,846458.0,,"['amlb', 'amlb']",False,False,False 217 | 42769,360114.0,1000000.0,29.0,2.0,0.0,,"['amlb', 'amlb']",False,False,False 218 | 43072,360975.0,50000.0,14892.0,2.0,19658569.0,,"['amlb', 'amlb']",False,False,False 219 | 20,,2000.0,241.0,10.0,0.0,,['additional'],False,False, 220 | 53,,270.0,14.0,2.0,0.0,,['additional'],False,False, 221 | 56,,435.0,17.0,2.0,392.0,,['additional'],False,False, 222 | 137,,39366.0,10.0,2.0,0.0,,['additional'],False,False, 223 | 273,,120919.0,1002.0,2.0,0.0,,['additional'],False,False, 224 | 285,,194.0,29.0,8.0,0.0,,['additional'],False,False, 225 | 311,,937.0,50.0,2.0,0.0,,['additional'],False,False, 226 | 333,,556.0,7.0,2.0,0.0,,['additional'],False,False, 227 | 335,,554.0,7.0,2.0,0.0,,['additional'],False,False, 228 | 336,,267.0,23.0,2.0,0.0,,['additional'],False,False, 229 | 337,,349.0,45.0,2.0,0.0,,['additional'],False,False, 230 | 338,,155.0,9.0,4.0,0.0,,['additional'],False,False, 231 | 382,,7019.0,61.0,8.0,48089.0,,['additional'],False,False, 232 | 384,,336.0,7903.0,6.0,0.0,,['additional'],False,False, 233 | 387,,414.0,6430.0,9.0,0.0,,['additional'],False,False, 234 | 389,,2463.0,2001.0,17.0,0.0,,['additional'],False,False, 235 | 396,,3204.0,13196.0,6.0,0.0,,['additional'],False,False, 236 | 446,,200.0,8.0,2.0,0.0,,['additional'],False,False, 237 | 449,,163.0,27.0,5.0,9.0,,['additional'],False,False, 238 | 450,,264.0,5.0,2.0,0.0,,['additional'],False,False, 239 | 452,,285.0,8.0,7.0,27.0,,['additional'],False,False, 240 | 460,,379.0,8.0,4.0,1418.0,,['additional'],False,False, 241 | 463,,180.0,32.0,2.0,0.0,,['additional'],False,False, 242 | 464,,250.0,3.0,2.0,0.0,,['additional'],False,False, 243 | 466,,340.0,15.0,2.0,834.0,,['additional'],False,False, 244 | 474,,364.0,33.0,6.0,101.0,,['additional'],False,False, 245 | 475,,400.0,6.0,4.0,0.0,,['additional'],False,False, 246 | 481,,209.0,9.0,2.0,15.0,,['additional'],False,False, 247 | 679,,1024.0,3.0,4.0,0.0,,['additional'],False,False, 248 | 694,,310.0,9.0,9.0,0.0,,['additional'],False,False, 249 | 717,,508.0,11.0,2.0,0.0,,['additional'],False,False, 250 | 721,,200.0,11.0,2.0,0.0,,['additional'],False,False, 251 | 724,,468.0,4.0,2.0,0.0,,['additional'],False,False, 252 | 733,,209.0,7.0,2.0,0.0,,['additional'],False,False, 253 | 738,,195.0,11.0,2.0,2.0,,['additional'],False,False, 254 | 745,,159.0,16.0,2.0,0.0,,['additional'],False,False, 255 | 746,,250.0,26.0,2.0,0.0,,['additional'],False,False, 256 | 747,,167.0,5.0,2.0,0.0,,['additional'],False,False, 257 | 748,,163.0,6.0,2.0,0.0,,['additional'],False,False, 258 | 750,,500.0,8.0,2.0,0.0,,['additional'],False,False, 259 | 753,,194.0,33.0,2.0,0.0,,['additional'],False,False, 260 | 757,,528.0,22.0,2.0,504.0,,['additional'],False,False, 261 | 763,,250.0,11.0,2.0,0.0,,['additional'],False,False, 262 | 764,,450.0,4.0,2.0,0.0,,['additional'],False,False, 263 | 765,,475.0,4.0,2.0,0.0,,['additional'],False,False, 264 | 767,,475.0,4.0,2.0,0.0,,['additional'],False,False, 265 | 774,,662.0,4.0,2.0,0.0,,['additional'],False,False, 266 | 778,,252.0,15.0,2.0,0.0,,['additional'],False,False, 267 | 779,,500.0,26.0,2.0,0.0,,['additional'],False,False, 268 | 786,,303.0,14.0,2.0,6.0,,['additional'],False,False, 269 | 788,,186.0,61.0,2.0,0.0,,['additional'],False,False, 270 | 795,,662.0,4.0,2.0,0.0,,['additional'],False,False, 271 | 796,,209.0,8.0,2.0,0.0,,['additional'],False,False, 272 | 798,,303.0,14.0,2.0,6.0,,['additional'],False,False, 273 | 801,,185.0,3.0,2.0,0.0,,['additional'],False,False, 274 | 802,,1945.0,19.0,2.0,1133.0,,['additional'],False,False, 275 | 805,,500.0,51.0,2.0,0.0,,['additional'],False,False, 276 | 806,,1000.0,51.0,2.0,0.0,,['additional'],False,False, 277 | 810,,418.0,19.0,2.0,1239.0,,['additional'],False,False, 278 | 811,,264.0,3.0,2.0,0.0,,['additional'],False,False, 279 | 814,,468.0,3.0,2.0,0.0,,['additional'],False,False, 280 | 816,,8192.0,9.0,2.0,0.0,,['additional'],False,False, 281 | 820,,235.0,13.0,2.0,0.0,,['additional'],False,False, 282 | 825,,506.0,21.0,2.0,0.0,,['additional'],False,False, 283 | 827,,662.0,4.0,2.0,0.0,,['additional'],False,False, 284 | 831,,398.0,8.0,2.0,6.0,,['additional'],False,False, 285 | 837,,1000.0,51.0,2.0,0.0,,['additional'],False,False, 286 | 838,,500.0,26.0,2.0,0.0,,['additional'],False,False, 287 | 839,,782.0,9.0,2.0,466.0,,['additional'],False,False, 288 | 840,,205.0,26.0,2.0,57.0,,['additional'],False,False, 289 | 841,,950.0,10.0,2.0,0.0,,['additional'],False,False, 290 | 843,,22784.0,9.0,2.0,0.0,,['additional'],False,False, 291 | 844,,286.0,10.0,2.0,9.0,,['additional'],False,False, 292 | 852,,159.0,10.0,2.0,6.0,,['additional'],False,False, 293 | 854,,158.0,8.0,2.0,87.0,,['additional'],False,False, 294 | 860,,380.0,3.0,2.0,0.0,,['additional'],False,False, 295 | 880,,284.0,11.0,2.0,0.0,,['additional'],False,False, 296 | 886,,500.0,8.0,2.0,0.0,,['additional'],False,False, 297 | 895,,222.0,3.0,2.0,0.0,,['additional'],False,False, 298 | 896,,500.0,26.0,2.0,0.0,,['additional'],False,False, 299 | 900,,400.0,7.0,2.0,0.0,,['additional'],False,False, 300 | 903,,1000.0,26.0,2.0,0.0,,['additional'],False,False, 301 | 906,,400.0,8.0,2.0,0.0,,['additional'],False,False, 302 | 907,,400.0,8.0,2.0,0.0,,['additional'],False,False, 303 | 908,,400.0,8.0,2.0,0.0,,['additional'],False,False, 304 | 909,,400.0,8.0,2.0,0.0,,['additional'],False,False, 305 | 911,,250.0,6.0,2.0,0.0,,['additional'],False,False, 306 | 912,,1000.0,6.0,2.0,0.0,,['additional'],False,False, 307 | 913,,1000.0,11.0,2.0,0.0,,['additional'],False,False, 308 | 915,,315.0,14.0,2.0,0.0,,['additional'],False,False, 309 | 925,,323.0,5.0,2.0,0.0,,['additional'],False,False, 310 | 930,,1302.0,34.0,2.0,7830.0,,['additional'],False,False, 311 | 931,,662.0,4.0,2.0,0.0,,['additional'],False,False, 312 | 939,,228.0,9.0,2.0,20.0,,['additional'],False,False, 313 | 940,,527.0,37.0,2.0,542.0,,['additional'],False,False, 314 | 941,,189.0,10.0,2.0,0.0,,['additional'],False,False, 315 | 943,,500.0,11.0,2.0,0.0,,['additional'],False,False, 316 | 949,,559.0,5.0,2.0,0.0,,['additional'],False,False, 317 | 957,,412.0,9.0,2.0,96.0,,['additional'],False,False, 318 | 966,,1340.0,17.0,2.0,20.0,,['additional'],False,False, 319 | 968,,365.0,4.0,2.0,30.0,,['additional'],False,False, 320 | 981,,10108.0,69.0,2.0,2699.0,,['additional'],False,False, 321 | 984,,366.0,5.0,2.0,1.0,,['additional'],False,False, 322 | 987,,500.0,23.0,2.0,0.0,,['additional'],False,False, 323 | 996,,214.0,10.0,2.0,0.0,,['additional'],False,False, 324 | 1000,,3772.0,30.0,2.0,6064.0,,['additional'],False,False, 325 | 1002,,7485.0,56.0,2.0,32427.0,,['additional'],False,False, 326 | 1016,,990.0,14.0,2.0,0.0,,['additional'],False,False, 327 | 1018,,8844.0,57.0,2.0,34843.0,,['additional'],False,False, 328 | 1037,,4562.0,15.0,2.0,88.0,,['additional'],False,False, 329 | 1042,,3468.0,785.0,2.0,0.0,,['additional'],False,False, 330 | 1048,,369.0,9.0,2.0,0.0,,['additional'],False,False, 331 | 1054,,161.0,40.0,2.0,0.0,,['additional'],False,False, 332 | 1071,,403.0,38.0,2.0,0.0,,['additional'],False,False, 333 | 1073,,274.0,9.0,2.0,0.0,,['additional'],False,False, 334 | 1100,,478.0,11.0,3.0,0.0,,['additional'],False,False, 335 | 1112,,50000.0,231.0,2.0,8024152.0,,['additional'],False,False, 336 | 1115,,151.0,7.0,3.0,0.0,,['additional'],False,False, 337 | 1126,,412.0,10936.0,2.0,0.0,,['additional'],False,False, 338 | 1129,,384.0,10936.0,2.0,0.0,,['additional'],False,False, 339 | 1130,,1545.0,10936.0,2.0,0.0,,['additional'],False,False, 340 | 1131,,193.0,10936.0,2.0,0.0,,['additional'],False,False, 341 | 1142,,1545.0,10936.0,2.0,0.0,,['additional'],False,False, 342 | 1156,,275.0,10936.0,2.0,0.0,,['additional'],False,False, 343 | 1162,,322.0,10936.0,2.0,0.0,,['additional'],False,False, 344 | 1412,,226.0,24.0,2.0,0.0,,['additional'],False,False, 345 | 1442,,253.0,38.0,2.0,0.0,,['additional'],False,False, 346 | 1443,,661.0,38.0,2.0,0.0,,['additional'],False,False, 347 | 1444,,1043.0,38.0,2.0,0.0,,['additional'],False,False, 348 | 1446,,296.0,38.0,2.0,0.0,,['additional'],False,False, 349 | 1447,,327.0,38.0,2.0,0.0,,['additional'],False,False, 350 | 1448,,194.0,40.0,2.0,0.0,,['additional'],False,False, 351 | 1451,,705.0,38.0,2.0,0.0,,['additional'],False,False, 352 | 1452,,745.0,37.0,2.0,0.0,,['additional'],False,False, 353 | 1453,,1077.0,38.0,2.0,0.0,,['additional'],False,False, 354 | 1481,,28056.0,7.0,18.0,0.0,,['additional'],False,False, 355 | 1488,,195.0,23.0,2.0,0.0,,['additional'],False,False, 356 | 1490,,182.0,13.0,2.0,0.0,,['additional'],False,False, 357 | 1495,,250.0,7.0,2.0,0.0,,['additional'],False,False, 358 | 1498,,462.0,10.0,2.0,0.0,,['additional'],False,False, 359 | 1499,,210.0,8.0,3.0,0.0,,['additional'],False,False, 360 | 1503,,263256.0,15.0,10.0,0.0,,['additional'],False,False, 361 | 1506,,470.0,17.0,2.0,0.0,,['additional'],False,False, 362 | 1507,,7400.0,21.0,2.0,0.0,,['additional'],False,False, 363 | 1508,,403.0,6.0,5.0,0.0,,['additional'],False,False, 364 | 1511,,440.0,9.0,2.0,0.0,,['additional'],False,False, 365 | 1512,,200.0,14.0,5.0,0.0,,['additional'],False,False, 366 | 1520,,164.0,91.0,5.0,0.0,,['additional'],False,False, 367 | 1523,,310.0,7.0,3.0,0.0,,['additional'],False,False, 368 | 1531,,10176.0,4.0,5.0,0.0,,['additional'],False,False, 369 | 1532,,10668.0,4.0,5.0,0.0,,['additional'],False,False, 370 | 1538,,8753.0,4.0,5.0,0.0,,['additional'],False,False, 371 | 1541,,8654.0,4.0,5.0,0.0,,['additional'],False,False, 372 | 1546,,1112.0,4.0,5.0,0.0,,['additional'],False,False, 373 | 1547,,1000.0,21.0,2.0,0.0,,['additional'],False,False, 374 | 4153,,180.0,68.0,6.0,0.0,,['additional'],False,False, 375 | 23499,,277.0,10.0,2.0,0.0,,['additional'],False,False, 376 | 40646,,1600.0,21.0,2.0,0.0,,['additional'],False,False, 377 | 40663,,399.0,33.0,5.0,0.0,,['additional'],False,False, 378 | 40669,,160.0,7.0,2.0,0.0,,['additional'],False,False, 379 | 40680,,1324.0,11.0,2.0,0.0,,['additional'],False,False, 380 | 40682,,215.0,6.0,3.0,0.0,,['additional'],False,False, 381 | 40690,,512.0,10.0,2.0,0.0,,['additional'],False,False, 382 | 40693,,973.0,10.0,2.0,0.0,,['additional'],False,False, 383 | 40705,,959.0,45.0,2.0,0.0,,['additional'],False,False, 384 | 40706,,1124.0,11.0,2.0,0.0,,['additional'],False,False, 385 | 40710,,303.0,14.0,2.0,0.0,,['additional'],False,False, 386 | 40711,,303.0,8.0,5.0,0.0,,['additional'],False,False, 387 | 41430,,281.0,98.0,2.0,2.0,,['additional'],False,False, 388 | 41538,,246.0,7.0,2.0,0.0,,['additional'],False,False, 389 | 41919,,527.0,23.0,4.0,0.0,,['additional'],False,False, 390 | 41976,,156.0,81.0,2.0,0.0,,['additional'],False,False, 391 | 42172,,202.0,20.0,2.0,17.0,,['additional'],False,False, 392 | 42261,,150.0,5.0,3.0,0.0,,['additional'],False,False, 393 | 42544,,265.0,11.0,8.0,0.0,,['additional'],False,False, 394 | 42585,,344.0,7.0,3.0,18.0,,['additional'],False,False, 395 | 42638,,891.0,8.0,2.0,689.0,,['additional'],False,False, 396 | --------------------------------------------------------------------------------