├── docs ├── source │ ├── .gitignore │ ├── _figures │ │ ├── modular.png │ │ ├── voyageai.webp │ │ ├── OpenAI_Logo.png │ │ ├── architecture.png │ │ ├── cohere-logo.png │ │ ├── hf-logo-with-title.png │ │ ├── pytorch_frame_logo.JPG │ │ ├── pytorch_frame_logo.png │ │ ├── pytorch_frame_logo_text.JPG │ │ └── pytorch_frame_logo_text.png │ ├── modules │ │ ├── root.rst │ │ ├── utils.rst │ │ ├── config.rst │ │ ├── gbdt.rst │ │ ├── datasets.rst │ │ ├── data.rst │ │ ├── nn.rst │ │ └── transforms.rst │ ├── _templates │ │ └── autosummary │ │ │ └── class.rst │ ├── _static │ │ └── js │ │ │ └── version_alert.js │ ├── get_started │ │ └── installation.rst │ ├── index.rst │ └── conf.py ├── requirements.txt ├── Makefile └── README.md ├── torch_frame ├── nn │ ├── encoding │ │ ├── __init__.py │ │ ├── positional_encoding.py │ │ └── cyclic_encoding.py │ ├── decoder │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── excelformer_decoder.py │ │ └── trompt_decoder.py │ ├── __init__.py │ ├── conv │ │ ├── __init__.py │ │ ├── table_conv.py │ │ └── ft_transformer_convs.py │ ├── models │ │ ├── __init__.py │ │ ├── ft_transformer.py │ │ └── mlp.py │ ├── encoder │ │ ├── __init__.py │ │ ├── encoder.py │ │ └── stypewise_encoder.py │ ├── utils │ │ └── init.py │ └── base.py ├── testing │ ├── __init__.py │ ├── image_embedder.py │ ├── text_embedder.py │ ├── decorators.py │ └── text_tokenizer.py ├── gbdt │ └── __init__.py ├── config │ ├── __init__.py │ ├── model.py │ ├── text_tokenizer.py │ ├── text_embedder.py │ └── image_embedder.py ├── transforms │ ├── __init__.py │ ├── base_transform.py │ ├── fittable_base_transform.py │ └── mutual_information_sort.py ├── utils │ ├── __init__.py │ ├── memory.py │ └── split.py ├── data │ ├── __init__.py │ ├── download.py │ └── loader.py ├── __init__.py ├── datasets │ ├── __init__.py │ ├── titanic.py │ ├── dota2.py │ ├── adult_census_income.py │ ├── amazon_fine_food_reviews.py │ ├── poker_hand.py │ ├── mercari.py │ ├── bank_marketing.py │ ├── diamond_images.py │ ├── amphibians.py │ ├── movielens_1m.py │ ├── mushroom.py │ └── kdd_census_income.py ├── typing.py └── _stype.py ├── .github ├── labeler.yml ├── workflows │ ├── changelog.yml │ ├── labeler.yml │ ├── linting.yml │ ├── auto-merge.yml │ ├── documentation.yml │ ├── release.yml │ └── testing.yml ├── dependabot.yml ├── actions │ └── setup │ │ └── action.yml └── CONTRIBUTING.md ├── readthedocs.yml ├── .gitignore ├── codecov.yml ├── test ├── nn │ ├── conv │ │ ├── test_ft_transformer_convs.py │ │ ├── test_excelformer_conv.py │ │ ├── test_tab_transformer_conv.py │ │ └── test_trompt_conv.py │ ├── decoder │ │ ├── test_excelformer_decoder.py │ │ └── test_trompt_decoder.py │ ├── encoding │ │ ├── test_positional_encoding.py │ │ └── test_cyclic_encoding.py │ ├── models │ │ ├── test_tabnet.py │ │ ├── test_resnet.py │ │ ├── test_ft_transformer.py │ │ ├── test_mlp.py │ │ ├── test_tab_transformer.py │ │ ├── test_trompt.py │ │ ├── test_excelformer.py │ │ └── test_compile.py │ └── test_simple_basecls.py ├── data │ └── test_loader.py ├── utils │ ├── test_split.py │ ├── test_memory.py │ ├── test_concat.py │ ├── test_infer_stype.py │ └── test_io.py ├── test_stype.py ├── datasets │ ├── test_titanic.py │ └── test_movielens_1m.py ├── transforms │ └── test_mutual_information_sort.py ├── conftest.py └── gbdt │ └── test_gbdt.py ├── CITATION.cff ├── LICENSE ├── examples ├── tabpfn_classification.py ├── tuned_gbdt.py └── tabnet.py ├── .pre-commit-config.yaml ├── benchmark └── encoder │ └── README.md └── pyproject.toml /docs/source/.gitignore: -------------------------------------------------------------------------------- 1 | generated/ 2 | -------------------------------------------------------------------------------- /docs/source/_figures/modular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/modular.png -------------------------------------------------------------------------------- /docs/source/_figures/voyageai.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/voyageai.webp -------------------------------------------------------------------------------- /docs/source/_figures/OpenAI_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/OpenAI_Logo.png -------------------------------------------------------------------------------- /docs/source/_figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/architecture.png -------------------------------------------------------------------------------- /docs/source/_figures/cohere-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/cohere-logo.png -------------------------------------------------------------------------------- /docs/source/_figures/hf-logo-with-title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/hf-logo-with-title.png -------------------------------------------------------------------------------- /docs/source/_figures/pytorch_frame_logo.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo.JPG -------------------------------------------------------------------------------- /docs/source/_figures/pytorch_frame_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo.png -------------------------------------------------------------------------------- /docs/source/_figures/pytorch_frame_logo_text.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo_text.JPG -------------------------------------------------------------------------------- /docs/source/_figures/pytorch_frame_logo_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo_text.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl 2 | git+https://github.com/pyg-team/pyg_sphinx_theme.git 3 | -------------------------------------------------------------------------------- /docs/source/modules/root.rst: -------------------------------------------------------------------------------- 1 | torch_frame 2 | =========== 3 | 4 | .. automodule:: torch_frame.stype 5 | :members: 6 | 7 | .. automodule:: torch_frame.typing 8 | :members: 9 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | -------------------------------------------------------------------------------- /torch_frame/nn/encoding/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Encoding package.""" 2 | from .cyclic_encoding import CyclicEncoding 3 | from .positional_encoding import PositionalEncoding 4 | 5 | __all__ = classes = [ 6 | 'CyclicEncoding', 7 | 'PositionalEncoding', 8 | ] 9 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD = sphinx-build 2 | SPHINXPROJ = pytorch-frame 3 | SOURCEDIR = source 4 | BUILDDIR = build 5 | 6 | .PHONY: help Makefile 7 | 8 | %: Makefile 9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(0) 10 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | documentation: 2 | - docs/**/* 3 | 4 | example: 5 | - examples/**/* 6 | 7 | data: 8 | - torch_frame/data/**/* 9 | 10 | dataset: 11 | - torch_frame/datasets/**/* 12 | 13 | nn: 14 | - torch_frame/nn/**/* 15 | 16 | utils: 17 | - torch_frame/utils/**/* 18 | -------------------------------------------------------------------------------- /torch_frame/testing/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Utility package for testing.""" 2 | from .decorators import ( 3 | has_package, 4 | withPackage, 5 | withCUDA, 6 | onlyCUDA, 7 | ) 8 | 9 | __all__ = [ 10 | 'has_package', 11 | 'withPackage', 12 | 'withCUDA', 13 | 'onlyCUDA', 14 | ] 15 | -------------------------------------------------------------------------------- /torch_frame/nn/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Decoder package.""" 2 | from .decoder import Decoder 3 | from .trompt_decoder import TromptDecoder 4 | from .excelformer_decoder import ExcelFormerDecoder 5 | 6 | __all__ = classes = [ 7 | 'Decoder', 8 | 'TromptDecoder', 9 | 'ExcelFormerDecoder', 10 | ] 11 | -------------------------------------------------------------------------------- /torch_frame/nn/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Neural network module package.""" 2 | from .base import Module 3 | from .encoder import * # noqa 4 | from .encoding import * # noqa 5 | from .conv import * # noqa 6 | from .decoder import * # noqa 7 | from .models import * # noqa 8 | 9 | __all__ = [ 10 | 'Module', 11 | ] 12 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.10" 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | - method: pip 15 | path: . 16 | 17 | formats: [] 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .pytest_cache/ 3 | .mypy_cache/ 4 | .DS_Store 5 | build/ 6 | dist/ 7 | alpha/ 8 | .cache/ 9 | .eggs/ 10 | *.egg-info/ 11 | .ipynb_checkpoints 12 | .coverage 13 | .coverage.* 14 | coverage.xml 15 | .vscode 16 | .idea 17 | .venv 18 | venv/* 19 | *.out 20 | data/** 21 | catboost_info/ 22 | .pt_tmp/ 23 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # See: https://docs.codecov.io/docs/codecov-yaml 2 | coverage: 3 | range: 80..100 4 | round: down 5 | precision: 2 6 | status: 7 | project: 8 | default: 9 | target: 80% 10 | threshold: 1% 11 | patch: 12 | default: 13 | target: 80% 14 | threshold: 1% 15 | -------------------------------------------------------------------------------- /torch_frame/gbdt/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Gradient Boosting Decision Trees package.""" 2 | from .gbdt import GBDT 3 | from .tuned_xgboost import XGBoost 4 | from .tuned_catboost import CatBoost 5 | from .tuned_lightgbm import LightGBM 6 | 7 | __all__ = classes = [ 8 | 'GBDT', 9 | 'XGBoost', 10 | 'CatBoost', 11 | 'LightGBM', 12 | ] 13 | -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | torch_frame.utils 2 | ================= 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | .. currentmodule:: torch_frame.utils 8 | 9 | .. autosummary:: 10 | :nosignatures: 11 | :toctree: ../generated 12 | 13 | {% for name in torch_frame.utils.functions %} 14 | {{ name }} 15 | {% endfor %} 16 | -------------------------------------------------------------------------------- /docs/source/modules/config.rst: -------------------------------------------------------------------------------- 1 | torch_frame.config 2 | ================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | .. currentmodule:: torch_frame.config 8 | 9 | Text Embedder 10 | ------------- 11 | 12 | .. autosummary:: 13 | :nosignatures: 14 | :toctree: ../generated 15 | 16 | {% for name in torch_frame.config.classes %} 17 | {{ name }} 18 | {% endfor %} 19 | -------------------------------------------------------------------------------- /docs/source/modules/gbdt.rst: -------------------------------------------------------------------------------- 1 | torch_frame.gbdt 2 | ================ 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | .. currentmodule:: torch_frame.gbdt 8 | 9 | Gradient Boosted Decision Trees 10 | ------------------------------- 11 | 12 | .. autosummary:: 13 | :nosignatures: 14 | :toctree: ../generated 15 | 16 | {% for name in torch_frame.gbdt.classes %} 17 | {{ name }} 18 | {% endfor %} 19 | -------------------------------------------------------------------------------- /test/nn/conv/test_ft_transformer_convs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import FTTransformerConvs 4 | 5 | 6 | def test_ft_transformer_convs(): 7 | x = torch.randn(size=(10, 3, 8)) 8 | conv = FTTransformerConvs(channels=8, num_layers=3) 9 | x, x_cls = conv(x) 10 | # The first added column corresponds to CLS token. 11 | assert x.shape == (10, 3, 8) 12 | assert x_cls.shape == (10, 8) 13 | -------------------------------------------------------------------------------- /torch_frame/config/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Config package.""" 2 | from .text_embedder import TextEmbedderConfig 3 | from .text_tokenizer import TextTokenizerConfig 4 | from .model import ModelConfig 5 | from .image_embedder import ImageEmbedderConfig, ImageEmbedder 6 | 7 | __all__ = classes = [ 8 | 'TextEmbedderConfig', 9 | 'TextTokenizerConfig', 10 | 'ModelConfig', 11 | 'ImageEmbedderConfig', 12 | 'ImageEmbedder', 13 | ] 14 | -------------------------------------------------------------------------------- /torch_frame/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Transforms package.""" 2 | from .base_transform import BaseTransform 3 | from .fittable_base_transform import FittableBaseTransform 4 | from .cat_to_num_transform import CatToNumTransform 5 | from .mutual_information_sort import MutualInformationSort 6 | 7 | __all__ = functions = [ 8 | 'BaseTransform', 9 | 'FittableBaseTransform', 10 | 'CatToNumTransform', 11 | 'MutualInformationSort', 12 | ] 13 | -------------------------------------------------------------------------------- /torch_frame/utils/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Utility package.""" 2 | from .io import save, load 3 | from .concat import cat 4 | from .split import generate_random_split 5 | from .infer_stype import infer_series_stype, infer_df_stype 6 | from .memory import num_bytes 7 | 8 | __all__ = functions = [ 9 | "save", 10 | "load", 11 | "cat", 12 | "generate_random_split", 13 | "infer_series_stype", 14 | "infer_df_stype", 15 | "num_bytes", 16 | ] 17 | -------------------------------------------------------------------------------- /test/nn/decoder/test_excelformer_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import ExcelFormerDecoder 4 | 5 | 6 | def test_excelformer_decoder(): 7 | batch_size = 10 8 | num_cols = 18 9 | in_channels = 8 10 | out_channels = 3 11 | x = torch.randn(batch_size, num_cols, in_channels) 12 | decoder = ExcelFormerDecoder(in_channels, out_channels, num_cols) 13 | y = decoder(x) 14 | assert y.shape == (batch_size, out_channels) 15 | -------------------------------------------------------------------------------- /test/nn/decoder/test_trompt_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import TromptDecoder 4 | 5 | 6 | def test_trompt_decoder(): 7 | batch_size = 10 8 | num_prompts = 2 9 | in_channels = 8 10 | out_channels = 1 11 | x_prompt = torch.randn(batch_size, num_prompts, in_channels) 12 | decoder = TromptDecoder(in_channels, out_channels, num_prompts) 13 | y = decoder(x_prompt) 14 | assert y.shape == (batch_size, out_channels) 15 | -------------------------------------------------------------------------------- /torch_frame/nn/conv/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Convolutional layer package.""" 2 | from .table_conv import TableConv 3 | from .ft_transformer_convs import FTTransformerConvs 4 | from .trompt_conv import TromptConv 5 | from .excelformer_conv import ExcelFormerConv 6 | from .tab_transformer_conv import TabTransformerConv 7 | 8 | __all__ = classes = [ 9 | 'TableConv', 10 | 'FTTransformerConvs', 11 | 'TromptConv', 12 | 'ExcelFormerConv', 13 | 'TabTransformerConv', 14 | ] 15 | -------------------------------------------------------------------------------- /test/nn/encoding/test_positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import PositionalEncoding 4 | 5 | 6 | def test_positional_encoding(): 7 | out_size = 8 8 | for size in [(10, ), (10, 4), (10, 5, 8)]: 9 | input_tensor = torch.randint(0, 10, size=size) 10 | positional_encoding = PositionalEncoding(out_size) 11 | out_tensor = positional_encoding(input_tensor) 12 | assert out_tensor.shape == input_tensor.shape + (out_size, ) 13 | -------------------------------------------------------------------------------- /test/data/test_loader.py: -------------------------------------------------------------------------------- 1 | from torch_frame.data import DataLoader, TensorFrame 2 | 3 | 4 | def test_data_loader(get_fake_tensor_frame): 5 | tf = get_fake_tensor_frame(num_rows=10) 6 | loader = DataLoader(tf, batch_size=3) 7 | assert len(loader) == 4 8 | 9 | for i, batch in enumerate(loader): 10 | assert isinstance(batch, TensorFrame) 11 | if i + 1 < len(loader): 12 | assert len(batch) == 3 13 | else: 14 | assert len(batch) == 1 15 | -------------------------------------------------------------------------------- /test/nn/conv/test_excelformer_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import ExcelFormerConv 4 | 5 | 6 | def test_excelformer_conv(): 7 | batch_size = 10 8 | channels = 16 9 | num_cols = 15 10 | num_heads = 8 11 | # Feature-based embeddings 12 | x = torch.randn(size=(batch_size, num_cols, channels)) 13 | conv = ExcelFormerConv(channels, num_cols, num_heads=num_heads) 14 | x_out = conv(x) 15 | assert x_out.shape == (batch_size, num_cols, channels) 16 | -------------------------------------------------------------------------------- /torch_frame/nn/models/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Model package.""" 2 | from .trompt import Trompt 3 | from .ft_transformer import FTTransformer 4 | from .excelformer import ExcelFormer 5 | from .tabnet import TabNet 6 | from .resnet import ResNet 7 | from .tab_transformer import TabTransformer 8 | from .mlp import MLP 9 | 10 | __all__ = classes = [ 11 | 'Trompt', 12 | 'FTTransformer', 13 | 'ExcelFormer', 14 | 'TabNet', 15 | 'ResNet', 16 | 'TabTransformer', 17 | 'MLP', 18 | ] 19 | -------------------------------------------------------------------------------- /test/nn/conv/test_tab_transformer_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import TabTransformerConv 4 | 5 | 6 | def test_tab_transformer_conv(): 7 | batch_size = 10 8 | channels = 16 9 | num_cols = 15 10 | num_heads = 8 11 | # Feature-based embeddings 12 | x = torch.randn(size=(batch_size, num_cols, channels)) 13 | conv = TabTransformerConv(channels, num_heads=num_heads, attn_dropout=0.) 14 | x_out = conv(x) 15 | assert x_out.shape == (batch_size, num_cols, channels) 16 | -------------------------------------------------------------------------------- /.github/workflows/changelog.yml: -------------------------------------------------------------------------------- 1 | name: Changelog Enforcer 2 | 3 | on: # yamllint disable-line rule:truthy 4 | pull_request: 5 | types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled] 6 | 7 | jobs: 8 | 9 | changelog: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout repository 14 | uses: actions/checkout@v6 15 | 16 | - name: Enforce changelog entry 17 | uses: dangoslen/changelog-enforcer@v3 18 | with: 19 | skipLabels: 'skip-changelog' 20 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Building Documentation 2 | 3 | To build the documentation: 4 | 5 | 1. [Build and install](https://github.com/pyg-team/pytorch-frame/blob/master/.github/CONTRIBUTING.md) PyTorch Frame from source. 6 | 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via 7 | ``` 8 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git 9 | ``` 10 | 1. Generate the documentation file via: 11 | ``` 12 | cd docs 13 | make html 14 | ``` 15 | 16 | The documentation is now available to view by opening `docs/build/html/index.html`. 17 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/dependabot-options-reference 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directories: 6 | - "/" 7 | - "/.github/actions/setup" 8 | schedule: 9 | interval: "daily" 10 | time: "00:00" 11 | labels: 12 | - "ci" 13 | - "skip-changelog" 14 | pull-request-branch-name: 15 | separator: "-" 16 | open-pull-requests-limit: 10 17 | reviewers: 18 | - "akihironitta" 19 | assignees: 20 | - "akihironitta" 21 | -------------------------------------------------------------------------------- /test/nn/conv/test_trompt_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import TromptConv 4 | 5 | 6 | def test_trompt_conv(): 7 | batch_size = 10 8 | channels = 8 9 | num_cols = 5 10 | num_prompts = 2 11 | # Feature-based embeddings 12 | x = torch.randn(size=(batch_size, num_cols, channels)) 13 | # Prompt embeddings 14 | x_prompt = torch.randn(size=(batch_size, num_prompts, channels)) 15 | conv = TromptConv(channels=8, num_cols=num_cols, num_prompts=num_prompts) 16 | x_prompt = conv(x, x_prompt) 17 | assert x_prompt.shape == (batch_size, num_prompts, channels) 18 | -------------------------------------------------------------------------------- /torch_frame/config/model.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from dataclasses import dataclass 3 | 4 | from torch import Tensor 5 | 6 | from torch_frame.typing import TensorData 7 | 8 | 9 | @dataclass 10 | class ModelConfig: 11 | r"""Learnable model that maps a single-column :class:`TensorData` object 12 | into row embeddings. 13 | 14 | Args: 15 | model (callable): A callable model that takes a :obj:`TensorData` 16 | object of shape :obj:`[batch_size, 1, *]` as input and outputs 17 | embeddings of shape :obj:`[batch_size, 1, out_channels]`. 18 | out_channels (int): Model output channels. 19 | 20 | """ 21 | model: Callable[[TensorData], Tensor] 22 | out_channels: int 23 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | --- 2 | cff-version: 1.2.0 3 | message: "Please cite our paper if you use this code in your own work." 4 | title: "PyTorch Frame: A Deep Learning Framework for Tabular Data" 5 | authors: 6 | - family-names: "Hu" 7 | given-names: "Weihua" 8 | - family-names: "Yuan" 9 | given-names: "Yiwen" 10 | - family-names: "Zhang" 11 | given-names: "Zecheng" 12 | - family-names: "Nitta" 13 | given-names: "Akihiro" 14 | - family-names: "Cao" 15 | given-names: "Kaidi" 16 | - family-names: "Kocijan" 17 | given-names: "Vid" 18 | - family-names: "Leskovec" 19 | given-names: "Jure" 20 | - family-names: "Fey" 21 | given-names: "Matthias" 22 | date-released: 2023-10-24 23 | license: MIT 24 | url: "https://github.com/pyg-team/pytorch-frame" 25 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: PR Labeler 2 | 3 | on: # yamllint disable-line rule:truthy 4 | pull_request: 5 | 6 | jobs: 7 | 8 | triage: 9 | if: github.repository == 'pyg-team/pytorch-frame' 10 | runs-on: ubuntu-latest 11 | 12 | permissions: 13 | contents: read 14 | pull-requests: write 15 | 16 | steps: 17 | - name: Add PR labels 18 | uses: actions/labeler@v6 19 | continue-on-error: true 20 | with: 21 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 22 | sync-labels: true 23 | 24 | - name: Add PR author 25 | uses: samspills/assign-pr-to-author@v1.0 26 | if: github.event_name == 'pull_request' 27 | continue-on-error: true 28 | with: 29 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 30 | -------------------------------------------------------------------------------- /docs/source/_static/js/version_alert.js: -------------------------------------------------------------------------------- 1 | function warnOnLatestVersion() { 2 | if (!window.READTHEDOCS_DATA || window.READTHEDOCS_DATA.version !== "latest") { 3 | return; // not on ReadTheDocs and not latest. 4 | } 5 | 6 | var note = document.createElement('div'); 7 | note.setAttribute('class', 'admonition note'); 8 | note.innerHTML = "

Note

" + 9 | "

" + 10 | "This documentation is for an unreleased development version. " + 11 | "Click here to access the documentation of the current stable release." + 12 | "

"; 13 | 14 | var parent = document.querySelector('#pyg-documentation'); 15 | if (parent) 16 | parent.insertBefore(note, parent.querySelector('h1')); 17 | } 18 | 19 | document.addEventListener('DOMContentLoaded', warnOnLatestVersion); 20 | -------------------------------------------------------------------------------- /torch_frame/testing/image_embedder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from PIL import Image 5 | from torch import Tensor 6 | 7 | from torch_frame.config.image_embedder import ImageEmbedder 8 | 9 | 10 | class RandomImageEmbedder(ImageEmbedder): 11 | r"""A random-based light-weight image embedder for testing 12 | purposes. It opens each image and generates a random embedding 13 | with :obj:`out_channels` embedding size. 14 | 15 | Args: 16 | out_channels (int): The output dimensionality 17 | """ 18 | def __init__( 19 | self, 20 | out_channels: int, 21 | ) -> None: 22 | super().__init__() 23 | self.out_channels = out_channels 24 | 25 | def forward_embed(self, images: list[Image.Image]) -> Tensor: 26 | return torch.rand(len(images), self.out_channels) 27 | -------------------------------------------------------------------------------- /test/utils/test_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split 4 | 5 | 6 | def test_generate_random_split(): 7 | num_data = 20 8 | train_ratio = 0.8 9 | val_ratio = 0.1 10 | test_ratio = 0.1 11 | 12 | split = generate_random_split(num_data, seed=42, train_ratio=train_ratio, 13 | val_ratio=val_ratio) 14 | assert (split == SPLIT_TO_NUM['train']).sum() == int(num_data * 15 | train_ratio) 16 | assert (split == SPLIT_TO_NUM['val']).sum() == int(num_data * val_ratio) 17 | assert (split == SPLIT_TO_NUM['test']).sum() == int(num_data * test_ratio) 18 | assert np.allclose( 19 | split, 20 | np.array([0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0]), 21 | ) 22 | -------------------------------------------------------------------------------- /torch_frame/data/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from .tensor_frame import TensorFrame 4 | from .multi_embedding_tensor import MultiEmbeddingTensor 5 | from .multi_nested_tensor import MultiNestedTensor 6 | from .stats import StatType 7 | from .dataset import Dataset, DataFrameToTensorFrameConverter 8 | from .loader import DataLoader 9 | from .download import download_url 10 | 11 | data_classes = [ 12 | 'TensorFrame', 13 | 'MultiEmbeddingTensor', 14 | 'MultiNestedTensor', 15 | 'Dataset', 16 | ] 17 | 18 | loader_classes = [ 19 | 'DataLoader', 20 | ] 21 | 22 | stats_classes = [ 23 | 'StatType', 24 | ] 25 | 26 | helper_functions = [ 27 | 'download_url', 28 | 'DataFrameToTensorFrameConverter', 29 | ] 30 | 31 | __all__ = data_classes + loader_classes + stats_classes + helper_functions 32 | 33 | classes = data_classes + loader_classes + stats_classes 34 | -------------------------------------------------------------------------------- /docs/source/get_started/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | :pyf:`PyTorch Frame` is available for `Python 3.10` to `Python 3.13` on Linux, Windows and macOS. 5 | 6 | Installation via PyPI 7 | --------------------- 8 | 9 | .. code-block:: bash 10 | 11 | pip install pytorch-frame 12 | 13 | # Install with optional dependencies 14 | pip install pytorch-frame[full] 15 | 16 | 17 | Installation from master 18 | ------------------------ 19 | 20 | .. code-block:: bash 21 | 22 | pip install git+https://github.com/pyg-team/pytorch-frame.git 23 | 24 | 25 | Installation for development 26 | ---------------------------- 27 | 28 | .. code-block:: bash 29 | 30 | git clone https://github.com/pyg-team/pytorch-frame.git 31 | cd pytorch-frame 32 | pip install -e .[dev] 33 | 34 | # Install with optional dependencies 35 | pip install -e .[dev,full] 36 | -------------------------------------------------------------------------------- /test/nn/models/test_tabnet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_frame.data import Dataset 4 | from torch_frame.datasets import FakeDataset 5 | from torch_frame.nn import TabNet 6 | 7 | 8 | @pytest.mark.parametrize('batch_size', [0, 5]) 9 | def test_tabnet(batch_size): 10 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False) 11 | dataset.materialize() 12 | tensor_frame = dataset.tensor_frame[:batch_size] 13 | out_channels = 12 14 | model = TabNet(out_channels=out_channels, num_layers=3, 15 | split_feat_channels=8, split_attn_channels=8, gamma=1.2, 16 | col_stats=dataset.col_stats, 17 | col_names_dict=tensor_frame.col_names_dict) 18 | model.reset_parameters() 19 | out, reg = model(tensor_frame, return_reg=True) 20 | assert out.shape == (len(tensor_frame), out_channels) 21 | assert reg >= 0 22 | -------------------------------------------------------------------------------- /test/nn/models/test_resnet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_frame.data.dataset import Dataset 4 | from torch_frame.datasets import FakeDataset 5 | from torch_frame.nn import ResNet 6 | 7 | 8 | @pytest.mark.parametrize('batch_size', [0, 5]) 9 | def test_resnet(batch_size): 10 | channels = 8 11 | out_channels = 1 12 | num_layers = 3 13 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False) 14 | dataset.materialize() 15 | tensor_frame = dataset.tensor_frame[:batch_size] 16 | # Feature-based embeddings 17 | model = ResNet( 18 | channels=channels, 19 | out_channels=out_channels, 20 | num_layers=num_layers, 21 | col_stats=dataset.col_stats, 22 | col_names_dict=tensor_frame.col_names_dict, 23 | ) 24 | model.reset_parameters() 25 | out = model(tensor_frame) 26 | assert out.shape == (batch_size, out_channels) 27 | -------------------------------------------------------------------------------- /test/nn/models/test_ft_transformer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_frame.data.dataset import Dataset 4 | from torch_frame.datasets import FakeDataset 5 | from torch_frame.nn import FTTransformer 6 | 7 | 8 | @pytest.mark.parametrize('batch_size', [0, 5]) 9 | def test_ft_transformer(batch_size): 10 | channels = 8 11 | out_channels = 1 12 | num_layers = 3 13 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False) 14 | dataset.materialize() 15 | tensor_frame = dataset.tensor_frame[:batch_size] 16 | # Feature-based embeddings 17 | model = FTTransformer( 18 | channels=channels, 19 | out_channels=out_channels, 20 | num_layers=num_layers, 21 | col_stats=dataset.col_stats, 22 | col_names_dict=tensor_frame.col_names_dict, 23 | ) 24 | model.reset_parameters() 25 | out = model(tensor_frame) 26 | assert out.shape == (batch_size, out_channels) 27 | -------------------------------------------------------------------------------- /test/nn/encoding/test_cyclic_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_frame.nn import CyclicEncoding 4 | 5 | 6 | def test_cyclic_encoding_shape(): 7 | out_size = 8 8 | for size in [(10, ), (10, 4), (10, 5, 8)]: 9 | input_tensor = torch.rand(size) 10 | cyclic_encoding = CyclicEncoding(out_size) 11 | out_tensor = cyclic_encoding(input_tensor) 12 | assert out_tensor.shape == input_tensor.shape + (out_size, ) 13 | 14 | 15 | def test_cyclic_encoding_values(): 16 | out_size = 8 17 | for size in [(10, ), (10, 4), (10, 5, 8)]: 18 | cyclic_encoding = CyclicEncoding(out_size) 19 | input_zeros_tensor = torch.zeros(size) 20 | input_ones_tensor = torch.ones(size) 21 | out_zeros_tensor = cyclic_encoding(input_zeros_tensor) 22 | out_ones_tensor = cyclic_encoding(input_ones_tensor) 23 | assert torch.allclose(out_zeros_tensor, out_ones_tensor, atol=1e-5) 24 | -------------------------------------------------------------------------------- /torch_frame/nn/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Encoder package.""" 2 | from .encoder import FeatureEncoder 3 | from .stypewise_encoder import StypeWiseFeatureEncoder 4 | from .stype_encoder import ( 5 | StypeEncoder, 6 | EmbeddingEncoder, 7 | MultiCategoricalEmbeddingEncoder, 8 | LinearEncoder, 9 | LinearBucketEncoder, 10 | LinearPeriodicEncoder, 11 | ExcelFormerEncoder, 12 | LinearEmbeddingEncoder, 13 | LinearModelEncoder, 14 | StackEncoder, 15 | TimestampEncoder, 16 | ) 17 | 18 | __all__ = classes = [ 19 | 'FeatureEncoder', 20 | 'StypeWiseFeatureEncoder', 21 | 'StypeEncoder', 22 | 'EmbeddingEncoder', 23 | 'MultiCategoricalEmbeddingEncoder', 24 | 'LinearEncoder', 25 | 'LinearBucketEncoder', 26 | 'LinearPeriodicEncoder', 27 | 'ExcelFormerEncoder', 28 | 'LinearEmbeddingEncoder', 29 | 'LinearModelEncoder', 30 | 'StackEncoder', 31 | 'TimestampEncoder', 32 | ] 33 | -------------------------------------------------------------------------------- /.github/workflows/linting.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 12 | 13 | jobs: 14 | mypy: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout repository 18 | uses: actions/checkout@v6 19 | 20 | - name: Set up Python 21 | uses: actions/setup-python@v6 22 | with: 23 | python-version: '3.10' 24 | 25 | - name: Install dependencies 26 | run: | 27 | # TODO: Use the latest PyTorch version once type issues are addressed in the codebase: 28 | pip install -e '.[full,test]' 'torch==2.7.*' -f https://download.pytorch.org/whl/cpu 29 | pip list 30 | 31 | - name: Check type hints 32 | run: mypy 33 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/pyg-team/pytorch-frame 2 | 3 | PyTorch Frame Documentation 4 | =========================== 5 | :pyf:`null` **PyTorch Frame** is a library built upon :pytorch:`null` `PyTorch `_ to easily write and train tabular deep learning models. 6 | 7 | .. slack_button:: 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | :caption: Get Started 12 | 13 | get_started/installation 14 | get_started/introduction 15 | get_started/modular_design 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | :caption: Handling Advanced Stypes 20 | 21 | handling_advanced_stypes/handle_heterogeneous_stypes 22 | handling_advanced_stypes/handle_text 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :caption: Package Reference 27 | 28 | modules/root 29 | modules/data 30 | modules/datasets 31 | modules/nn 32 | modules/gbdt 33 | modules/config 34 | modules/transforms 35 | modules/utils 36 | -------------------------------------------------------------------------------- /torch_frame/nn/conv/table_conv.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | from torch import Tensor 5 | from torch.nn import Module 6 | 7 | 8 | class TableConv(Module, ABC): 9 | r"""Base class for table convolution that transforms the input column-wise 10 | pytorch tensor. 11 | """ 12 | @abstractmethod 13 | def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any: 14 | r"""Process column-wise 3-dimensional tensor into another column-wise 15 | 3-dimensional tensor. 16 | 17 | Args: 18 | x (torch.Tensor): Input column-wise tensor of shape 19 | :obj:`[batch_size, num_cols, hidden_channels]`. 20 | args (Any): Extra arguments. 21 | kwargs (Any): Extra keyword arguments. 22 | """ 23 | raise NotImplementedError 24 | 25 | def reset_parameters(self) -> None: 26 | r"""Resets all learnable parameters of the module.""" 27 | -------------------------------------------------------------------------------- /torch_frame/utils/memory.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch import Tensor 4 | 5 | from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor 6 | from torch_frame.data.multi_nested_tensor import MultiNestedTensor 7 | 8 | 9 | def num_bytes(data: Any) -> int: 10 | r"""Returns the number of bytes the tensor data consumes. 11 | 12 | Args: 13 | data: The tensor data. 14 | """ 15 | if isinstance(data, Tensor): 16 | return data.element_size() * data.numel() 17 | if isinstance(data, MultiNestedTensor | MultiEmbeddingTensor): 18 | return num_bytes(data.values) + num_bytes(data.offset) 19 | if isinstance(data, list): 20 | return sum([num_bytes(value) for value in data]) 21 | if isinstance(data, dict): 22 | return sum([num_bytes(value) for value in data.values()]) 23 | 24 | raise NotImplementedError(f"'num_bytes' not implemented for " 25 | f"'{type(data)}'") 26 | -------------------------------------------------------------------------------- /test/nn/models/test_mlp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_frame.data.dataset import Dataset 4 | from torch_frame.datasets import FakeDataset 5 | from torch_frame.nn import MLP 6 | 7 | 8 | @pytest.mark.parametrize('batch_size', [0, 5]) 9 | @pytest.mark.parametrize('normalization', ["layer_norm", "batch_norm"]) 10 | def test_mlp(batch_size, normalization): 11 | channels = 8 12 | out_channels = 1 13 | num_layers = 3 14 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False) 15 | dataset.materialize() 16 | tensor_frame = dataset.tensor_frame[:batch_size] 17 | # Feature-based embeddings 18 | model = MLP( 19 | channels=channels, 20 | out_channels=out_channels, 21 | num_layers=num_layers, 22 | col_stats=dataset.col_stats, 23 | col_names_dict=tensor_frame.col_names_dict, 24 | normalization=normalization, 25 | ) 26 | model.reset_parameters() 27 | out = model(tensor_frame) 28 | assert out.shape == (batch_size, out_channels) 29 | -------------------------------------------------------------------------------- /docs/source/modules/datasets.rst: -------------------------------------------------------------------------------- 1 | torch_frame.datasets 2 | ==================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | .. currentmodule:: torch_frame.datasets 8 | 9 | Real-World Datasets 10 | ------------------- 11 | 12 | .. autosummary:: 13 | :nosignatures: 14 | :toctree: ../generated 15 | :template: autosummary/class.rst 16 | 17 | {% for name in torch_frame.datasets.real_world_datasets %} 18 | {{ name }} 19 | {% endfor %} 20 | 21 | Synthetic Datasets 22 | ------------------ 23 | 24 | .. autosummary:: 25 | :nosignatures: 26 | :toctree: ../generated 27 | :template: autosummary/class.rst 28 | 29 | {% for name in torch_frame.datasets.synthetic_datasets %} 30 | {{ name }} 31 | {% endfor %} 32 | 33 | Other Datasets 34 | -------------- 35 | 36 | .. autosummary:: 37 | :nosignatures: 38 | :toctree: ../generated 39 | :template: autosummary/class.rst 40 | 41 | {% for name in torch_frame.datasets.other_datasets %} 42 | {{ name }} 43 | {% endfor %} 44 | -------------------------------------------------------------------------------- /.github/workflows/auto-merge.yml: -------------------------------------------------------------------------------- 1 | name: Dependabot auto-merge 2 | 3 | on: # yamllint disable-line rule:truthy 4 | pull_request_target: 5 | types: [opened, reopened] 6 | 7 | permissions: 8 | contents: write 9 | pull-requests: write 10 | 11 | jobs: 12 | auto-merge: 13 | runs-on: ubuntu-latest 14 | if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]' }} 15 | steps: 16 | - uses: actions/checkout@v6 17 | 18 | - name: Label bot PRs 19 | run: gh pr edit --add-label "ci,skip-changelog" ${{ github.event.pull_request.html_url }} 20 | env: 21 | GITHUB_TOKEN: ${{ secrets.PAT }} 22 | 23 | - name: Auto-approve 24 | uses: hmarr/auto-approve-action@v4 25 | with: 26 | github-token: ${{ secrets.PAT }} 27 | 28 | - name: Enable auto-merge 29 | run: gh pr merge --auto --squash ${{ github.event.pull_request.html_url }} 30 | env: 31 | GITHUB_TOKEN: ${{ secrets.PAT }} 32 | -------------------------------------------------------------------------------- /torch_frame/nn/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | from torch import Tensor 5 | from torch.nn import Module 6 | 7 | 8 | class Decoder(Module, ABC): 9 | r"""Base class for decoder that transforms the input column-wise PyTorch 10 | tensor into output tensor on which prediction head is applied. 11 | """ 12 | @abstractmethod 13 | def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any: 14 | r"""Decode :obj:`x` of shape :obj:`[batch_size, num_cols, channels]` 15 | into an output tensor of shape :obj:`[batch_size, out_channels]`. 16 | 17 | Args: 18 | x (torch.Tensor): Input column-wise tensor of shape 19 | :obj:`[batch_size, num_cols, hidden_channels]`. 20 | args (Any): Extra arguments. 21 | kwargs (Any): Extra keyward arguments. 22 | """ 23 | raise NotImplementedError 24 | 25 | def reset_parameters(self) -> None: 26 | r"""Resets all learnable parameters of the module.""" 27 | -------------------------------------------------------------------------------- /test/utils/test_memory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_frame.data import MultiEmbeddingTensor, MultiNestedTensor 5 | from torch_frame.utils import num_bytes 6 | 7 | 8 | def test_num_bytes(): 9 | data = torch.randn(4, 8) 10 | assert num_bytes(data) == 4 * 8 * 4 11 | 12 | data = MultiNestedTensor( 13 | num_rows=3, 14 | num_cols=2, 15 | values=torch.randn(12), 16 | offset=torch.tensor([0, 2, 4, 6, 8, 10, 12]), 17 | ) 18 | assert num_bytes(data) == 12 * 4 + 7 * 8 19 | 20 | data = MultiEmbeddingTensor( 21 | num_rows=2, 22 | num_cols=3, 23 | values=torch.randn(2, 10), 24 | offset=torch.tensor([0, 3, 5, 10]), 25 | ) 26 | assert num_bytes(data) == 2 * 10 * 4 + 4 * 8 27 | 28 | mapping = { 29 | 'A': data, 30 | 'B': data, 31 | } 32 | assert num_bytes(mapping) == 2 * num_bytes(data) 33 | 34 | seq = [data, data, data] 35 | assert num_bytes(seq) == 3 * num_bytes(data) 36 | 37 | with pytest.raises(NotImplementedError): 38 | num_bytes("unsupported") 39 | -------------------------------------------------------------------------------- /torch_frame/config/text_tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from dataclasses import dataclass 5 | 6 | from torch_frame.typing import TextTokenizationOutputs 7 | 8 | 9 | @dataclass 10 | class TextTokenizerConfig: 11 | r"""Text tokenizer that maps a list of strings/sentences into a 12 | dictionary of :class:`MultiNestedTensor`. 13 | 14 | Args: 15 | text_tokenizer (callable): A callable text tokenizer that takes a 16 | list of strings as input and outputs a list of dictionaries. 17 | Each dictionary contains keys that are arguments to the text 18 | encoder model and values are corresponding tensors such as 19 | tokens and attention masks. 20 | batch_size (int, optional): Batch size to use when tokenizing the 21 | sentences. If set to :obj:`None`, the text embeddings will 22 | be obtained in a full-batch manner. (default: :obj:`None`) 23 | 24 | """ 25 | text_tokenizer: Callable[[list[str]], TextTokenizationOutputs] 26 | batch_size: int | None = None 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 PyG Team 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /docs/source/modules/data.rst: -------------------------------------------------------------------------------- 1 | torch_frame.data 2 | ================ 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | .. currentmodule:: torch_frame.data 8 | 9 | Data Objects 10 | ------------ 11 | 12 | .. autosummary:: 13 | :nosignatures: 14 | :toctree: ../generated 15 | :template: autosummary/class.rst 16 | 17 | {% for name in torch_frame.data.data_classes %} 18 | {{ name }} 19 | {% endfor %} 20 | 21 | Stats 22 | ----- 23 | 24 | .. autosummary:: 25 | :nosignatures: 26 | :toctree: ../generated 27 | :template: autosummary/class.rst 28 | 29 | {% for name in torch_frame.data.stats_classes %} 30 | {{ name }} 31 | {% endfor %} 32 | 33 | Data Loaders 34 | ------------ 35 | 36 | .. autosummary:: 37 | :nosignatures: 38 | :toctree: ../generated 39 | :template: autosummary/class.rst 40 | 41 | {% for name in torch_frame.data.loader_classes %} 42 | {{ name }} 43 | {% endfor %} 44 | 45 | Helper Functions 46 | ---------------- 47 | 48 | .. autosummary:: 49 | :nosignatures: 50 | :toctree: ../generated 51 | 52 | {% for name in torch_frame.data.helper_functions %} 53 | {{ name }} 54 | {% endfor %} 55 | -------------------------------------------------------------------------------- /torch_frame/config/text_embedder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from dataclasses import dataclass 5 | 6 | from torch import Tensor 7 | 8 | 9 | @dataclass 10 | class TextEmbedderConfig: 11 | r"""Text embedder model that maps a list of strings/sentences into PyTorch 12 | Tensor embeddings. 13 | 14 | Args: 15 | text_embedder (callable): A callable text embedder that takes a list 16 | of strings as input and outputs the PyTorch Tensor embeddings for 17 | that list of strings. 18 | batch_size (int, optional): Batch size to use when encoding the 19 | sentences. If set to :obj:`None`, the text embeddings will 20 | be obtained in a full-batch manner. (default: :obj:`None`) 21 | 22 | """ 23 | text_embedder: Callable[[list[str]], Tensor] 24 | # Batch size to use when encoding the sentences. It is recommended to set 25 | # it to a reasonable value when one uses a heavy text embedding model 26 | # (e.g., Transformer) on GPU. If set to :obj:`None`, the text embeddings 27 | # will be obtained in a full-batch manner. 28 | batch_size: int | None = None 29 | -------------------------------------------------------------------------------- /test/test_stype.py: -------------------------------------------------------------------------------- 1 | import torch_frame 2 | 3 | 4 | def test_stype(): 5 | assert len(torch_frame.stype) == 9 6 | assert torch_frame.numerical == torch_frame.stype('numerical') 7 | assert not torch_frame.numerical.is_text_stype 8 | assert torch_frame.categorical == torch_frame.stype('categorical') 9 | assert not torch_frame.categorical.is_text_stype 10 | assert torch_frame.multicategorical == torch_frame.stype( 11 | 'multicategorical') 12 | assert not torch_frame.multicategorical.is_text_stype 13 | assert torch_frame.sequence_numerical == torch_frame.stype( 14 | 'sequence_numerical') 15 | assert not torch_frame.sequence_numerical.is_text_stype 16 | assert torch_frame.text_embedded == torch_frame.stype('text_embedded') 17 | assert torch_frame.text_embedded.is_text_stype 18 | assert torch_frame.text_tokenized == torch_frame.stype('text_tokenized') 19 | assert torch_frame.text_tokenized.is_text_stype 20 | assert torch_frame.image_embedded == torch_frame.stype('image_embedded') 21 | assert torch_frame.image_embedded.is_image_stype 22 | assert torch_frame.embedding == torch_frame.stype('embedding') 23 | assert torch_frame.embedding.use_multi_embedding_tensor 24 | -------------------------------------------------------------------------------- /test/nn/models/test_tab_transformer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_frame import stype 4 | from torch_frame.data.dataset import Dataset 5 | from torch_frame.datasets import FakeDataset 6 | from torch_frame.nn import TabTransformer 7 | 8 | 9 | @pytest.mark.parametrize('stypes', [[stype.categorical, stype.numerical], 10 | [stype.categorical], [stype.numerical]]) 11 | @pytest.mark.parametrize('batch_size', [0, 5]) 12 | def test_tab_transformer(stypes, batch_size): 13 | channels = 8 14 | out_channels = 1 15 | num_layers = 3 16 | num_heads = 2 17 | encoder_pad_size = 2 18 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False, stypes=stypes) 19 | dataset.materialize() 20 | tensor_frame = dataset.tensor_frame[:batch_size] 21 | model = TabTransformer( 22 | channels=channels, 23 | out_channels=out_channels, 24 | num_layers=num_layers, 25 | num_heads=num_heads, 26 | encoder_pad_size=encoder_pad_size, 27 | attn_dropout=0., 28 | ffn_dropout=0., 29 | col_stats=dataset.col_stats, 30 | col_names_dict=tensor_frame.col_names_dict, 31 | ) 32 | model.reset_parameters() 33 | out = model(tensor_frame) 34 | assert out.shape == (batch_size, out_channels) 35 | -------------------------------------------------------------------------------- /torch_frame/__init__.py: -------------------------------------------------------------------------------- 1 | r"""Utility package.""" 2 | from ._stype import ( 3 | stype, 4 | numerical, 5 | categorical, 6 | text_embedded, 7 | text_tokenized, 8 | multicategorical, 9 | sequence_numerical, 10 | timestamp, 11 | image_embedded, 12 | embedding, 13 | ) 14 | from .data import TensorFrame 15 | from .typing import ( 16 | TaskType, 17 | Metric, 18 | DataFrame, 19 | NAStrategy, 20 | WITH_PT24, 21 | ) 22 | from torch_frame.utils import save, load, cat # noqa 23 | import torch_frame.data # noqa 24 | import torch_frame.datasets # noqa 25 | import torch_frame.nn # noqa 26 | import torch_frame.gbdt # noqa 27 | 28 | if WITH_PT24: 29 | import torch 30 | 31 | torch.serialization.add_safe_globals([ 32 | stype, 33 | torch_frame.data.stats.StatType, 34 | ]) 35 | 36 | # https://peps.python.org/pep-0440/ 37 | __version__ = '0.4.0.dev0' 38 | 39 | __all__ = [ 40 | 'DataFrame', 41 | 'stype', 42 | 'numerical', 43 | 'categorical', 44 | 'text_embedded', 45 | 'text_tokenized', 46 | 'multicategorical', 47 | 'sequence_numerical', 48 | 'timestamp', 49 | 'image_embedded', 50 | 'embedding', 51 | 'TaskType', 52 | 'Metric', 53 | 'NAStrategy', 54 | 'TensorFrame', 55 | 'save', 56 | 'load', 57 | 'cat', 58 | 'torch_frame', 59 | '__version__', 60 | ] 61 | -------------------------------------------------------------------------------- /torch_frame/testing/text_embedder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Embedding 6 | 7 | 8 | class HashTextEmbedder: 9 | r"""A hash-based light-weight text embedder for testing purposes. 10 | It hashes each sentence into an index modulo :obj:`num_hash` and then 11 | uses :class:`torch.nn.Embedding` to look up the index to obtain the 12 | sentence embedding. 13 | 14 | Args: 15 | out_channels (int): The output dimensionality 16 | num_hash_bins (int): Number of hash bins to use. 17 | (default: :obj:`64`) 18 | device (torch.device, optional): The device to put :class:`Embedding` 19 | module. (default: :obj:`None`) 20 | """ 21 | def __init__( 22 | self, 23 | out_channels: int, 24 | num_hash_bins: int = 64, 25 | device: torch.device | None = None, 26 | ) -> None: 27 | self.out_channels = out_channels 28 | self.num_hash_bins = num_hash_bins 29 | self.device = device 30 | self.embedding = Embedding(num_hash_bins, out_channels).to(device) 31 | 32 | def __call__(self, sentences: list[str]) -> Tensor: 33 | idx = torch.tensor([hash(s) % self.num_hash_bins for s in sentences], 34 | device=self.device) 35 | return self.embedding(idx).detach() 36 | -------------------------------------------------------------------------------- /torch_frame/nn/utils/init.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.init import _calculate_correct_fan, calculate_gain 6 | 7 | 8 | def attenuated_kaiming_uniform_( 9 | tensor: Tensor, 10 | scale: float = 0.1, 11 | a: float = math.sqrt(5), 12 | mode: str = 'fan_in', 13 | nonlinearity: str = 'leaky_relu', 14 | ) -> Tensor: 15 | r"""Attenuated Kaiming Uniform Initialization. 16 | 17 | Args: 18 | tensor (tensor): Input tensor to be initialized 19 | scale (float): Positive rescaling constant to the variance. 20 | a (float): Negative slope of the rectifier used after this layer 21 | mode (str): Either 'fan_in' (default) or 'fan_out'. Choosing 22 | 'fan_in' preserves the magnitude of the variance of the weights 23 | in the forward pass. Choosing 'fan_out' preserves the magnitudes 24 | in the backwards pass. 25 | nonlinearity (str) : the non-linear function (nn.functional name), 26 | recommended to use only with 'relu' or 'leaky_relu'. 27 | """ 28 | with torch.no_grad(): 29 | fan = _calculate_correct_fan(tensor, mode) 30 | gain = calculate_gain(nonlinearity, a) 31 | std = gain * scale / math.sqrt(fan) 32 | # Calculate uniform bounds from standard deviation 33 | bound = math.sqrt(3.0) * std 34 | return tensor.uniform_(-bound, bound) 35 | -------------------------------------------------------------------------------- /torch_frame/nn/encoder/encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from torch import Tensor 4 | from torch.nn import Module 5 | 6 | from torch_frame import TensorFrame 7 | 8 | 9 | class FeatureEncoder(Module, ABC): 10 | r"""Base class for feature encoder that transforms input 11 | :class:`torch_frame.TensorFrame` into :obj:`(x, col_names)`, 12 | where :obj:`x` is the colum-wise PyTorch tensor of shape 13 | :obj:`[batch_size, num_cols, channels]` and :obj:`col_names` is the 14 | names of the columns. This class contains learnable parameters and missing 15 | value handling. 16 | """ 17 | @abstractmethod 18 | def forward(self, tf: TensorFrame) -> tuple[Tensor, list[str]]: 19 | r"""Encode :class:`TensorFrame` object into a tuple 20 | :obj:`(x, col_names)`. 21 | 22 | Args: 23 | tf (:class:`torch_frame.TensorFrame`): Input :class:`TensorFrame` 24 | object. 25 | 26 | Returns: 27 | (torch.Tensor, List[str]): A tuple of an output column-wise 28 | :class:`torch.Tensor` of shape 29 | :obj:`[batch_size, num_cols, hidden_channels]` and a list of 30 | column names of :obj:`x`. The length needs to be 31 | :obj:`num_cols`. 32 | """ 33 | raise NotImplementedError 34 | 35 | def reset_parameters(self) -> None: 36 | r"""Resets all learnable parameters of the module.""" 37 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 12 | 13 | jobs: 14 | 15 | make_html: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v6 21 | with: 22 | fetch-depth: 40 23 | 24 | # Skip workflow if only certain files have been changed. 25 | - name: Get changed files 26 | id: changed-files-specific 27 | uses: tj-actions/changed-files@v47 28 | with: 29 | files: | 30 | examples/** 31 | README.md 32 | CHANGELOG.md 33 | 34 | - name: Setup packages 35 | if: steps.changed-files-specific.outputs.only_changed != 'true' 36 | uses: ./.github/actions/setup 37 | 38 | - name: Install main package 39 | if: steps.changed-files-specific.outputs.only_changed != 'true' 40 | run: | 41 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git 42 | pip install -e . 43 | pip list 44 | 45 | 46 | - name: Build documentation 47 | if: steps.changed-files-specific.outputs.only_changed != 'true' 48 | run: | 49 | cd docs && make clean && make html SPHINXOPTS="-W" # Fail on warning. 50 | -------------------------------------------------------------------------------- /torch_frame/data/download.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import os.path as osp 5 | import ssl 6 | import sys 7 | import urllib.request 8 | 9 | 10 | def download_url( 11 | url: str, 12 | root: str, 13 | filename: str | None = None, 14 | *, 15 | log: bool = True, 16 | ) -> str: 17 | r"""Downloads the content of :obj:`url` to the specified folder 18 | :obj:`root`. 19 | 20 | Args: 21 | url (str): The URL. 22 | root (str): The root folder. 23 | filename (str, optional): If set, will rename the downloaded file. 24 | (default: :obj:`None`) 25 | log (bool, optional): If :obj:`False`, will not print anything to the 26 | console. (default: :obj:`True`) 27 | """ 28 | if filename is None: 29 | filename = url.rpartition('/')[2] 30 | if filename[0] != '?': 31 | filename = filename.split('?')[0] 32 | 33 | path = osp.join(root, filename) 34 | 35 | if osp.exists(path): 36 | return path 37 | 38 | if log and 'pytest' not in sys.modules: 39 | print(f'Downloading {url}', file=sys.stderr) 40 | 41 | os.makedirs(root, exist_ok=True) 42 | 43 | context = ssl._create_unverified_context() 44 | data = urllib.request.urlopen(url, context=context) 45 | 46 | with open(path, 'wb') as f: 47 | while True: 48 | chunk = data.read(10 * 1024 * 1024) 49 | if not chunk: 50 | break 51 | f.write(chunk) 52 | 53 | return path 54 | -------------------------------------------------------------------------------- /torch_frame/nn/encoding/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module 4 | 5 | 6 | class PositionalEncoding(Module): 7 | r"""Positional encoding introduced in `"Attention Is All You Need" 8 | `_ paper. Given an input tensor of shape 9 | :obj:`(*, )`, this encoding expands it into an output tensor of shape 10 | :obj:`(*, out_size)`. 11 | 12 | Args: 13 | out_size (int): The output dimension size. 14 | """ 15 | def __init__(self, out_size: int) -> None: 16 | super().__init__() 17 | if out_size % 2 != 0: 18 | raise ValueError( 19 | f"out_size should be divisible by 2 (got {out_size}).") 20 | self.out_size = out_size 21 | self.mult_term: Tensor 22 | self.register_buffer( 23 | "mult_term", 24 | torch.pow( 25 | 1 / 10000.0, 26 | torch.arange(0, self.out_size, 2) / out_size, 27 | ), 28 | ) 29 | 30 | def forward(self, input_tensor: Tensor) -> Tensor: 31 | assert torch.all(input_tensor >= 0) 32 | # (*, 1) * (1, ..., 1, out_size // 2) -> (*, out_size // 2) 33 | mult_tensor = input_tensor.unsqueeze(-1) * self.mult_term.reshape( 34 | (1, ) * input_tensor.ndim + (-1, )) 35 | # cat([(*, out_size // 2), (*, out_size // 2)]) -> (*, out_size) 36 | return torch.cat([torch.sin(mult_tensor), 37 | torch.cos(mult_tensor)], dim=-1) 38 | -------------------------------------------------------------------------------- /torch_frame/nn/encoding/cyclic_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Module 6 | 7 | 8 | class CyclicEncoding(Module): 9 | r"""Cyclic encoding for input data containing values between 0 and 1. 10 | This function maps each value in the input using sine and cosine 11 | functions of different wavelengths to preserve the cyclical nature. This 12 | is particularly useful for encoding cyclical features like hours of a 13 | day, days of the week, etc. Given an input tensor of shape 14 | :obj:`(*, )`, this encoding expands it into an output tensor of shape 15 | :obj:`(*, out_size)`. 16 | 17 | Args: 18 | out_size (int): The output dimension size. 19 | """ 20 | def __init__(self, out_size: int) -> None: 21 | super().__init__() 22 | if out_size % 2 != 0: 23 | raise ValueError( 24 | f"out_size should be divisible by 2 (got {out_size}).") 25 | self.out_size = out_size 26 | self.mult_term: Tensor 27 | self.register_buffer( 28 | "mult_term", 29 | torch.arange(1, self.out_size // 2 + 1), 30 | ) 31 | 32 | def forward(self, input_tensor: Tensor) -> Tensor: 33 | assert torch.all((input_tensor >= 0) & (input_tensor <= 1)) 34 | mult_tensor = input_tensor.unsqueeze(-1) * self.mult_term.reshape( 35 | (1, ) * input_tensor.ndim + (-1, )) 36 | return torch.cat([ 37 | torch.sin(mult_tensor * math.pi), 38 | torch.cos(mult_tensor * 2 * math.pi) 39 | ], dim=-1) 40 | -------------------------------------------------------------------------------- /torch_frame/utils/split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Mapping split name to integer. 4 | SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2} 5 | 6 | 7 | def generate_random_split(length: int, seed: int, train_ratio: float = 0.8, 8 | val_ratio: float = 0.1, 9 | include_test: bool = True) -> np.ndarray: 10 | r"""Generate a list of random split assignments of the specified length. 11 | The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing 12 | train, val, test, respectively. Note that this function relies on the fact 13 | that numpy's shuffle is consistent across versions, which has been 14 | historically the case. 15 | """ 16 | assert train_ratio > 0 17 | assert val_ratio > 0 18 | 19 | if include_test: 20 | assert train_ratio + val_ratio < 1 21 | train_num = int(length * train_ratio) 22 | val_num = int(length * val_ratio) 23 | test_num = length - train_num - val_num 24 | arr = np.concatenate([ 25 | np.full(train_num, SPLIT_TO_NUM['train']), 26 | np.full(val_num, SPLIT_TO_NUM['val']), 27 | np.full(test_num, SPLIT_TO_NUM['test']) 28 | ]) 29 | else: 30 | assert train_ratio + val_ratio == 1 31 | train_num = int(length * train_ratio) 32 | val_num = length - train_num 33 | arr = np.concatenate([ 34 | np.full(train_num, SPLIT_TO_NUM['train']), 35 | np.full(val_num, SPLIT_TO_NUM['val']), 36 | ]) 37 | np.random.seed(seed) 38 | np.random.shuffle(arr) 39 | 40 | return arr 41 | -------------------------------------------------------------------------------- /torch_frame/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .adult_census_income import AdultCensusIncome 3 | from .amazon_fine_food_reviews import AmazonFineFoodReviews 4 | from .amphibians import Amphibians 5 | from .bank_marketing import BankMarketing 6 | from .data_frame_benchmark import DataFrameBenchmark 7 | from .data_frame_text_benchmark import DataFrameTextBenchmark 8 | from .diamond_images import DiamondImages 9 | from .dota2 import Dota2 10 | from .fake import FakeDataset 11 | from .forest_cover_type import ForestCoverType 12 | from .huggingface_dataset import HuggingFaceDatasetDict 13 | from .kdd_census_income import KDDCensusIncome 14 | from .mercari import Mercari 15 | from .movielens_1m import Movielens1M 16 | from .multimodal_text_benchmark import MultimodalTextBenchmark 17 | from .mushroom import Mushroom 18 | from .poker_hand import PokerHand 19 | from .tabular_benchmark import TabularBenchmark 20 | from .titanic import Titanic 21 | from .yandex import Yandex 22 | 23 | real_world_datasets = [ 24 | 'AdultCensusIncome', 25 | 'AmazonFineFoodReviews', 26 | 'Amphibians', 27 | 'BankMarketing', 28 | 'DataFrameBenchmark', 29 | 'DataFrameTextBenchmark', 30 | 'Dota2', 31 | 'Titanic', 32 | 'ForestCoverType', 33 | 'HuggingFaceDatasetDict', 34 | 'KDDCensusIncome', 35 | 'Mercari', 36 | 'Movielens1M', 37 | 'MultimodalTextBenchmark', 38 | 'Mushroom', 39 | 'PokerHand', 40 | 'TabularBenchmark', 41 | 'Yandex', 42 | 'DiamondImages', 43 | ] 44 | 45 | synthetic_datasets = [ 46 | 'FakeDataset', 47 | ] 48 | 49 | other_datasets = [ 50 | 'HuggingFaceDatasetDict', 51 | ] 52 | 53 | __all__ = real_world_datasets + synthetic_datasets + other_datasets 54 | -------------------------------------------------------------------------------- /torch_frame/datasets/titanic.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import torch_frame 4 | 5 | 6 | class Titanic(torch_frame.data.Dataset): 7 | r"""The Titanic dataset from the `Titanic Kaggle competition 8 | `_. 9 | The Titanic dataset is known as the MNIST equivalent for tabular learning. 10 | The goal is to predict which passenger survived using passenger data 11 | (*i.e.* gender, age, etc). 12 | 13 | **STATS:** 14 | 15 | .. list-table:: 16 | :widths: 10 10 10 10 20 10 17 | :header-rows: 1 18 | 19 | * - #rows 20 | - #cols (numerical) 21 | - #cols (categorical) 22 | - #classes 23 | - Task 24 | - Missing value ratio 25 | * - 891 26 | - 4 27 | - 3 28 | - 2 29 | - binary_classification 30 | - 8.8% 31 | """ 32 | 33 | url = 'https://github.com/datasciencedojo/datasets/raw/master/titanic.csv' 34 | 35 | def __init__(self, root: str) -> None: 36 | path = self.download_url(self.url, root) 37 | df = pd.read_csv(path, index_col=['PassengerId']) 38 | 39 | col_to_stype = { # TODO Use 'Name', 'Ticket' and 'Cabin'. 40 | 'Survived': torch_frame.categorical, 41 | 'Pclass': torch_frame.categorical, 42 | 'Sex': torch_frame.categorical, 43 | 'Age': torch_frame.numerical, 44 | 'SibSp': torch_frame.numerical, 45 | 'Parch': torch_frame.numerical, 46 | 'Fare': torch_frame.numerical, 47 | 'Embarked': torch_frame.categorical, 48 | } 49 | 50 | super().__init__(df, col_to_stype, target_col='Survived') 51 | -------------------------------------------------------------------------------- /.github/actions/setup/action.yml: -------------------------------------------------------------------------------- 1 | name: Setup 2 | 3 | inputs: 4 | python-version: 5 | required: false 6 | default: '3.10' 7 | torch-version: 8 | required: false 9 | default: '2.9' 10 | cuda-version: 11 | required: false 12 | default: cpu 13 | 14 | runs: 15 | using: composite 16 | 17 | steps: 18 | - name: Set up Python ${{ inputs.python-version }} 19 | uses: actions/setup-python@v6 20 | with: 21 | python-version: ${{ inputs.python-version }} 22 | check-latest: true 23 | cache: pip 24 | cache-dependency-path: | 25 | pyproject.toml 26 | 27 | - name: Pre-install NumPy<2 if necessary 28 | run: | 29 | [[ ${{ inputs.torch-version }} =~ ^2\.[0-2] ]] && pip install 'numpy<2' || echo "Skipping NumPy<2 installation" 30 | shell: bash 31 | 32 | - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }} 33 | if: ${{ inputs.torch-version != 'nightly' }} 34 | run: | 35 | pip install torch==${{ inputs.torch-version }}.* --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }} 36 | shell: bash 37 | 38 | - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }} 39 | if: ${{ inputs.torch-version == 'nightly' }} 40 | run: | 41 | pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/${{ inputs.cuda-version }} 42 | shell: bash 43 | 44 | - name: List installed packages 45 | run: | 46 | pip list 47 | python -c "import torch; print('PyTorch:', torch.__version__)" 48 | python -c "import torch; print('CUDA:', torch.version.cuda)" 49 | shell: bash 50 | -------------------------------------------------------------------------------- /torch_frame/data/loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | from torch_frame.data import Dataset, TensorFrame 6 | from torch_frame.typing import IndexSelectType 7 | 8 | 9 | class DataLoader(torch.utils.data.DataLoader): 10 | r"""A data loader which creates mini-batches from a 11 | :class:`torch_frame.Dataset` or :class:`torch_frame.TensorFrame` object. 12 | 13 | .. code-block:: python 14 | 15 | import torch_frame 16 | 17 | dataset = ... 18 | 19 | loader = torch_frame.data.DataLoader( 20 | dataset, 21 | batch_size=512, 22 | shuffle=True, 23 | ) 24 | 25 | Args: 26 | dataset (Dataset or TensorFrame): The dataset or tensor frame from 27 | which to load the data. 28 | *args (optional): Additional arguments of 29 | :class:`torch.utils.data.DataLoader`. 30 | **kwargs (optional): Additional keyword arguments of 31 | :class:`torch.utils.data.DataLoader`. 32 | """ 33 | def __init__( 34 | self, 35 | dataset: Dataset | TensorFrame, 36 | *args, 37 | **kwargs, 38 | ): 39 | kwargs.pop('collate_fn', None) 40 | 41 | if isinstance(dataset, Dataset): 42 | self.tensor_frame: TensorFrame = dataset.materialize().tensor_frame 43 | else: 44 | self.tensor_frame: TensorFrame = dataset 45 | 46 | super().__init__( 47 | range(len(dataset)), 48 | *args, 49 | collate_fn=self.collate_fn, 50 | **kwargs, 51 | ) 52 | 53 | def collate_fn(self, index: IndexSelectType) -> TensorFrame: 54 | return self.tensor_frame[index] 55 | -------------------------------------------------------------------------------- /test/nn/models/test_trompt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_frame import stype 4 | from torch_frame.data.dataset import Dataset 5 | from torch_frame.datasets import FakeDataset 6 | from torch_frame.nn import EmbeddingEncoder, LinearEncoder, Trompt 7 | 8 | 9 | @pytest.mark.parametrize('batch_size', [0, 5]) 10 | @pytest.mark.parametrize('use_stype_encoder_dicts', [ 11 | True, 12 | False, 13 | ]) 14 | def test_trompt(batch_size, use_stype_encoder_dicts): 15 | batch_size = 10 16 | channels = 8 17 | out_channels = 1 18 | num_prompts = 2 19 | num_layers = 3 20 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False) 21 | dataset.materialize() 22 | tensor_frame = dataset.tensor_frame[:batch_size] 23 | if use_stype_encoder_dicts: 24 | stype_encoder_dicts = [ 25 | { 26 | stype.numerical: LinearEncoder(), 27 | stype.categorical: EmbeddingEncoder(), 28 | }, 29 | { 30 | stype.numerical: LinearEncoder(), 31 | stype.categorical: EmbeddingEncoder(), 32 | }, 33 | { 34 | stype.numerical: LinearEncoder(), 35 | stype.categorical: EmbeddingEncoder(), 36 | }, 37 | ] 38 | else: 39 | stype_encoder_dicts = None 40 | model = Trompt( 41 | channels=channels, 42 | out_channels=out_channels, 43 | num_prompts=num_prompts, 44 | num_layers=num_layers, 45 | col_stats=dataset.col_stats, 46 | col_names_dict=tensor_frame.col_names_dict, 47 | stype_encoder_dicts=stype_encoder_dicts, 48 | ) 49 | model.reset_parameters() 50 | pred = model(tensor_frame) 51 | assert pred.shape == (batch_size, num_layers, out_channels) 52 | -------------------------------------------------------------------------------- /torch_frame/nn/decoder/excelformer_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Linear, PReLU 4 | 5 | from torch_frame.nn.decoder import Decoder 6 | 7 | 8 | class ExcelFormerDecoder(Decoder): 9 | r"""The ExcelFormer decoder introduced in the 10 | `"ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data" 11 | `_ paper. 12 | 13 | Args: 14 | in_channels (int): Input channel dimensionality 15 | out_channels (int): Output channel dimensionality 16 | num_cols (int): Number of columns. 17 | """ 18 | def __init__( 19 | self, 20 | in_channels: int, 21 | out_channels: int, 22 | num_cols: int, 23 | ) -> None: 24 | super().__init__() 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.lin_f = Linear(num_cols, self.out_channels) 28 | self.activation = PReLU() 29 | self.lin_d = Linear(self.in_channels, 1) 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self) -> None: 33 | self.lin_f.reset_parameters() 34 | self.lin_d.reset_parameters() 35 | with torch.no_grad(): 36 | self.activation.weight.fill_(0.25) 37 | 38 | def forward(self, x: Tensor) -> Tensor: 39 | r"""Transforming :obj:`x` into output predictions. 40 | 41 | Args: 42 | x (Tensor): Input column-wise tensor of shape 43 | [batch_size, num_cols, in_channels] 44 | 45 | Returns: 46 | Tensor: [batch_size, out_channels]. 47 | """ 48 | x = x.transpose(1, 2) 49 | x = self.lin_f(x) 50 | x = self.activation(x) 51 | x = self.lin_d(x.transpose(1, 2)).squeeze(2) 52 | return x 53 | -------------------------------------------------------------------------------- /docs/source/modules/nn.rst: -------------------------------------------------------------------------------- 1 | torch_frame.nn 2 | ============== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | .. currentmodule:: torch_frame.nn 8 | 9 | torch_frame.nn.encoder 10 | ---------------------- 11 | 12 | .. currentmodule:: torch_frame.nn.encoder 13 | 14 | .. autosummary:: 15 | :nosignatures: 16 | :toctree: ../generated 17 | :template: autosummary/class.rst 18 | 19 | {% for name in torch_frame.nn.encoder.classes %} 20 | {{ name }} 21 | {% endfor %} 22 | 23 | 24 | torch_frame.nn.encoding 25 | ------------------------ 26 | 27 | .. currentmodule:: torch_frame.nn.encoding 28 | 29 | .. autosummary:: 30 | :nosignatures: 31 | :toctree: ../generated 32 | :template: autosummary/class.rst 33 | 34 | {% for name in torch_frame.nn.encoding.classes %} 35 | {{ name }} 36 | {% endfor %} 37 | 38 | torch_frame.nn.conv 39 | ------------------------ 40 | 41 | .. currentmodule:: torch_frame.nn.conv 42 | 43 | .. autosummary:: 44 | :nosignatures: 45 | :toctree: ../generated 46 | :template: autosummary/class.rst 47 | 48 | {% for name in torch_frame.nn.conv.classes %} 49 | {{ name }} 50 | {% endfor %} 51 | 52 | torch_frame.nn.decoder 53 | ------------------------ 54 | 55 | .. currentmodule:: torch_frame.nn.decoder 56 | 57 | .. autosummary:: 58 | :nosignatures: 59 | :toctree: ../generated 60 | :template: autosummary/class.rst 61 | 62 | {% for name in torch_frame.nn.decoder.classes %} 63 | {{ name }} 64 | {% endfor %} 65 | 66 | torch_frame.nn.models 67 | ------------------------ 68 | 69 | .. currentmodule:: torch_frame.nn.models 70 | 71 | .. autosummary:: 72 | :nosignatures: 73 | :toctree: ../generated 74 | :template: autosummary/class.rst 75 | 76 | {% for name in torch_frame.nn.models.classes %} 77 | {{ name }} 78 | {% endfor %} 79 | -------------------------------------------------------------------------------- /test/datasets/test_titanic.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import torch 4 | 5 | import torch_frame 6 | from torch_frame.data.stats import StatType 7 | from torch_frame.datasets import Titanic 8 | 9 | 10 | def test_titanic(): 11 | with tempfile.TemporaryDirectory() as temp_dir: 12 | dataset = Titanic(temp_dir) 13 | assert str(dataset) == 'Titanic()' 14 | assert len(dataset) == 891 15 | assert dataset.feat_cols == [ 16 | 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked' 17 | ] 18 | 19 | dataset = dataset.materialize() 20 | 21 | tensor_frame = dataset.tensor_frame 22 | assert len(tensor_frame.feat_dict) == 2 23 | assert tensor_frame.feat_dict[torch_frame.numerical].dtype == torch.float 24 | assert tensor_frame.feat_dict[torch_frame.numerical].size() == (891, 4) 25 | assert tensor_frame.feat_dict[torch_frame.categorical].dtype == torch.long 26 | assert tensor_frame.feat_dict[torch_frame.categorical].size() == (891, 3) 27 | assert tensor_frame.col_names_dict == { 28 | torch_frame.categorical: ['Embarked', 'Pclass', 'Sex'], 29 | torch_frame.numerical: ['Age', 'Fare', 'Parch', 'SibSp'], 30 | } 31 | assert tensor_frame.y.size() == (891, ) 32 | assert tensor_frame.y.min() == 0 and tensor_frame.y.max() == 1 33 | 34 | col_stats = dataset.col_stats 35 | assert len(col_stats) == 8 36 | assert StatType.COUNT in col_stats['Survived'] 37 | assert StatType.COUNT in col_stats['Pclass'] 38 | assert StatType.COUNT in col_stats['Sex'] 39 | assert StatType.COUNT in col_stats['Embarked'] 40 | assert StatType.MEAN and StatType.STD in col_stats['Age'] 41 | assert StatType.MEAN and StatType.STD in col_stats['SibSp'] 42 | assert StatType.MEAN and StatType.STD in col_stats['Parch'] 43 | assert StatType.MEAN and StatType.STD in col_stats['Fare'] 44 | -------------------------------------------------------------------------------- /torch_frame/nn/decoder/trompt_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import Tensor 3 | from torch.nn import LayerNorm, Linear, ReLU, Sequential 4 | 5 | from torch_frame.nn.decoder import Decoder 6 | 7 | 8 | class TromptDecoder(Decoder): 9 | r"""The Trompt downstream introduced in 10 | `"Trompt: Towards a Better Deep Neural Network for Tabular Data" 11 | `_ paper. 12 | 13 | Args: 14 | in_channels (int): Input channel dimensionality 15 | out_channels (int): Output channel dimensionality 16 | num_prompts (int): Number of prompt columns. 17 | """ 18 | def __init__( 19 | self, 20 | in_channels: int, 21 | out_channels: int, 22 | num_prompts: int, 23 | ) -> None: 24 | super().__init__() 25 | self.in_channels = in_channels 26 | self.num_prompts = num_prompts 27 | self.lin_attn = Linear(in_channels, 1) 28 | self.mlp = Sequential( 29 | Linear(in_channels, in_channels), 30 | ReLU(), 31 | LayerNorm(in_channels), 32 | Linear(in_channels, out_channels), 33 | ) 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self) -> None: 37 | self.lin_attn.reset_parameters() 38 | for m in self.mlp: 39 | if not isinstance(m, ReLU): 40 | m.reset_parameters() 41 | 42 | def forward(self, x: Tensor) -> Tensor: 43 | batch_size = len(x) 44 | assert x.shape == (batch_size, self.num_prompts, self.in_channels) 45 | # [batch_size, num_prompts, 1] 46 | w_prompt = F.softmax(self.lin_attn(x), dim=1) 47 | # [batch_size, in_channels] 48 | x = (w_prompt * x).sum(dim=1) 49 | # [batch_size, out_channels] 50 | x = self.mlp(x) 51 | return x 52 | -------------------------------------------------------------------------------- /test/utils/test_concat.py: -------------------------------------------------------------------------------- 1 | import torch_frame 2 | from torch_frame import TensorFrame 3 | 4 | 5 | def test_cat_along_row(get_fake_tensor_frame): 6 | num_rows = 10 7 | num_repeats = 5 8 | tf: TensorFrame = get_fake_tensor_frame(num_rows=num_rows) 9 | tf_cat = torch_frame.cat([tf for _ in range(num_repeats)], dim=0) 10 | assert len(tf_cat) == num_rows * num_repeats 11 | for i in range(num_repeats): 12 | tf_mini = tf_cat[num_rows * i:num_rows * (i + 1)] 13 | assert tf_mini == tf 14 | 15 | 16 | def test_cat_along_col(get_fake_tensor_frame): 17 | num_rows = 10 18 | tf = get_fake_tensor_frame(num_rows=num_rows) 19 | stypes = list(tf.col_names_dict.keys()) 20 | feat_dict1 = {} 21 | feat_dict2 = {} 22 | col_names_dict1 = {} 23 | col_names_dict2 = {} 24 | for stype in stypes: 25 | if stype.use_dict_multi_nested_tensor: 26 | feat_dict1[stype] = { 27 | name: tf.feat_dict[stype][name][:, :1] 28 | for name in tf.feat_dict[stype].keys() 29 | } 30 | col_names_dict1[stype] = tf.col_names_dict[stype][:1] 31 | feat_dict2[stype] = { 32 | name: tf.feat_dict[stype][name][:, 1:] 33 | for name in tf.feat_dict[stype].keys() 34 | } 35 | col_names_dict2[stype] = tf.col_names_dict[stype][1:] 36 | else: 37 | feat_dict1[stype] = tf.feat_dict[stype][:, :1] 38 | col_names_dict1[stype] = tf.col_names_dict[stype][:1] 39 | feat_dict2[stype] = tf.feat_dict[stype][:, 1:] 40 | col_names_dict2[stype] = tf.col_names_dict[stype][1:] 41 | 42 | tf1 = TensorFrame(feat_dict1, col_names_dict1, tf.y) 43 | tf2 = TensorFrame(feat_dict2, col_names_dict2, None) 44 | tf_cat = torch_frame.cat([tf1, tf2], dim=1) 45 | assert tf_cat == tf 46 | -------------------------------------------------------------------------------- /torch_frame/testing/decorators.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from importlib import import_module 3 | from importlib.util import find_spec 4 | 5 | import torch 6 | from packaging.requirements import Requirement 7 | from packaging.version import Version 8 | 9 | 10 | def has_package(package: str) -> bool: 11 | r"""Returns :obj:`True` in case :obj:`package` is installed.""" 12 | if '|' in package: 13 | return any(has_package(p) for p in package.split('|')) 14 | 15 | req = Requirement(package) 16 | if find_spec(req.name) is None: 17 | return False 18 | module = import_module(req.name) 19 | if not hasattr(module, '__version__'): 20 | return True 21 | 22 | version = Version(module.__version__).base_version 23 | return version in req.specifier 24 | 25 | 26 | def withPackage(*args) -> Callable: 27 | r"""A decorator to skip tests if certain packages are not installed. 28 | Also supports version specification. 29 | """ 30 | na_packages = {package for package in args if not has_package(package)} 31 | 32 | def decorator(func: Callable) -> Callable: 33 | import pytest 34 | return pytest.mark.skipif( 35 | len(na_packages) > 0, 36 | reason=f"Package(s) {na_packages} are not installed", 37 | )(func) 38 | 39 | return decorator 40 | 41 | 42 | def withCUDA(func: Callable): 43 | r"""A decorator to test both on CPU and CUDA (if available).""" 44 | import pytest 45 | 46 | devices = [pytest.param(torch.device('cpu'), id='cpu')] 47 | if torch.cuda.is_available(): 48 | devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0')) 49 | 50 | return pytest.mark.parametrize('device', devices)(func) 51 | 52 | 53 | def onlyCUDA(func: Callable) -> Callable: 54 | r"""A decorator to skip tests if CUDA is not found.""" 55 | import pytest 56 | return pytest.mark.skipif( 57 | not torch.cuda.is_available(), 58 | reason="CUDA not available", 59 | )(func) 60 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os.path as osp 3 | import sys 4 | 5 | import pyg_sphinx_theme 6 | 7 | import torch_frame 8 | 9 | author = 'PyG Team' 10 | project = 'pytorch-frame' 11 | version = torch_frame.__version__ 12 | copyright = f'{datetime.datetime.now().year}, {author}' 13 | 14 | sys.path.append(osp.join(osp.dirname(pyg_sphinx_theme.__file__), 'extension')) 15 | 16 | extensions = [ 17 | 'sphinx.ext.autodoc', 18 | 'sphinx.ext.autosummary', 19 | 'sphinx.ext.intersphinx', 20 | 'sphinx.ext.mathjax', 21 | 'sphinx.ext.napoleon', 22 | 'sphinx.ext.viewcode', 23 | 'sphinx_copybutton', 24 | 'pyg', 25 | ] 26 | 27 | html_theme = 'pyg_sphinx_theme' 28 | html_logo = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/' 29 | 'master/pyg_sphinx_theme/static/img/pytorch_frame_logo.png') 30 | html_favicon = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/' 31 | 'master/pyg_sphinx_theme/static/img/pytorch_frame_favicon.png') 32 | html_static_path = ['_static'] 33 | templates_path = ['_templates'] 34 | 35 | add_module_names = False 36 | autodoc_member_order = 'bysource' 37 | 38 | suppress_warnings = ['autodoc.import_object'] 39 | 40 | intersphinx_mapping = { 41 | 'python': ('https://docs.python.org/3', None), 42 | 'numpy': ('http://docs.scipy.org/doc/numpy', None), 43 | 'pandas': ('http://pandas.pydata.org/pandas-docs/dev', None), 44 | 'torch': ('https://pytorch.org/docs/stable', None), 45 | 'optuna': ('https://optuna.readthedocs.io/en/stable/', None), 46 | 'xgboost': ('https://xgboost.readthedocs.io/en/stable/', None), 47 | } 48 | 49 | copybutton_prompt_text = r">>> |\.\.\. " 50 | copybutton_prompt_is_regexp = True 51 | 52 | 53 | def setup(app): 54 | def rst_jinja_render(app, _, source): 55 | rst_context = {'torch_frame': torch_frame} 56 | source[0] = app.builder.templates.render_string(source[0], rst_context) 57 | 58 | app.connect('source-read', rst_jinja_render) 59 | app.add_js_file('js/version_alert.js') 60 | -------------------------------------------------------------------------------- /torch_frame/transforms/base_transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | from abc import ABC, abstractmethod 5 | from typing import Any 6 | 7 | from torch_frame import TensorFrame 8 | from torch_frame.data.stats import StatType 9 | 10 | 11 | class BaseTransform(ABC): 12 | r"""An abstract base class for writing transforms. 13 | 14 | Transforms are a general way to modify and customize 15 | :class:`TensorFrame` 16 | """ 17 | def __init__(self): 18 | self._transformed_stats: dict[str, dict[StatType, Any]] | None = None 19 | 20 | def __call__(self, tf: TensorFrame) -> TensorFrame: 21 | # Shallow-copy the data so that we prevent in-place data modification. 22 | return self.forward(copy.copy(tf)) 23 | 24 | @abstractmethod 25 | def forward(self, tf: TensorFrame) -> TensorFrame: 26 | r"""Process TensorFrame obj into another TensorFrame obj. 27 | 28 | Args: 29 | tf (TensorFrame): Input :class:`TensorFrame`. 30 | 31 | Returns: 32 | TensorFrame: Input :class:`TensorFrame` after transform. 33 | """ 34 | return tf 35 | 36 | @property 37 | def transformed_stats(self) -> dict[str, dict[StatType, Any]]: 38 | r"""The column stats after the transform. 39 | 40 | Returns: 41 | transformed_stats (Dict[str, Dict[StatType, Any]]): 42 | Transformed column stats. The :class:`TensorFrame` object might 43 | be modified by the transform, so the returned 44 | :obj:`transformed_stats` would contain the column stats of the 45 | modified :class:`TensorFrame` object. 46 | """ 47 | if self._transformed_stats is None: 48 | raise ValueError("Transformed column stats is not computed yet. " 49 | "Please run necessary functions to compute this" 50 | " first.") 51 | return self._transformed_stats 52 | 53 | def __repr__(self) -> str: 54 | return f'{self.__class__.__name__}()' 55 | -------------------------------------------------------------------------------- /test/datasets/test_movielens_1m.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import torch 4 | 5 | import torch_frame 6 | from torch_frame.config.text_embedder import TextEmbedderConfig 7 | from torch_frame.data.stats import StatType 8 | from torch_frame.datasets import Movielens1M 9 | from torch_frame.testing.text_embedder import HashTextEmbedder 10 | 11 | 12 | def test_movielens_1m(): 13 | with tempfile.TemporaryDirectory() as temp_dir: 14 | dataset = Movielens1M( 15 | temp_dir, 16 | col_to_text_embedder_cfg=TextEmbedderConfig( 17 | text_embedder=HashTextEmbedder(10)), 18 | ) 19 | assert str(dataset) == 'Movielens1M()' 20 | assert len(dataset) == 1000209 21 | assert dataset.feat_cols == [ 22 | 'user_id', 'gender', 'age', 'occupation', 'zip', 'movie_id', 'title', 23 | 'genres', 'timestamp' 24 | ] 25 | 26 | dataset = dataset.materialize() 27 | 28 | tensor_frame = dataset.tensor_frame 29 | assert len(tensor_frame.feat_dict) == 4 30 | assert tensor_frame.feat_dict[torch_frame.categorical].dtype == torch.int64 31 | assert tensor_frame.feat_dict[torch_frame.categorical].size() == (1000209, 32 | 6) 33 | assert tensor_frame.feat_dict[ 34 | torch_frame.multicategorical].dtype == torch.int64 35 | assert tensor_frame.feat_dict[torch_frame.embedding].dtype == torch.float32 36 | assert tensor_frame.col_names_dict == { 37 | torch_frame.categorical: 38 | ['age', 'gender', 'movie_id', 'occupation', 'user_id', 'zip'], 39 | torch_frame.multicategorical: ['genres'], 40 | torch_frame.timestamp: ['timestamp'], 41 | torch_frame.embedding: ['title'], 42 | } 43 | assert tensor_frame.y.size() == (1000209, ) 44 | assert tensor_frame.y.min() == 1 and tensor_frame.y.max() == 5 45 | 46 | col_stats = dataset.col_stats 47 | assert len(col_stats) == 10 48 | assert StatType.COUNT in col_stats['user_id'] 49 | assert StatType.MULTI_COUNT in col_stats['genres'] 50 | assert StatType.YEAR_RANGE in col_stats['timestamp'] 51 | assert StatType.EMB_DIM in col_stats['title'] 52 | -------------------------------------------------------------------------------- /test/utils/test_infer_stype.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | import torch_frame 5 | from torch_frame.config.text_embedder import TextEmbedderConfig 6 | from torch_frame.datasets import FakeDataset 7 | from torch_frame.testing.text_embedder import HashTextEmbedder 8 | from torch_frame.utils import infer_df_stype 9 | 10 | 11 | def get_fake_dataset( 12 | num_rows: int, 13 | col_to_text_embedder_cfg: TextEmbedderConfig, 14 | with_nan: bool, 15 | ) -> FakeDataset: 16 | stypes = [ 17 | torch_frame.numerical, 18 | torch_frame.categorical, 19 | torch_frame.multicategorical, 20 | torch_frame.text_embedded, 21 | torch_frame.sequence_numerical, 22 | torch_frame.timestamp, 23 | torch_frame.embedding, 24 | ] 25 | dataset = FakeDataset( 26 | num_rows=num_rows, 27 | stypes=stypes, 28 | col_to_text_embedder_cfg=col_to_text_embedder_cfg, 29 | with_nan=with_nan, 30 | ) 31 | return dataset 32 | 33 | 34 | @pytest.mark.parametrize("with_nan", [True, False]) 35 | def test_infer_df_stype(with_nan): 36 | num_rows = 200 37 | col_to_text_embedder_cfg = TextEmbedderConfig( 38 | text_embedder=HashTextEmbedder(8)) 39 | dataset = get_fake_dataset(num_rows, col_to_text_embedder_cfg, with_nan) 40 | col_to_stype_inferred = infer_df_stype(dataset.df) 41 | assert col_to_stype_inferred == dataset.col_to_stype 42 | 43 | 44 | def test_infer_stypes(): 45 | # Test when multicategoricals are lists 46 | df = pd.DataFrame({ 47 | 'category': [['Books', 'Mystery, Thriller'], 48 | ['Books', "Children's Books", 'Geography'], 49 | ['Books', 'Health', 'Fitness & Dieting'], 50 | ['Books', 'Teen & oung Adult']] * 50, 51 | 'id': [i for i in range(200)] 52 | }) 53 | col_to_stype_inferred = infer_df_stype(df) 54 | assert col_to_stype_inferred['category'] == torch_frame.multicategorical 55 | 56 | df = pd.DataFrame({'bool': [True] * 50 + [False] * 50}) 57 | 58 | col_to_stype_inferred = infer_df_stype(df) 59 | assert col_to_stype_inferred['bool'] == torch_frame.categorical 60 | -------------------------------------------------------------------------------- /torch_frame/datasets/dota2.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import zipfile 3 | 4 | import pandas as pd 5 | 6 | import torch_frame 7 | 8 | 9 | class Dota2(torch_frame.data.Dataset): 10 | r"""The `Dota2 Game Results 11 | `_ 12 | dataset. Dota2 is a popular moba game with two teams of 5 players. 13 | At start of the game, each player choose a unique hero with 14 | different strengths and weakness. The dataset is reasonably sparse 15 | as only 10 of 113 possible heroes are chosen in a given game. All 16 | games were played in a space of 2 hours on the 13th of August 2016. 17 | The classification goal is to predict the winning team. 18 | 19 | **STATS:** 20 | 21 | .. list-table:: 22 | :widths: 10 10 10 10 20 10 23 | :header-rows: 1 24 | 25 | * - #rows 26 | - #cols (numerical) 27 | - #cols (categorical) 28 | - #classes 29 | - Task 30 | - Missing value ratio 31 | * - 92,650 32 | - 0 33 | - 116 34 | - 2 35 | - binary_classification 36 | - 0.0% 37 | """ 38 | 39 | url = 'https://archive.ics.uci.edu/static/public/367/dota2+games+results.zip' # noqa 40 | 41 | def __init__(self, root: str): 42 | path = self.download_url(self.url, root) 43 | names = [ 44 | 'Team won the game', 45 | 'Cluster ID', 46 | 'Game mode', 47 | 'Game type', 48 | ] 49 | num_heroes = 113 50 | names += [f'hero_{i}' for i in range(num_heroes)] 51 | folder_path = osp.dirname(path) 52 | with zipfile.ZipFile(path, 'r') as zip_ref: 53 | zip_ref.extractall(folder_path) 54 | 55 | df = pd.read_csv(osp.join(folder_path, 'dota2Train.csv'), names=names) 56 | 57 | col_to_stype = { 58 | 'Team won the game': torch_frame.categorical, 59 | 'Cluster ID': torch_frame.categorical, 60 | 'Game mode': torch_frame.categorical, 61 | 'Game type': torch_frame.categorical, 62 | } 63 | for i in range(num_heroes): 64 | col_to_stype[f'hero_{i}'] = torch_frame.categorical 65 | 66 | super().__init__(df, col_to_stype, target_col='Team won the game') 67 | -------------------------------------------------------------------------------- /test/transforms/test_mutual_information_sort.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_frame import TaskType, TensorFrame, stype 5 | from torch_frame.data import Dataset 6 | from torch_frame.datasets.fake import FakeDataset 7 | from torch_frame.transforms import MutualInformationSort 8 | 9 | 10 | @pytest.mark.parametrize('with_nan', [True, False]) 11 | def test_mutual_information_sort(with_nan): 12 | task_type = TaskType.REGRESSION 13 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=with_nan, 14 | stypes=[stype.numerical], create_split=True, 15 | task_type=task_type) 16 | # modify the FakeDataset so column num_1 would have highest mutual 17 | # information score 18 | dataset.df['num_1'] = dataset.df['target'].astype(float) 19 | dataset.materialize() 20 | 21 | tensor_frame: TensorFrame = dataset.tensor_frame 22 | train_dataset = dataset.get_split('train') 23 | transform = MutualInformationSort(task_type) 24 | transform.fit(train_dataset.tensor_frame, train_dataset.col_stats) 25 | out = transform(tensor_frame) 26 | 27 | # column num_1 ranks the first 28 | assert (out.col_names_dict[stype.numerical][0] == 'num_1') 29 | actual_first_col = out.feat_dict[stype.numerical][:, 0] 30 | actual_first_col_nan_mask = torch.isnan(actual_first_col) 31 | expected_first_col = torch.tensor(dataset.df['num_1'].values, 32 | dtype=torch.float32) 33 | expected_first_col_nan_mask = torch.isnan(expected_first_col) 34 | # if the tensor on first column contains NaNs, make sure the NaNs 35 | # are unchanged 36 | assert (torch.equal(actual_first_col_nan_mask, 37 | expected_first_col_nan_mask)) 38 | actual = actual_first_col[~actual_first_col_nan_mask] 39 | expected = expected_first_col[~expected_first_col_nan_mask] 40 | # make sure that the non NaN values are the same on first column 41 | assert (torch.allclose(actual, expected)) 42 | 43 | # make sure the shapes are unchanged 44 | assert (set(out.col_names_dict[stype.numerical]) == set( 45 | tensor_frame.col_names_dict[stype.numerical])) 46 | assert (out.feat_dict[stype.numerical].size() == tensor_frame.feat_dict[ 47 | stype.numerical].size()) 48 | -------------------------------------------------------------------------------- /test/nn/models/test_excelformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytest 4 | import torch 5 | 6 | from torch_frame import TaskType, stype 7 | from torch_frame.data.dataset import Dataset 8 | from torch_frame.datasets.fake import FakeDataset 9 | from torch_frame.nn import ExcelFormer 10 | 11 | 12 | @pytest.mark.parametrize('task_type', [ 13 | TaskType.REGRESSION, 14 | TaskType.BINARY_CLASSIFICATION, 15 | TaskType.MULTICLASS_CLASSIFICATION, 16 | ]) 17 | @pytest.mark.parametrize('batch_size', [0, 5]) 18 | @pytest.mark.parametrize('mixup', [None, 'feature', 'hidden']) 19 | def test_excelformer(task_type, batch_size, mixup): 20 | in_channels = 8 21 | num_heads = 2 22 | num_layers = 6 23 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False, 24 | stypes=[stype.numerical], 25 | task_type=task_type) 26 | dataset.materialize() 27 | if task_type.is_classification: 28 | out_channels = dataset.num_classes 29 | else: 30 | out_channels = 1 31 | num_cols = len(dataset.col_stats) - 1 32 | tensor_frame = dataset.tensor_frame[:batch_size] 33 | model = ExcelFormer( 34 | in_channels=in_channels, 35 | out_channels=out_channels, 36 | num_cols=num_cols, 37 | num_layers=num_layers, 38 | num_heads=num_heads, 39 | mixup=mixup, 40 | col_stats=dataset.col_stats, 41 | col_names_dict=tensor_frame.col_names_dict, 42 | ) 43 | model.reset_parameters() 44 | 45 | # Test the original forward pass 46 | out = model(tensor_frame) 47 | assert out.shape == (batch_size, out_channels) 48 | 49 | # Test the mixup forward pass 50 | feat_num = copy.copy(tensor_frame.feat_dict[stype.numerical]) 51 | # Set lazy mutual information scores for `feature` mixup 52 | tensor_frame.mi_scores = torch.rand(torch.Size((feat_num.shape[1], ))) 53 | out_mixedup, y_mixedup = model(tensor_frame, mixup_encoded=True) 54 | assert out_mixedup.shape == (batch_size, out_channels) 55 | # Make sure the numerical feature is not modified. 56 | assert torch.allclose(feat_num, tensor_frame.feat_dict[stype.numerical]) 57 | 58 | if task_type.is_classification: 59 | assert y_mixedup.shape == (batch_size, out_channels) 60 | else: 61 | assert y_mixedup.shape == tensor_frame.y.shape 62 | -------------------------------------------------------------------------------- /torch_frame/datasets/adult_census_income.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import torch_frame 4 | 5 | 6 | class AdultCensusIncome(torch_frame.data.Dataset): 7 | r"""The `Adult Census Income 8 | `_ 9 | dataset from Kaggle. It's extracted from census bureau database and the 10 | task is to predict whether a person's income exceeds $50K/year. 11 | 12 | **STATS:** 13 | 14 | .. list-table:: 15 | :widths: 10 10 10 10 20 10 16 | :header-rows: 1 17 | 18 | * - #rows 19 | - #cols (numerical) 20 | - #cols (categorical) 21 | - #classes 22 | - Task 23 | - Missing value ratio 24 | * - 32,561 25 | - 4 26 | - 8 27 | - 2 28 | - binary_classification 29 | - 0.0% 30 | """ 31 | 32 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data' # noqa 33 | 34 | def __init__(self, root: str): 35 | path = self.download_url(self.url, root) 36 | names = [ 37 | 'age', 38 | 'workclass', 39 | 'fnlwgt', 40 | 'education', 41 | 'education.num', 42 | 'marital.status', 43 | 'occupation', 44 | 'relationship', 45 | 'race', 46 | 'sex', 47 | 'capital.gain', 48 | 'capital.loss', 49 | 'hours.per.week', 50 | 'native.country', 51 | 'income', 52 | ] 53 | df = pd.read_csv(path, names=names) 54 | 55 | col_to_stype = { 56 | 'age': torch_frame.numerical, 57 | 'workclass': torch_frame.categorical, 58 | 'education': torch_frame.categorical, 59 | 'marital.status': torch_frame.categorical, 60 | 'occupation': torch_frame.categorical, 61 | 'relationship': torch_frame.categorical, 62 | 'race': torch_frame.categorical, 63 | 'sex': torch_frame.categorical, 64 | 'capital.gain': torch_frame.numerical, 65 | 'capital.loss': torch_frame.numerical, 66 | 'hours.per.week': torch_frame.numerical, 67 | 'native.country': torch_frame.categorical, 68 | 'income': torch_frame.categorical, 69 | } 70 | 71 | super().__init__(df, col_to_stype, target_col='income') 72 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: 6 | - master 7 | release: 8 | types: 9 | - published 10 | 11 | defaults: 12 | run: 13 | shell: bash 14 | 15 | jobs: 16 | build-package: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v6 20 | 21 | - uses: actions/setup-python@v6 22 | with: 23 | python-version: '3.10' 24 | 25 | - name: Install build tools 26 | run: | 27 | pip install -U pip 28 | pip install -U flit twine 29 | pip list 30 | 31 | - name: Build package 32 | run: | 33 | flit build --no-use-vcs 34 | twine check dist/* --strict 35 | 36 | - uses: actions/upload-artifact@v5 37 | with: 38 | name: release-${{ github.sha }} 39 | path: dist 40 | 41 | publish-package-test: 42 | if: github.event_name == 'release' 43 | runs-on: ubuntu-latest 44 | needs: [build-package] 45 | steps: 46 | - uses: actions/checkout@v6 47 | 48 | - uses: actions/setup-python@v6 49 | with: 50 | python-version: '3.10' 51 | 52 | - name: Download package 53 | uses: actions/download-artifact@v6 54 | with: 55 | name: release-${{ github.sha }} 56 | path: dist 57 | 58 | - name: Publish package to TestPyPI 59 | uses: pypa/gh-action-pypi-publish@release/v1 60 | with: 61 | repository-url: https://test.pypi.org/legacy/ 62 | username: __token__ 63 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 64 | verbose: true 65 | print-hash: true 66 | 67 | publish-package: 68 | if: github.event_name == 'release' 69 | runs-on: ubuntu-latest 70 | needs: [publish-package-test] 71 | permissions: 72 | id-token: write 73 | steps: 74 | - uses: actions/checkout@v6 75 | 76 | - uses: actions/setup-python@v6 77 | with: 78 | python-version: '3.10' 79 | 80 | - name: Download package 81 | uses: actions/download-artifact@v6 82 | with: 83 | name: release-${{ github.sha }} 84 | path: dist 85 | 86 | - name: Publish package to PyPI 87 | uses: pypa/gh-action-pypi-publish@release/v1 88 | with: 89 | username: __token__ 90 | password: ${{ secrets.PYPI_API_TOKEN }} 91 | verbose: true 92 | print-hash: true 93 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | 3 | on: # yamllint disable-line rule:truthy 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} 11 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 12 | 13 | jobs: 14 | 15 | pytest: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | include: 21 | - {torch-version: '2.7', python-version: '3.10'} 22 | - {torch-version: '2.7', python-version: '3.13'} 23 | - {torch-version: '2.9', python-version: '3.10'} 24 | - {torch-version: '2.9', python-version: '3.13'} 25 | - {torch-version: 'nightly', python-version: '3.10'} 26 | - {torch-version: 'nightly', python-version: '3.13'} 27 | 28 | steps: 29 | - name: Checkout repository 30 | uses: actions/checkout@v6 31 | with: 32 | fetch-depth: 40 33 | 34 | # Skip workflow if only certain files have been changed. 35 | - name: Get changed files 36 | id: changed-files-specific 37 | uses: tj-actions/changed-files@v47 38 | with: 39 | files: | 40 | docs/** 41 | examples/** 42 | README.md 43 | CHANGELOG.md 44 | .github/workflows/changelog.yml 45 | .github/workflows/dependabot-auto-merge.yml 46 | .github/workflows/documentation.yml 47 | .github/workflows/labeler.yml 48 | .github/workflows/linting.yml 49 | .github/workflows/release.yml 50 | 51 | - name: Setup packages 52 | if: steps.changed-files-specific.outputs.only_changed != 'true' 53 | uses: ./.github/actions/setup 54 | with: 55 | python-version: ${{ matrix.python-version }} 56 | torch-version: ${{ matrix.torch-version }} 57 | cuda-version: cpu 58 | 59 | - name: Install main package 60 | if: steps.changed-files-specific.outputs.only_changed != 'true' 61 | run: | 62 | pip install -e .[full,test] 63 | pip list 64 | 65 | - name: Run tests 66 | if: steps.changed-files-specific.outputs.only_changed != 'true' 67 | run: | 68 | pytest --cov --cov-report=xml 69 | 70 | - name: Upload coverage 71 | if: steps.changed-files-specific.outputs.only_changed != 'true' 72 | uses: codecov/codecov-action@v5 73 | with: 74 | fail_ci_if_error: false 75 | -------------------------------------------------------------------------------- /docs/source/modules/transforms.rst: -------------------------------------------------------------------------------- 1 | torch_frame.transforms 2 | ====================== 3 | 4 | .. contents:: Contents 5 | :local: 6 | 7 | .. currentmodule:: torch_frame.transforms 8 | 9 | Transforms 10 | ---------- 11 | 12 | :pyf:`PyTorch Frame` allows for data transformation across different :obj:`stype`'s or within the same :obj:`stype`. Transforms takes in both :class:`TensorFrame` and column stats. 13 | 14 | Let's look an example, where we apply `CatToNumTransform `_ to transform the categorical features into numerical features. 15 | 16 | .. code-block:: python 17 | 18 | from torch_frame.datasets import Yandex 19 | from torch_frame.transforms import CatToNumTransform 20 | from torch_frame import stype 21 | 22 | dataset = Yandex(root='/tmp/adult', name='adult') 23 | dataset.materialize() 24 | transform = CatToNumTransform() 25 | train_dataset = dataset.get_split('train') 26 | 27 | train_dataset.tensor_frame.col_names_dict[stype.categorical] 28 | >>> ['C_feature_0', 'C_feature_1', 'C_feature_2', 'C_feature_3', 'C_feature_4', 'C_feature_5', 'C_feature_6', 'C_feature_7'] 29 | 30 | test_dataset = dataset.get_split('test') 31 | transform.fit(train_dataset.tensor_frame, dataset.col_stats) 32 | 33 | transformed_col_stats = transform.transformed_stats 34 | 35 | transformed_col_stats.keys() 36 | >>> dict_keys(['C_feature_0_0', 'C_feature_1_0', 'C_feature_2_0', 'C_feature_3_0', 'C_feature_4_0', 'C_feature_5_0', 'C_feature_6_0', 'C_feature_7_0']) 37 | 38 | transformed_col_stats['C_feature_0_0'] 39 | >>> {: 0.6984029484029484, : 0.45895127199411595, : [0.0, 0.0, 1.0, 1.0, 1.0]} 40 | 41 | transform(test_dataset.tensor_frame) 42 | >>> TensorFrame( 43 | num_cols=14, 44 | num_rows=16281, 45 | numerical (14): ['N_feature_0', 'N_feature_1', 'N_feature_2', 'N_feature_3', 'N_feature_4', 'N_feature_5', 'C_feature_0_0', 'C_feature_1_0', 'C_feature_2_0', 'C_feature_3_0', 'C_feature_4_0', 'C_feature_5_0', 'C_feature_6_0', 'C_feature_7_0'], 46 | has_target=True, 47 | device=cpu, 48 | ) 49 | 50 | You can see that after the transform, the column names of the categorical features changes and the categorical features are transformed into numerical features. 51 | 52 | 53 | .. autosummary:: 54 | :nosignatures: 55 | :toctree: ../generated 56 | 57 | {% for name in torch_frame.transforms.functions %} 58 | {{ name }} 59 | {% endfor %} 60 | -------------------------------------------------------------------------------- /torch_frame/datasets/amazon_fine_food_reviews.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pandas as pd 4 | 5 | import torch_frame 6 | from torch_frame.config.text_embedder import TextEmbedderConfig 7 | from torch_frame.config.text_tokenizer import TextTokenizerConfig 8 | 9 | 10 | class AmazonFineFoodReviews(torch_frame.data.Dataset): 11 | r"""The `Amazon Fine Food Reviews `_ 12 | dataset. It consists of reviews of fine foods from amazon. 13 | 14 | Args: 15 | text_stype (torch_frame.stype): Text stype to use for text columns 16 | in the dataset. (default: :obj:`torch_frame.text_embedded`) 17 | 18 | **STATS:** 19 | 20 | .. list-table:: 21 | :widths: 10 10 10 10 10 20 10 22 | :header-rows: 1 23 | 24 | * - #rows 25 | - #cols (numerical) 26 | - #cols (categorical) 27 | - #cols (text) 28 | - #classes 29 | - Task 30 | - Missing value ratio 31 | * - 568,454 32 | - 2 33 | - 3 34 | - 2 35 | - 5 36 | - multiclass_classification 37 | - 0.0% 38 | """ 39 | 40 | url = "https://data.pyg.org/datasets/tables/amazon_fine_food_reviews.zip" 41 | 42 | def __init__( 43 | self, 44 | root: str, 45 | text_stype: torch_frame.stype = torch_frame.text_embedded, 46 | col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] 47 | | TextEmbedderConfig | None = None, 48 | col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig] 49 | | TextTokenizerConfig | None = None, 50 | ) -> None: 51 | self.root = root 52 | self.text_stype = text_stype 53 | path = self.download_url(self.url, root) 54 | 55 | col_to_stype = { 56 | 'ProductId': torch_frame.categorical, 57 | 'UserId': torch_frame.categorical, 58 | 'HelpfulnessNumerator': torch_frame.numerical, 59 | 'HelpfulnessDenominator': torch_frame.numerical, 60 | 'Score': torch_frame.categorical, 61 | # 'Time': torch_frame.categorical, # TODO: change to timestamp 62 | 'Summary': text_stype, 63 | 'Text': text_stype, 64 | } 65 | 66 | df = pd.read_csv(path)[list(col_to_stype.keys())] 67 | 68 | super().__init__( 69 | df, 70 | col_to_stype, 71 | target_col='Score', 72 | col_to_text_embedder_cfg=col_to_text_embedder_cfg, 73 | col_to_text_tokenizer_cfg=col_to_text_tokenizer_cfg, 74 | ) 75 | -------------------------------------------------------------------------------- /test/nn/models/test_compile.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torch_frame import stype 5 | from torch_frame.datasets import FakeDataset 6 | from torch_frame.nn.models import ( 7 | ExcelFormer, 8 | FTTransformer, 9 | ResNet, 10 | TabNet, 11 | TabTransformer, 12 | Trompt, 13 | ) 14 | from torch_frame.testing import withPackage 15 | 16 | 17 | @withPackage("torch>=2.6.0") 18 | @pytest.mark.parametrize( 19 | "model_cls, model_kwargs, stypes, expected_graph_breaks", 20 | [ 21 | pytest.param( 22 | FTTransformer, 23 | dict(channels=8), 24 | None, 25 | 2, 26 | id="FTTransformer", 27 | ), 28 | pytest.param(ResNet, dict(channels=8), None, 2, id="ResNet"), 29 | pytest.param( 30 | TabNet, 31 | dict( 32 | split_feat_channels=2, 33 | split_attn_channels=2, 34 | gamma=0.1, 35 | ), 36 | None, 37 | 2, 38 | id="TabNet", 39 | ), 40 | pytest.param( 41 | TabTransformer, 42 | dict( 43 | channels=8, 44 | num_heads=2, 45 | encoder_pad_size=2, 46 | attn_dropout=0.5, 47 | ffn_dropout=0.5, 48 | ), 49 | None, 50 | 0, 51 | id="TabTransformer", 52 | ), 53 | pytest.param( 54 | Trompt, 55 | dict(channels=8, num_prompts=2), 56 | None, 57 | 3, 58 | id="Trompt", 59 | ), 60 | pytest.param( 61 | ExcelFormer, 62 | dict(in_channels=8, num_cols=3, num_heads=1), 63 | [stype.numerical], 64 | 1, 65 | id="ExcelFormer", 66 | ), 67 | ], 68 | ) 69 | def test_compile_graph_break( 70 | model_cls, 71 | model_kwargs, 72 | stypes, 73 | expected_graph_breaks, 74 | ): 75 | torch._dynamo.config.suppress_errors = True 76 | 77 | dataset = FakeDataset( 78 | num_rows=10, 79 | with_nan=False, 80 | stypes=stypes or [stype.categorical, stype.numerical], 81 | ) 82 | dataset.materialize() 83 | tf = dataset.tensor_frame 84 | model = model_cls( 85 | out_channels=1, 86 | num_layers=2, 87 | col_stats=dataset.col_stats, 88 | col_names_dict=tf.col_names_dict, 89 | **model_kwargs, 90 | ) 91 | explanation = torch._dynamo.explain(model)(tf) 92 | graph_breaks = explanation.graph_break_count 93 | assert graph_breaks == expected_graph_breaks 94 | -------------------------------------------------------------------------------- /torch_frame/datasets/poker_hand.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import zipfile 3 | 4 | import pandas as pd 5 | 6 | import torch_frame 7 | 8 | 9 | class PokerHand(torch_frame.data.Dataset): 10 | r"""The `Poker Hand 11 | `_ 12 | dataset. It's a task to predict 5-card poker hand. 13 | 14 | **STATS:** 15 | 16 | .. list-table:: 17 | :widths: 10 10 10 10 20 10 18 | :header-rows: 1 19 | 20 | * - #rows 21 | - #cols (numerical) 22 | - #cols (categorical) 23 | - #classes 24 | - Task 25 | - Missing value ratio 26 | * - 1,025,010 27 | - 5 28 | - 5 29 | - 10 30 | - multiclass_classification 31 | - 0.0% 32 | """ 33 | 34 | url = 'https://archive.ics.uci.edu/static/public/158/poker+hand.zip' 35 | 36 | def __init__(self, root: str): 37 | path = self.download_url(self.url, root) 38 | folder_path = osp.dirname(path) 39 | 40 | with zipfile.ZipFile(path, 'r') as zip_ref: 41 | zip_ref.extractall(folder_path) 42 | 43 | train_path = osp.join(folder_path, 'poker-hand-training-true.data') 44 | test_path = osp.join(folder_path, 'poker-hand-testing.data') 45 | 46 | names = [ 47 | 'Suit of card #1', 48 | 'Rank of card #1', 49 | 'Suit of card #2', 50 | 'Rank of card #2', 51 | 'Suit of card #3', 52 | 'Rank of card #3', 53 | 'Suit of card #4', 54 | 'Rank of card #4', 55 | 'Suit of card #5', 56 | 'Rank of card #5', 57 | 'Poker Hand', 58 | ] 59 | train_df = pd.read_csv(train_path, names=names) 60 | test_df = pd.read_csv(test_path, names=names) 61 | df = pd.concat([train_df, test_df], ignore_index=True) 62 | 63 | col_to_stype = { 64 | 'Suit of card #1': torch_frame.categorical, 65 | 'Rank of card #1': torch_frame.numerical, 66 | 'Suit of card #2': torch_frame.categorical, 67 | 'Rank of card #2': torch_frame.numerical, 68 | 'Suit of card #3': torch_frame.categorical, 69 | 'Rank of card #3': torch_frame.numerical, 70 | 'Suit of card #4': torch_frame.categorical, 71 | 'Rank of card #4': torch_frame.numerical, 72 | 'Suit of card #5': torch_frame.categorical, 73 | 'Rank of card #5': torch_frame.numerical, 74 | 'Poker Hand': torch_frame.categorical, 75 | } 76 | 77 | super().__init__(df, col_to_stype, target_col='Poker Hand') 78 | -------------------------------------------------------------------------------- /torch_frame/datasets/mercari.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os.path as osp 4 | 5 | import pandas as pd 6 | 7 | import torch_frame 8 | from torch_frame.config.text_embedder import TextEmbedderConfig 9 | from torch_frame.utils.split import SPLIT_TO_NUM 10 | 11 | SPLIT_COL = 'split_col' 12 | 13 | 14 | class Mercari(torch_frame.data.Dataset): 15 | r"""The `Mercari Price Suggestion Challenge 16 | `_ 17 | dataset from Kaggle. 18 | 19 | Args: 20 | num_rows (int, optional): Number of rows to subsample. 21 | (default: :obj:`None`) 22 | 23 | **STATS:** 24 | 25 | .. list-table:: 26 | :widths: 10 10 10 10 20 10 27 | :header-rows: 1 28 | 29 | * - #rows 30 | - #cols (numerical) 31 | - #cols (categorical) 32 | - #cols (text_embedded) 33 | - Task 34 | - Missing value ratio 35 | * - 1,482,535 36 | - 1 37 | - 4 38 | - 2 39 | - regression 40 | - 0.0% 41 | """ 42 | base_url = 'https://data.pyg.org/datasets/tables/mercari_price_suggestion/' 43 | files = ['train', 'test_stg2'] 44 | 45 | def __init__( 46 | self, 47 | root: str, 48 | num_rows: int | None = None, 49 | col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] 50 | | TextEmbedderConfig | None = None, 51 | ) -> None: 52 | col_to_stype = { 53 | 'name': torch_frame.text_embedded, 54 | 'item_condition_id': torch_frame.categorical, 55 | 'category_name': torch_frame.multicategorical, 56 | 'brand_name': torch_frame.categorical, 57 | 'price': torch_frame.numerical, 58 | 'shipping': torch_frame.categorical, 59 | 'item_description': torch_frame.text_embedded, 60 | } 61 | train_path = osp.join(self.base_url, 'train.csv') 62 | self.download_url(train_path, root) 63 | df_train = pd.read_csv(train_path) 64 | test_path = osp.join(self.base_url, 'test_stg2.csv') 65 | self.download_url(test_path, root) 66 | df_test = pd.read_csv(test_path) 67 | df_train[SPLIT_COL] = SPLIT_TO_NUM['train'] 68 | df_test[SPLIT_COL] = SPLIT_TO_NUM['test'] 69 | df = pd.concat([df_train, df_test], axis=0, ignore_index=True) 70 | if num_rows is not None: 71 | df = df.head(num_rows) 72 | df.drop(['train_id'], axis=1, inplace=True) 73 | super().__init__(df, col_to_stype, target_col='price', col_to_sep="/", 74 | col_to_text_embedder_cfg=col_to_text_embedder_cfg, 75 | split_col=SPLIT_COL) 76 | -------------------------------------------------------------------------------- /examples/tabpfn_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | import numpy as np 5 | import torch 6 | from tabpfn import TabPFNClassifier 7 | # Please run `pip install tabpfn` to install the package 8 | from tqdm import tqdm 9 | 10 | from torch_frame.data import DataLoader 11 | from torch_frame.datasets import ( 12 | ForestCoverType, 13 | KDDCensusIncome, 14 | Mushroom, 15 | Titanic, 16 | ) 17 | 18 | parser = argparse.ArgumentParser( 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument( 21 | '--dataset', type=str, default="Titanic", 22 | choices=["Titanic", "Mushroom", "ForestCoverType", "KDDCensusIncome"]) 23 | parser.add_argument('--train_batch_size', type=int, default=4096) 24 | parser.add_argument('--test_batch_size', type=int, default=128) 25 | parser.add_argument('--seed', type=int, default=0) 26 | args = parser.parse_args() 27 | 28 | torch.manual_seed(args.seed) 29 | 30 | # Prepare datasets 31 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 32 | args.dataset) 33 | 34 | if args.dataset == "Titanic": 35 | dataset = Titanic(root=path) 36 | elif args.dataset == "ForestCoverType": 37 | dataset = ForestCoverType(root=path) 38 | elif args.dataset == "KDDCensusIncome": 39 | dataset = KDDCensusIncome(root=path) 40 | else: 41 | dataset = Mushroom(root=path) 42 | 43 | dataset.materialize() 44 | assert dataset.task_type.is_classification 45 | dataset = dataset.shuffle() 46 | train_dataset, test_dataset = dataset[:0.9], dataset[0.9:] 47 | train_tensor_frame = train_dataset.tensor_frame 48 | test_tensor_frame = test_dataset.tensor_frame 49 | train_loader = DataLoader( 50 | train_tensor_frame, 51 | batch_size=args.train_batch_size, 52 | shuffle=True, 53 | ) 54 | X_train = [] 55 | train_data = next(iter(train_loader)) 56 | for stype in train_data.stypes: 57 | X_train.append(train_data.feat_dict[stype]) 58 | X_train: torch.Tensor = torch.cat(X_train, dim=1) 59 | clf = TabPFNClassifier() 60 | clf.fit(X_train, train_data.y) 61 | test_loader = DataLoader(test_tensor_frame, batch_size=args.test_batch_size) 62 | 63 | 64 | @torch.no_grad() 65 | def test() -> float: 66 | accum = total_count = 0 67 | for test_data in tqdm(test_loader): 68 | X_test = [] 69 | for stype in train_data.stypes: 70 | X_test.append(test_data.feat_dict[stype]) 71 | X_test = torch.cat(X_test, dim=1) 72 | pred: np.ndarray = clf.predict_proba(X_test) 73 | pred_class = pred.argmax(axis=-1) 74 | accum += float((test_data.y.numpy() == pred_class).sum()) 75 | total_count += len(test_data.y) 76 | 77 | return accum / total_count 78 | 79 | 80 | acc = test() 81 | print(f"Accuracy: {acc:.4f}") 82 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: weekly 8 | # mypy exceeds free tier 250MiB limit on pre-commit.ci 9 | # https://github.com/pre-commit-ci/issues/issues/171 10 | skip: [mypy] 11 | 12 | repos: 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v6.0.0 15 | hooks: 16 | - id: no-commit-to-branch 17 | name: No commits to master 18 | - id: end-of-file-fixer 19 | name: End-of-file fixer 20 | - id: trailing-whitespace 21 | name: Remove trailing whitespaces 22 | - id: check-toml 23 | name: Check toml 24 | - id: check-yaml 25 | name: Check yaml 26 | 27 | - repo: https://github.com/adrienverge/yamllint.git 28 | rev: v1.37.1 29 | hooks: 30 | - id: yamllint 31 | name: Lint yaml 32 | args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}'] 33 | 34 | - repo: https://github.com/PyCQA/autoflake 35 | rev: v2.3.1 36 | hooks: 37 | - id: autoflake 38 | name: Remove unused imports and variables 39 | args: [ 40 | --remove-all-unused-imports, 41 | --remove-unused-variables, 42 | --remove-duplicate-keys, 43 | --ignore-init-module-imports, 44 | --in-place, 45 | ] 46 | 47 | - repo: https://github.com/google/yapf 48 | rev: v0.43.0 49 | hooks: 50 | - id: yapf 51 | name: Format code 52 | additional_dependencies: [toml] 53 | 54 | - repo: https://github.com/pycqa/isort 55 | rev: 7.0.0 56 | hooks: 57 | - id: isort 58 | name: Sort imports 59 | 60 | - repo: https://github.com/PyCQA/flake8 61 | rev: 7.3.0 62 | hooks: 63 | - id: flake8 64 | name: Check PEP8 65 | additional_dependencies: [Flake8-pyproject] 66 | 67 | - repo: https://github.com/astral-sh/ruff-pre-commit 68 | rev: v0.14.8 69 | hooks: 70 | - id: ruff 71 | name: Ruff formatting 72 | args: [--fix, --exit-non-zero-on-fix] 73 | 74 | - repo: https://github.com/pre-commit/mirrors-mypy 75 | rev: v1.19.0 76 | hooks: 77 | - id: mypy 78 | name: Check types 79 | additional_dependencies: [torch==2.9.*] 80 | exclude: "^test/|^examples/|^benchmark/" 81 | 82 | - repo: https://github.com/executablebooks/mdformat 83 | rev: 1.0.0 84 | hooks: 85 | - id: mdformat 86 | name: Format Markdown 87 | additional_dependencies: 88 | - mdformat-gfm 89 | - mdformat-front-matters 90 | - mdformat-footnote 91 | -------------------------------------------------------------------------------- /torch_frame/datasets/bank_marketing.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import zipfile 3 | 4 | import pandas as pd 5 | 6 | import torch_frame 7 | 8 | 9 | class BankMarketing(torch_frame.data.Dataset): 10 | r"""The `Bank Marketing 11 | `_ 12 | dataset. It's related with direct marketing campaigns of 13 | a Portuguese banking institution. The marketing campaigns 14 | were based on phone calls. Often, more than one contant to 15 | the same client was required, in order to access if the 16 | product (bank term deposit) would be (or not) subscribed. 17 | The classification goal is to predict if the client will 18 | subscribe a term deposit. 19 | 20 | **STATS:** 21 | 22 | .. list-table:: 23 | :widths: 10 10 10 10 20 10 24 | :header-rows: 1 25 | 26 | * - #rows 27 | - #cols (numerical) 28 | - #cols (categorical) 29 | - #classes 30 | - Task 31 | - Missing value ratio 32 | * - 45,211 33 | - 7 34 | - 9 35 | - 2 36 | - binary_classification 37 | - 0.0% 38 | """ 39 | 40 | url = 'https://archive.ics.uci.edu/static/public/222/bank+marketing.zip' # noqa 41 | 42 | def __init__(self, root: str): 43 | path = self.download_url(self.url, root) 44 | folder_path = osp.dirname(path) 45 | with zipfile.ZipFile(path, 'r') as zip_ref: 46 | zip_ref.extractall(folder_path) 47 | data_path = osp.join(folder_path, 'bank.zip') 48 | data_subfolder_path = osp.join(folder_path, 'bank') 49 | with zipfile.ZipFile(data_path, 'r') as zip_ref: 50 | zip_ref.extractall(data_subfolder_path) 51 | df = pd.read_csv(osp.join(data_subfolder_path, 'bank-full.csv'), 52 | sep=';') 53 | 54 | col_to_stype = { 55 | 'age': torch_frame.numerical, 56 | 'job': torch_frame.categorical, 57 | 'marital': torch_frame.categorical, 58 | 'education': torch_frame.categorical, 59 | 'default': torch_frame.categorical, 60 | 'balance': torch_frame.numerical, 61 | 'housing': torch_frame.categorical, 62 | 'loan': torch_frame.categorical, 63 | 'contact': torch_frame.categorical, 64 | 'day': torch_frame.numerical, 65 | 'month': torch_frame.categorical, 66 | 'duration': torch_frame.numerical, 67 | 'campaign': torch_frame.numerical, 68 | 'pdays': torch_frame.numerical, 69 | 'previous': torch_frame.numerical, 70 | 'poutcome': torch_frame.categorical, 71 | 'y': torch_frame.categorical, 72 | } 73 | 74 | super().__init__(df, col_to_stype, target_col='y') 75 | -------------------------------------------------------------------------------- /torch_frame/config/image_embedder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from collections.abc import Callable 5 | from dataclasses import dataclass 6 | 7 | from PIL import Image 8 | from torch import Tensor 9 | 10 | 11 | class ImageEmbedder(ABC): 12 | r"""Parent class for the :obj:`image_embedder` of 13 | :class:`ImageEmbedderConfig`. This class first retrieves images based 14 | on given paths stored in the data frame and then embeds retrieved images 15 | into tensor. Users are responsible for implementing :meth:`forward_embed` 16 | which takes a list of images and returns embeddings tensor. User can also 17 | override :meth:`forward_retrieve` which takes the paths to images and 18 | return a list of :obj:`PIL.Image.Image`. 19 | """ 20 | def forward_retrieve(self, path_to_images: list[str]) -> list[Image.Image]: 21 | r"""Retrieval function that reads a list of images from 22 | a list of file paths with the :obj:`RGB` mode. 23 | """ 24 | images: list[Image.Image] = [] 25 | for path_to_image in path_to_images: 26 | image = Image.open(path_to_image) 27 | images.append(image.copy()) 28 | image.close() 29 | images = [image.convert('RGB') for image in images] 30 | return images 31 | 32 | @abstractmethod 33 | def forward_embed(self, images: list[Image.Image]) -> Tensor: 34 | r"""Embedding function that takes a list of images and returns 35 | an embedding tensor. 36 | """ 37 | raise NotImplementedError 38 | 39 | def __call__(self, path_to_images: list[str]) -> Tensor: 40 | images = self.forward_retrieve(path_to_images) 41 | return self.forward_embed(images) 42 | 43 | 44 | @dataclass 45 | class ImageEmbedderConfig: 46 | r"""Image embedder model that maps a list of images into PyTorch 47 | Tensor embeddings. 48 | 49 | Args: 50 | image_embedder (callable): A callable image embedder that takes a 51 | list of path to images as input and outputs the PyTorch Tensor 52 | embeddings for that list of images. Usually it contains a retriever 53 | to load image files and then a embedder converting images to 54 | embeddings. 55 | batch_size (int, optional): Batch size to use when encoding the 56 | images. If set to :obj:`None`, the image embeddings will 57 | be obtained in a full-batch manner. (default: :obj:`None`) 58 | 59 | """ 60 | image_embedder: Callable[[list[str]], Tensor] 61 | # Batch size to use when encoding the images. It is recommended to set 62 | # it to a reasonable value when one uses a heavy image embedding model 63 | # (e.g., ViT) on GPU. If set to :obj:`None`, the image embeddings 64 | # will be obtained in a full-batch manner. 65 | batch_size: int | None = None 66 | -------------------------------------------------------------------------------- /torch_frame/datasets/diamond_images.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os.path as osp 4 | import zipfile 5 | 6 | import pandas as pd 7 | 8 | import torch_frame 9 | from torch_frame.config.image_embedder import ImageEmbedderConfig 10 | 11 | 12 | class DiamondImages(torch_frame.data.Dataset): 13 | r"""The `Diamond Images 14 | `_ 15 | dataset from Kaggle. The target is to predict :obj:`colour` of each 16 | diamond. 17 | 18 | **STATS:** 19 | 20 | .. list-table:: 21 | :widths: 10 10 10 10 10 20 10 22 | :header-rows: 1 23 | 24 | * - #rows 25 | - #cols (numerical) 26 | - #cols (categorical) 27 | - #cols (image) 28 | - #classes 29 | - Task 30 | - Missing value ratio 31 | * - 48,764 32 | - 4 33 | - 7 34 | - 1 35 | - 23 36 | - multiclass_classification 37 | - 0.167% 38 | """ 39 | 40 | url = 'https://data.pyg.org/datasets/tables/diamond.zip' 41 | 42 | def __init__( 43 | self, 44 | root: str, 45 | col_to_image_embedder_cfg: ImageEmbedderConfig 46 | | dict[str, ImageEmbedderConfig], 47 | ): 48 | path = self.download_url(self.url, root) 49 | 50 | folder_path = osp.dirname(path) 51 | with zipfile.ZipFile(path, "r") as zip_ref: 52 | zip_ref.extractall(folder_path) 53 | 54 | subfolder_path = osp.join(folder_path, "diamond") 55 | csv_path = osp.join(subfolder_path, "diamond_data.csv") 56 | df = pd.read_csv(csv_path) 57 | df = df.drop(columns=["stock_number"]) 58 | 59 | image_paths = [] 60 | for path_to_img in df["path_to_img"]: 61 | path_to_img = path_to_img.replace("web_scraped/", "") 62 | image_paths.append(osp.join(subfolder_path, path_to_img)) 63 | image_df = pd.DataFrame({"image_path": image_paths}) 64 | df = pd.concat([df, image_df], axis=1) 65 | df = df.drop(columns=["path_to_img"]) 66 | 67 | col_to_stype = { 68 | "shape": torch_frame.categorical, 69 | "carat": torch_frame.numerical, 70 | "clarity": torch_frame.categorical, 71 | "colour": torch_frame.categorical, 72 | "cut": torch_frame.categorical, 73 | "polish": torch_frame.categorical, 74 | "symmetry": torch_frame.categorical, 75 | "fluorescence": torch_frame.categorical, 76 | "lab": torch_frame.categorical, 77 | "length": torch_frame.numerical, 78 | "width": torch_frame.numerical, 79 | "depth": torch_frame.numerical, 80 | "image_path": torch_frame.image_embedded, 81 | } 82 | 83 | super().__init__(df, col_to_stype, target_col="colour", 84 | col_to_image_embedder_cfg=col_to_image_embedder_cfg) 85 | -------------------------------------------------------------------------------- /torch_frame/nn/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | from inspect import signature 5 | from typing import Any 6 | 7 | import torch 8 | 9 | 10 | class Module(torch.nn.Module): 11 | r"""A base class for defining modules in which attributes may be defined 12 | in a later stage. As such, users only need to define the dynamic 13 | hyperparameters of a module, and do not need to care about connecting the 14 | module to the underlying data, e.g., specifying the number of input or 15 | output channels. 16 | 17 | This is achieved by postponing submodule creation 18 | (via :meth:`init_modules`) until all attributes in :obj:`LAZY_ATTRS` are 19 | fully-specified. 20 | """ 21 | LAZY_ATTRS: set[str] = set() 22 | 23 | def init_modules(self): 24 | pass 25 | 26 | def __init__(self, *args, **kwargs): 27 | super().__init__() 28 | 29 | self._in_init = True 30 | self._missing_attrs = copy.copy(self.LAZY_ATTRS) 31 | 32 | for key, value in zip( 33 | signature(self.__init__).parameters, args, strict=False): 34 | setattr(self, key, value) 35 | for key, value in kwargs.items(): 36 | setattr(self, key, value) 37 | 38 | self._in_init = False 39 | 40 | if self.is_fully_specified: 41 | self._init_modules() 42 | 43 | def __setattr__(self, key: str, value: Any): 44 | super().__setattr__(key, value) 45 | if value is not None and key in getattr(self, '_missing_attrs', {}): 46 | self._missing_attrs.remove(key) 47 | if not self._in_init and self.is_fully_specified: 48 | self._init_modules() 49 | 50 | @property 51 | def is_fully_specified(self) -> bool: 52 | return len(self._missing_attrs) == 0 53 | 54 | def validate(self): 55 | if len(self._missing_attrs) > 0: 56 | raise ValueError( 57 | f"The '{self.__class__.__name__}' module is not fully-" 58 | f"specified yet. It is missing the following attribute(s): " 59 | f"{self._missing_attrs}. Please specify them before using " 60 | f"this module in a deep learning pipeline.") 61 | 62 | def _init_modules(self): 63 | self.validate() 64 | self.init_modules() 65 | 66 | def _apply(self, *args, **kwargs) -> Any: 67 | self.validate() 68 | return super()._apply(*args, **kwargs) 69 | 70 | def named_parameters(self, *args, **kwargs) -> Any: 71 | self.validate() 72 | return super().named_parameters(*args, **kwargs) 73 | 74 | def named_children(self) -> Any: 75 | self.validate() 76 | return super().named_children() 77 | 78 | def named_modules(self, *args, **kwargs) -> Any: 79 | self.validate() 80 | return super().named_modules(*args, **kwargs) 81 | 82 | def __call__(self, *args, **kwargs) -> Any: 83 | self.validate() 84 | return super().__call__(*args, **kwargs) 85 | -------------------------------------------------------------------------------- /torch_frame/datasets/amphibians.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import zipfile 3 | 4 | import pandas as pd 5 | 6 | import torch_frame 7 | 8 | 9 | class Amphibians(torch_frame.data.Dataset): 10 | r"""The `Amphibians 11 | `_ 12 | dataset. The task is to predict which of the 7 frogs types appeared 13 | in the habitat. 14 | 15 | **STATS:** 16 | 17 | .. list-table:: 18 | :widths: 10 10 10 10 20 10 19 | :header-rows: 1 20 | 21 | * - #rows 22 | - #cols (numerical) 23 | - #cols (categorical) 24 | - #cols (text_embedded) 25 | - Task 26 | - Missing value ratio 27 | * - 189 28 | - 3 29 | - 20 30 | - 0 31 | - multilabel classification 32 | - 0.0% 33 | """ 34 | url = 'https://archive.ics.uci.edu/static/public/528/amphibians.zip' 35 | 36 | def __init__(self, root: str): 37 | path = self.download_url(self.url, root) 38 | folder_path = osp.dirname(path) 39 | 40 | with zipfile.ZipFile(path, 'r') as zip_ref: 41 | zip_ref.extractall(folder_path) 42 | 43 | data_path = osp.join(folder_path, 'dataset.csv') 44 | names = [ 45 | 'ID', 'MV', 'SR', 'NR', 'TR', 'VR', 'SUR1', 'SUR2', 'SUR3', 'UR', 46 | 'FR', 'OR', 'RR', 'BR', 'MR', 'CR', 't1', 't2', 't3', 't4', 't5', 47 | 't6', 't7' 48 | ] 49 | df = pd.read_csv(data_path, names=names, sep=';') 50 | # Drop the first 2 rows containing metadata 51 | df = df.iloc[2:].reset_index(drop=True) 52 | target_cols = ['t1', 't2', 't3', 't4', 't5', 't6', 't7'] 53 | df['t'] = df.apply( 54 | lambda row: [col for col in target_cols if row[col] == '1'], 55 | axis=1) 56 | df = df.drop(target_cols, axis=1) 57 | 58 | # Infer the pandas dataframe automatically 59 | path = osp.join(root, 'amphibians_posprocess.csv') 60 | df.to_csv(path, index=False) 61 | df = pd.read_csv(path) 62 | 63 | col_to_stype = { 64 | 'ID': torch_frame.numerical, 65 | 'MV': torch_frame.categorical, 66 | 'SR': torch_frame.numerical, 67 | 'NR': torch_frame.numerical, 68 | 'TR': torch_frame.categorical, 69 | 'VR': torch_frame.categorical, 70 | 'SUR1': torch_frame.categorical, 71 | 'SUR2': torch_frame.categorical, 72 | 'SUR3': torch_frame.categorical, 73 | 'UR': torch_frame.categorical, 74 | 'FR': torch_frame.categorical, 75 | 'OR': torch_frame.numerical, 76 | 'RR': torch_frame.categorical, # Support Ordinal Encoding 77 | 'BR': torch_frame.categorical, # Support Ordinal Encoding 78 | 'MR': torch_frame.categorical, 79 | 'CR': torch_frame.categorical, 80 | 't': torch_frame.multicategorical, 81 | } 82 | super().__init__(df, col_to_stype, target_col='t', col_to_sep=None) 83 | -------------------------------------------------------------------------------- /torch_frame/datasets/movielens_1m.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os.path as osp 4 | import zipfile 5 | 6 | import pandas as pd 7 | 8 | import torch_frame 9 | from torch_frame.config.text_embedder import TextEmbedderConfig 10 | 11 | 12 | class Movielens1M(torch_frame.data.Dataset): 13 | r"""The MovieLens 1M rating dataset, assembled by GroupLens Research 14 | from the MovieLens web site, consisting of movies (3,883 nodes) and 15 | users (6,040 nodes) with approximately 1 million ratings between them. 16 | 17 | **STATS:** 18 | 19 | .. list-table:: 20 | :widths: 10 10 10 10 20 21 | :header-rows: 1 22 | 23 | * - #Users 24 | - #Items 25 | - #User Field 26 | - #Item Field 27 | - #Samples 28 | * - 6040 29 | - 3952 30 | - 5 31 | - 3 32 | - 1000209 33 | """ 34 | 35 | url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip' 36 | 37 | def __init__( 38 | self, 39 | root: str, 40 | col_to_text_embedder_cfg: dict[str, TextEmbedderConfig] 41 | | TextEmbedderConfig | None = None, 42 | ): 43 | path = self.download_url(self.url, root) 44 | folder_path = osp.dirname(path) 45 | 46 | with zipfile.ZipFile(path, 'r') as zip_ref: 47 | zip_ref.extractall(folder_path) 48 | 49 | data_path = osp.join(folder_path, 'ml-1m') 50 | users = pd.read_csv( 51 | osp.join(data_path, 'users.dat'), 52 | header=None, 53 | names=['user_id', 'gender', 'age', 'occupation', 'zip'], 54 | sep='::', 55 | engine='python', 56 | ) 57 | movies = pd.read_csv( 58 | osp.join(data_path, 'movies.dat'), 59 | header=None, 60 | names=['movie_id', 'title', 'genres'], 61 | sep='::', 62 | engine='python', 63 | encoding='ISO-8859-1', 64 | ) 65 | ratings = pd.read_csv( 66 | osp.join(data_path, 'ratings.dat'), 67 | header=None, 68 | names=['user_id', 'movie_id', 'rating', 'timestamp'], 69 | sep='::', 70 | engine='python', 71 | ) 72 | 73 | df = pd.merge(pd.merge(ratings, users), movies) \ 74 | .sort_values(by='timestamp') \ 75 | .reset_index().drop('index', axis=1) 76 | 77 | col_to_stype = { 78 | 'user_id': torch_frame.categorical, 79 | 'gender': torch_frame.categorical, 80 | 'age': torch_frame.categorical, 81 | 'occupation': torch_frame.categorical, 82 | 'zip': torch_frame.categorical, 83 | 'movie_id': torch_frame.categorical, 84 | 'title': torch_frame.text_embedded, 85 | 'genres': torch_frame.multicategorical, 86 | 'rating': torch_frame.numerical, 87 | 'timestamp': torch_frame.timestamp, 88 | } 89 | super().__init__(df, col_to_stype, target_col='rating', col_to_sep='|', 90 | col_to_text_embedder_cfg=col_to_text_embedder_cfg) 91 | -------------------------------------------------------------------------------- /torch_frame/testing/text_tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torch_frame.data import MultiNestedTensor 7 | from torch_frame.typing import TextTokenizationOutputs 8 | 9 | 10 | class WhiteSpaceHashTokenizer: 11 | r"""A simple white space tokenizer for testing purposes. 12 | It split sentence to tokens via white space and hashes each 13 | token into an index modulo :obj:`num_hash`. 14 | 15 | Args: 16 | num_hash_bins (int): Number of hash bins to use. 17 | (default: :obj:`64`) 18 | device (torch.device, optional): The device to put tokens. 19 | (default: :obj:`None`) 20 | batched (bool): Whether to tokenize in a batched format. 21 | If :obj:`True`, tokenizer returns Mapping[str, 2dim-Tensor], 22 | else List[Mapping[str, 1dim-Tensor]]. (default: :obj:`False`) 23 | """ 24 | def __init__( 25 | self, 26 | num_hash_bins: int = 64, 27 | device: torch.device | None = None, 28 | batched: bool = False, 29 | ): 30 | self.device = device 31 | self.num_hash_bins = num_hash_bins 32 | self.batched = batched 33 | 34 | def __call__(self, sentences: list[str]) -> TextTokenizationOutputs: 35 | input_ids = [] 36 | attention_mask = [] 37 | for s in sentences: 38 | tokens = s.split(' ') 39 | idx = torch.tensor([hash(t) % self.num_hash_bins for t in tokens]) 40 | input_ids.append(idx) 41 | attention_mask.append(torch.ones(idx.shape, dtype=torch.bool)) 42 | 43 | if self.batched: 44 | max_length = max(t.size(0) for t in input_ids) 45 | padded_input_ids = [ 46 | F.pad(t, (0, max_length - t.size(0)), value=-1) 47 | for t in input_ids 48 | ] 49 | input_ids = torch.stack(padded_input_ids) 50 | padded_attention_mask = [ 51 | F.pad(t, (0, max_length - t.size(0)), value=False) 52 | for t in attention_mask 53 | ] 54 | attention_mask = torch.stack(padded_attention_mask) 55 | return {'input_ids': input_ids, 'attention_mask': attention_mask} 56 | else: 57 | return [{ 58 | 'input_ids': input_ids[i], 59 | 'attention_mask': attention_mask[i] 60 | } for i in range(len(sentences))] 61 | 62 | 63 | class RandomTextModel(torch.nn.Module): 64 | r"""A text embedding model that takes the tokenized input from 65 | :class:`WhiteSpaceHashTokenizer` and outputs random embeddings. Should be 66 | used only for testing purposes. 67 | """ 68 | def __init__(self, text_emb_channels: int): 69 | self.text_emb_channels = text_emb_channels 70 | super().__init__() 71 | 72 | def forward(self, feat: dict[str, MultiNestedTensor]): 73 | input_ids = feat['input_ids'].to_dense(fill_value=0) 74 | _ = feat['attention_mask'].to_dense(fill_value=0) 75 | return torch.rand(size=(input_ids.shape[0], 1, self.text_emb_channels)) 76 | -------------------------------------------------------------------------------- /examples/tuned_gbdt.py: -------------------------------------------------------------------------------- 1 | """Reported (reproduced) results of Tuned XGBoost on TabularBenchmark of 2 | the Trompt paper https://arxiv.org/abs/2305.18446. 3 | 4 | electricity (A4): 88.52 (91.09) 5 | eye_movements (A5): 66.57 (64.21) 6 | MagicTelescope (B2): 86.05 (86.50) 7 | bank-marketing (B4): 80.34 (80.41) 8 | california (B5): 90.12 (89.71) 9 | credit (B7): 77.26 (77.4) 10 | pol (B14): 98.09 (97.5) 11 | jannis (mathcal B4): 79.67 (77.81) 12 | 13 | Reported (reproduced) results of Tuned CatBoost on TabularBenchmark of 14 | the Trompt paper: https://arxiv.org/abs/2305.18446 15 | 16 | electricity (A4): 87.73 (88.09) 17 | eye_movements (A5): 66.84 (64.27) 18 | MagicTelescope (B2): 85.92 (87.18) 19 | bank-marketing (B4): 80.39 (80.50) 20 | california (B5): 90.32 (87.56) 21 | credit (B7): 77.59 (77.29) 22 | pol (B14): 98.49 (98.21) 23 | jannis (mathcal B4): 79.89 (78.96) 24 | """ 25 | import argparse 26 | import os.path as osp 27 | import random 28 | 29 | import numpy as np 30 | import torch 31 | 32 | from torch_frame.datasets import TabularBenchmark 33 | from torch_frame.gbdt import CatBoost, LightGBM, XGBoost 34 | from torch_frame.typing import Metric 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--gbdt_type', type=str, default='xgboost', 38 | choices=['xgboost', 'catboost', 'lightgbm']) 39 | parser.add_argument('--dataset', type=str, default='eye_movements') 40 | parser.add_argument('--saved_model_path', type=str, 41 | default='storage/gbdts.txt') 42 | # Add this flag to match the reported number. 43 | parser.add_argument('--seed', type=int, default=0) 44 | args = parser.parse_args() 45 | 46 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 47 | 48 | random.seed(args.seed) 49 | np.random.seed(args.seed) 50 | torch.manual_seed(args.seed) 51 | 52 | # Prepare datasets 53 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 54 | args.dataset) 55 | dataset = TabularBenchmark(root=path, name=args.dataset) 56 | dataset.materialize() 57 | dataset = dataset.shuffle() 58 | # Split ratio following https://arxiv.org/abs/2207.08815 59 | # 70% is used for training. 30% of the remaining is used for validation. 60 | # The final reminder is used for testing. 61 | train_dataset, val_dataset, test_dataset = dataset[:0.7], dataset[ 62 | 0.7:0.79], dataset[0.79:] 63 | 64 | num_classes = None 65 | metric = None 66 | task_type = dataset.task_type 67 | if dataset.task_type.is_classification: 68 | metric = Metric.ACCURACY 69 | num_classes = dataset.num_classes 70 | else: 71 | metric = Metric.RMSE 72 | num_classes = None 73 | 74 | gbdt_cls_dict = { 75 | 'xgboost': XGBoost, 76 | 'catboost': CatBoost, 77 | 'lightgbm': LightGBM, 78 | } 79 | gbdt = gbdt_cls_dict[args.gbdt_type]( 80 | task_type=task_type, 81 | num_classes=num_classes, 82 | metric=metric, 83 | ) 84 | 85 | if osp.exists(args.saved_model_path): 86 | gbdt.load(args.saved_model_path) 87 | else: 88 | gbdt.tune(tf_train=train_dataset.tensor_frame, 89 | tf_val=val_dataset.tensor_frame, num_trials=20) 90 | gbdt.save(args.saved_model_path) 91 | 92 | pred = gbdt.predict(tf_test=test_dataset.tensor_frame) 93 | score = gbdt.compute_metric(test_dataset.tensor_frame.y, pred) 94 | print(f"{gbdt.metric} : {score}") 95 | -------------------------------------------------------------------------------- /test/nn/test_simple_basecls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | import torch_frame 5 | from torch_frame import TensorFrame, stype 6 | from torch_frame.nn import Decoder, FeatureEncoder, TableConv 7 | 8 | 9 | def test_simple_basecls(): 10 | # Instantiate each base class with a simple class and test e2e pipeline. 11 | class SimpleFeatureEncoder(FeatureEncoder): 12 | def __init__( 13 | self, 14 | out_channels: int, 15 | num_numerical: int, 16 | num_categories: list[int], 17 | ): 18 | super().__init__() 19 | 20 | self.out_channels = out_channels 21 | self.num_numerical = num_numerical 22 | self.num_categories = num_categories 23 | 24 | self.lins = torch.nn.ModuleList([ 25 | torch.nn.Linear(1, out_channels) for _ in range(num_numerical) 26 | ]) 27 | self.embs = torch.nn.ModuleList([ 28 | torch.nn.Embedding(num_category, out_channels) 29 | for num_category in num_categories 30 | ]) 31 | 32 | def forward(self, tf: TensorFrame) -> tuple[Tensor, list[str]]: 33 | xs = [] 34 | for i, lin in enumerate(self.lins): 35 | xs.append(lin(tf.feat_dict[torch_frame.numerical][:, i:i + 1])) 36 | for i, emb in enumerate(self.embs): 37 | xs.append(emb(tf.feat_dict[torch_frame.categorical][:, i])) 38 | 39 | x = torch.stack(xs, dim=1) 40 | col_names = (tf.col_names_dict[stype.numerical] + 41 | tf.col_names_dict[stype.categorical]) 42 | 43 | return x, col_names 44 | 45 | class SimpleTableConv(TableConv): 46 | def __init__(self, in_channels: int, out_channels: int): 47 | super().__init__() 48 | self.lin = torch.nn.Linear(in_channels, out_channels) 49 | 50 | def forward(self, x: Tensor) -> Tensor: 51 | B, C, H = x.shape 52 | x = x.view(-1, H) 53 | return self.lin(x).view(B, C, -1) 54 | 55 | class SimpleDecoder(Decoder): 56 | def forward(self, x: Tensor) -> Tensor: 57 | # Pool along the column axis 58 | return torch.mean(x, dim=1) 59 | 60 | tf = TensorFrame( 61 | feat_dict={ 62 | torch_frame.numerical: torch.randn(10, 2), 63 | torch_frame.categorical: torch.randint(0, 5, (10, 2)), 64 | }, 65 | col_names_dict={ 66 | torch_frame.numerical: ['num1', 'num2'], 67 | torch_frame.categorical: ['cat1', 'cat2'], 68 | }, 69 | ) 70 | 71 | feat_encoder = SimpleFeatureEncoder( 72 | out_channels=8, 73 | num_numerical=2, 74 | num_categories=[5, 5], 75 | ) 76 | table_conv1 = SimpleTableConv(in_channels=8, out_channels=16) 77 | table_conv2 = SimpleTableConv(in_channels=16, out_channels=8) 78 | decoder = SimpleDecoder() 79 | 80 | x, col_names = feat_encoder(tf) 81 | # [batch_size, num_cols, hidden_channels] 82 | assert x.shape == (10, 4, 8) 83 | assert col_names == ['num1', 'num2', 'cat1', 'cat2'] 84 | x = table_conv1(x) 85 | assert x.shape == (10, 4, 16) 86 | x = table_conv2(x) 87 | assert x.shape == (10, 4, 8) 88 | x = decoder(x) 89 | assert x.shape == (10, 8) 90 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections.abc import Callable 3 | 4 | import pytest 5 | import torch 6 | 7 | import torch_frame 8 | from torch_frame import TensorFrame 9 | from torch_frame.data import MultiEmbeddingTensor, MultiNestedTensor 10 | 11 | 12 | @pytest.fixture() 13 | def get_fake_tensor_frame() -> Callable: 14 | def _get_fake_tensor_frame(num_rows: int) -> TensorFrame: 15 | col_names_dict = { 16 | torch_frame.categorical: ['cat_1', 'cat_2', 'cat_3'], 17 | torch_frame.numerical: ['num_1', 'num_2'], 18 | torch_frame.multicategorical: ['multicat_1', 'multicat_2'], 19 | torch_frame.text_embedded: 20 | ['text_embedded_1', 'text_embedded_2', 'text_embedded_3'], 21 | torch_frame.text_tokenized: 22 | ['text_tokenized_1', 'text_tokenized_2'], 23 | torch_frame.sequence_numerical: ['seq_num_1', 'seq_num_2'], 24 | torch_frame.embedding: ['emb_1', 'emb_2'], 25 | } 26 | feat_dict = { 27 | torch_frame.categorical: 28 | torch.randint( 29 | 0, 3, 30 | size=(num_rows, len(col_names_dict[torch_frame.categorical]))), 31 | torch_frame.numerical: 32 | torch.randn(size=(num_rows, 33 | len(col_names_dict[torch_frame.numerical]))), 34 | torch_frame.multicategorical: 35 | MultiNestedTensor.from_tensor_mat([[ 36 | torch.arange(random.randint(0, 10)) for _ in range( 37 | len(col_names_dict[torch_frame.multicategorical])) 38 | ] for _ in range(num_rows)]), 39 | torch_frame.text_embedded: 40 | MultiEmbeddingTensor.from_tensor_list([ 41 | torch.randn(num_rows, random.randint(1, 5)) 42 | for _ in range(len(col_names_dict[torch_frame.text_embedded])) 43 | ]), 44 | torch_frame.text_tokenized: { 45 | 'input_id': 46 | MultiNestedTensor.from_tensor_mat([[ 47 | torch.randint(0, 5, size=(random.randint(0, 10), )) 48 | for _ in range( 49 | len(col_names_dict[torch_frame.text_tokenized])) 50 | ] for _ in range(num_rows)]), 51 | 'mask': 52 | MultiNestedTensor.from_tensor_mat([[ 53 | torch.randint(0, 5, size=(random.randint(0, 10), )) 54 | for _ in range( 55 | len(col_names_dict[torch_frame.text_tokenized])) 56 | ] for _ in range(num_rows)]), 57 | }, 58 | torch_frame.sequence_numerical: 59 | MultiNestedTensor.from_tensor_mat([[ 60 | torch.randn(random.randint(0, 10)) for _ in range( 61 | len(col_names_dict[torch_frame.sequence_numerical])) 62 | ] for _ in range(num_rows)]), 63 | torch_frame.embedding: 64 | MultiEmbeddingTensor.from_tensor_list([ 65 | torch.randn(num_rows, random.randint(1, 5)) 66 | for _ in range(len(col_names_dict[torch_frame.embedding])) 67 | ]) 68 | } 69 | 70 | y = torch.randn(num_rows) 71 | 72 | return TensorFrame( 73 | feat_dict=feat_dict, 74 | col_names_dict=col_names_dict, 75 | y=y, 76 | ) 77 | 78 | return _get_fake_tensor_frame 79 | -------------------------------------------------------------------------------- /test/gbdt/test_gbdt.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import tempfile 3 | 4 | import pytest 5 | import torch 6 | 7 | from torch_frame import Metric, TaskType, stype 8 | from torch_frame.config.text_embedder import TextEmbedderConfig 9 | from torch_frame.data.dataset import Dataset 10 | from torch_frame.datasets.fake import FakeDataset 11 | from torch_frame.gbdt import CatBoost, LightGBM, XGBoost 12 | from torch_frame.testing.text_embedder import HashTextEmbedder 13 | 14 | 15 | @pytest.mark.parametrize('gbdt_cls', [ 16 | CatBoost, 17 | XGBoost, 18 | LightGBM, 19 | ]) 20 | @pytest.mark.parametrize('stypes', [ 21 | [stype.numerical], 22 | [stype.categorical], 23 | [stype.text_embedded], 24 | [stype.numerical, stype.numerical, stype.text_embedded], 25 | ]) 26 | @pytest.mark.parametrize('task_type_and_metric', [ 27 | (TaskType.REGRESSION, Metric.RMSE), 28 | (TaskType.REGRESSION, Metric.MAE), 29 | (TaskType.BINARY_CLASSIFICATION, Metric.ACCURACY), 30 | (TaskType.BINARY_CLASSIFICATION, Metric.ROCAUC), 31 | (TaskType.MULTICLASS_CLASSIFICATION, Metric.ACCURACY), 32 | ]) 33 | def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric): 34 | task_type, metric = task_type_and_metric 35 | dataset: Dataset = FakeDataset( 36 | num_rows=30, 37 | with_nan=True, 38 | stypes=stypes, 39 | create_split=True, 40 | task_type=task_type, 41 | col_to_text_embedder_cfg=TextEmbedderConfig( 42 | text_embedder=HashTextEmbedder(8)), 43 | ) 44 | dataset.materialize() 45 | gbdt = gbdt_cls( 46 | task_type=task_type, 47 | num_classes=dataset.num_classes 48 | if task_type == TaskType.MULTICLASS_CLASSIFICATION else None, 49 | metric=metric, 50 | ) 51 | 52 | with tempfile.TemporaryDirectory() as temp_dir: 53 | path = osp.join(temp_dir, 'model.json') 54 | with pytest.raises(RuntimeError, match="is not yet fitted"): 55 | gbdt.save(path) 56 | 57 | if isinstance(gbdt_cls, XGBoost): 58 | gbdt.tune(tf_train=dataset.tensor_frame, 59 | tf_val=dataset.tensor_frame, num_trials=2, 60 | num_boost_round=1000, early_stopping_rounds=2) 61 | assert gbdt.model.best_iteration is not None 62 | else: 63 | gbdt.tune( 64 | tf_train=dataset.tensor_frame, 65 | tf_val=dataset.tensor_frame, 66 | num_trials=2, 67 | num_boost_round=2, 68 | ) 69 | gbdt.save(path) 70 | 71 | loaded_gbdt = gbdt_cls( 72 | task_type=task_type, 73 | num_classes=dataset.num_classes 74 | if task_type == TaskType.MULTICLASS_CLASSIFICATION else None, 75 | metric=metric, 76 | ) 77 | loaded_gbdt.load(path) 78 | 79 | pred = gbdt.predict(tf_test=dataset.tensor_frame) 80 | score = gbdt.compute_metric(dataset.tensor_frame.y, pred) 81 | 82 | loaded_score = loaded_gbdt.compute_metric(dataset.tensor_frame.y, pred) 83 | dataset.tensor_frame.y = None 84 | loaded_pred = loaded_gbdt.predict(tf_test=dataset.tensor_frame) 85 | 86 | assert torch.allclose(pred, loaded_pred, atol=1e-5) 87 | assert gbdt.metric == metric 88 | assert score == loaded_score 89 | if task_type == TaskType.REGRESSION: 90 | assert (score >= 0) 91 | elif task_type == TaskType.BINARY_CLASSIFICATION: 92 | assert (0 <= score <= 1) 93 | elif task_type == TaskType.MULTICLASS_CLASSIFICATION: 94 | assert (0 <= score <= 1) 95 | -------------------------------------------------------------------------------- /torch_frame/datasets/mushroom.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import zipfile 3 | 4 | import pandas as pd 5 | 6 | import torch_frame 7 | 8 | 9 | class Mushroom(torch_frame.data.Dataset): 10 | r"""The `Mushroom classification Kaggle competition 11 | `_ 12 | dataset. It's a task to predict whether a mushroom is edible 13 | or poisonous. 14 | 15 | **STATS:** 16 | 17 | .. list-table:: 18 | :widths: 10 10 10 10 20 10 19 | :header-rows: 1 20 | 21 | * - #rows 22 | - #cols (numerical) 23 | - #cols (categorical) 24 | - #classes 25 | - Task 26 | - Missing value ratio 27 | * - 8,124 28 | - 0 29 | - 22 30 | - 2 31 | - binary_classification 32 | - 0.0% 33 | """ 34 | 35 | url = 'http://archive.ics.uci.edu/static/public/73/mushroom.zip' 36 | 37 | def __init__(self, root: str): 38 | path = self.download_url(self.url, root) 39 | folder_path = osp.dirname(path) 40 | 41 | with zipfile.ZipFile(path, 'r') as zip_ref: 42 | zip_ref.extractall(folder_path) 43 | 44 | data_path = osp.join(folder_path, 'agaricus-lepiota.data') 45 | 46 | names = [ 47 | 'class', 48 | 'cap-shape', 49 | 'cap-surface', 50 | 'cap-color', 51 | 'bruises', 52 | 'odor', 53 | 'gill-attachment', 54 | 'gill-spacing', 55 | 'gill-size', 56 | 'gill-color', 57 | 'stalk-shape', 58 | 'stalk-root', 59 | 'stalk-surface-above-ring', 60 | 'stalk-surface-below-ring', 61 | 'stalk-color-above-ring', 62 | 'stalk-color-below-ring', 63 | 'veil-type', 64 | 'veil-color', 65 | 'ring-number', 66 | 'ring-type', 67 | 'spore-print-color', 68 | 'population', 69 | 'habitat', 70 | ] 71 | df = pd.read_csv(data_path, names=names) 72 | 73 | col_to_stype = { 74 | 'class': torch_frame.categorical, 75 | 'cap-shape': torch_frame.categorical, 76 | 'cap-surface': torch_frame.categorical, 77 | 'cap-color': torch_frame.categorical, 78 | 'bruises': torch_frame.categorical, 79 | 'odor': torch_frame.categorical, 80 | 'gill-attachment': torch_frame.categorical, 81 | 'gill-spacing': torch_frame.categorical, 82 | 'gill-size': torch_frame.categorical, 83 | 'gill-color': torch_frame.categorical, 84 | 'stalk-shape': torch_frame.categorical, 85 | 'stalk-root': torch_frame.categorical, 86 | 'stalk-surface-above-ring': torch_frame.categorical, 87 | 'stalk-surface-below-ring': torch_frame.categorical, 88 | 'stalk-color-above-ring': torch_frame.categorical, 89 | 'stalk-color-below-ring': torch_frame.categorical, 90 | 'veil-type': torch_frame.categorical, 91 | 'veil-color': torch_frame.categorical, 92 | 'ring-number': torch_frame.categorical, 93 | 'ring-type': torch_frame.categorical, 94 | 'spore-print-color': torch_frame.categorical, 95 | 'population': torch_frame.categorical, 96 | 'habitat': torch_frame.categorical, 97 | } 98 | 99 | super().__init__(df, col_to_stype, target_col='class') 100 | -------------------------------------------------------------------------------- /torch_frame/nn/conv/ft_transformer_convs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import ( 6 | LayerNorm, 7 | Parameter, 8 | TransformerEncoder, 9 | TransformerEncoderLayer, 10 | ) 11 | 12 | from torch_frame.nn.conv import TableConv 13 | 14 | 15 | class FTTransformerConvs(TableConv): 16 | r"""The FT-Transformer backbone in the 17 | `"Revisiting Deep Learning Models for Tabular Data" 18 | `_ paper. 19 | 20 | This module concatenates a learnable CLS token embedding :obj:`x_cls` to 21 | the input tensor :obj:`x` and applies a multi-layer Transformer on the 22 | concatenated tensor. After the Transformer layer, the output tensor is 23 | divided into two parts: (1) :obj:`x`, corresponding to the original input 24 | tensor, and (2) :obj:`x_cls`, corresponding to the CLS token tensor. 25 | 26 | Args: 27 | channels (int): Input/output channel dimensionality 28 | feedforward_channels (int, optional): Hidden channels used by 29 | feedforward network of the Transformer model. If :obj:`None`, it 30 | will be set to :obj:`channels` (default: :obj:`None`) 31 | num_layers (int): Number of transformer encoder layers. (default: 3) 32 | nhead (int): Number of heads in multi-head attention (default: 8) 33 | dropout (int): The dropout value (default: 0.1) 34 | activation (str): The activation function (default: :obj:`relu`) 35 | """ 36 | def __init__( 37 | self, 38 | channels: int, 39 | feedforward_channels: int | None = None, 40 | # Arguments for Transformer 41 | num_layers: int = 3, 42 | nhead: int = 8, 43 | dropout: float = 0.2, 44 | activation: str = 'relu', 45 | ): 46 | super().__init__() 47 | 48 | encoder_layer = TransformerEncoderLayer( 49 | d_model=channels, 50 | nhead=nhead, 51 | dim_feedforward=feedforward_channels or channels, 52 | dropout=dropout, 53 | activation=activation, 54 | # Input and output tensors are provided as 55 | # [batch_size, seq_len, channels] 56 | batch_first=True, 57 | ) 58 | encoder_norm = LayerNorm(channels) 59 | self.transformer = TransformerEncoder(encoder_layer=encoder_layer, 60 | num_layers=num_layers, 61 | norm=encoder_norm) 62 | self.cls_embedding = Parameter(torch.empty(channels)) 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | torch.nn.init.normal_(self.cls_embedding, std=0.01) 67 | for p in self.transformer.parameters(): 68 | if p.dim() > 1: 69 | torch.nn.init.xavier_uniform_(p) 70 | 71 | def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: 72 | r"""CLS-token augmented Transformer convolution. 73 | 74 | Args: 75 | x (Tensor): Input tensor of shape [batch_size, num_cols, channels] 76 | 77 | Returns: 78 | (torch.Tensor, torch.Tensor): (Output tensor of shape 79 | [batch_size, num_cols, channels] corresponding to the input 80 | columns, Output tensor of shape [batch_size, channels], 81 | corresponding to the added CLS token column.) 82 | """ 83 | B, _, _ = x.shape 84 | # [batch_size, num_cols, channels] 85 | x_cls = self.cls_embedding.repeat(B, 1, 1) 86 | # [batch_size, num_cols + 1, channels] 87 | x_concat = torch.cat([x_cls, x], dim=1) 88 | # [batch_size, num_cols + 1, channels] 89 | x_concat = self.transformer(x_concat) 90 | x_cls, x = x_concat[:, 0, :], x_concat[:, 1:, :] 91 | return x, x_cls 92 | -------------------------------------------------------------------------------- /torch_frame/nn/encoder/stypewise_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import ModuleDict 8 | 9 | import torch_frame 10 | from torch_frame import TensorFrame 11 | from torch_frame.data.stats import StatType 12 | from torch_frame.nn.encoder import FeatureEncoder 13 | from torch_frame.nn.encoder.stype_encoder import StypeEncoder 14 | 15 | 16 | class StypeWiseFeatureEncoder(FeatureEncoder): 17 | r"""Feature encoder that transforms each stype tensor into embeddings and 18 | performs the final concatenation. 19 | 20 | Args: 21 | out_channels (int): Output dimensionality. 22 | col_stats 23 | (dict[str, dict[:class:`torch_frame.data.stats.StatType`, Any]]): 24 | A dictionary that maps column name into stats. Available as 25 | :obj:`dataset.col_stats`. 26 | col_names_dict (dict[:class:`torch_frame.stype`, list[str]]): A 27 | dictionary that maps stype to a list of column names. The column 28 | names are sorted based on the ordering that appear in 29 | :obj:`tensor_frame.feat_dict`. 30 | Available as :obj:`tensor_frame.col_names_dict`. 31 | stype_encoder_dict 32 | (dict[:class:`torch_frame.stype`, 33 | :class:`torch_frame.nn.encoder.StypeEncoder`]): 34 | A dictionary that maps :class:`torch_frame.stype` into 35 | :class:`torch_frame.nn.encoder.StypeEncoder` class. Only 36 | parent :class:`stypes ` are supported 37 | as keys. 38 | """ 39 | def __init__( 40 | self, 41 | out_channels: int, 42 | col_stats: dict[str, dict[StatType, Any]], 43 | col_names_dict: dict[torch_frame.stype, list[str]], 44 | stype_encoder_dict: dict[torch_frame.stype, StypeEncoder], 45 | ) -> None: 46 | super().__init__() 47 | 48 | self.col_stats = col_stats 49 | self.col_names_dict = col_names_dict 50 | self.encoder_dict = ModuleDict() 51 | for stype, stype_encoder in stype_encoder_dict.items(): 52 | if stype != stype.parent: 53 | if stype.parent in stype_encoder_dict: 54 | msg = ( 55 | f"You can delete this {stype} directly since encoder " 56 | f"for parent stype {stype.parent} is already declared." 57 | ) 58 | else: 59 | msg = (f"To resolve the issue, you can change the key from" 60 | f" {stype} to {stype.parent}.") 61 | raise ValueError(f"{stype} is an invalid stype to use in the " 62 | f"stype_encoder_dcit. {msg}") 63 | if stype not in stype_encoder.supported_stypes: 64 | raise ValueError( 65 | f"{stype_encoder} does not support encoding {stype}.") 66 | 67 | if stype in col_names_dict: 68 | stats_list = [ 69 | self.col_stats[col_name] 70 | for col_name in self.col_names_dict[stype] 71 | ] 72 | # Set lazy attributes 73 | stype_encoder.stype = stype 74 | stype_encoder.out_channels = out_channels 75 | stype_encoder.stats_list = stats_list 76 | self.encoder_dict[stype.value] = stype_encoder 77 | 78 | def forward(self, tf: TensorFrame) -> tuple[Tensor, list[str]]: 79 | all_col_names = [] 80 | xs = [] 81 | for stype in tf.stypes: 82 | feat = tf.feat_dict[stype] 83 | col_names = self.col_names_dict[stype] 84 | x = self.encoder_dict[stype.value](feat, col_names) 85 | xs.append(x) 86 | all_col_names.extend(col_names) 87 | x = torch.cat(xs, dim=1) 88 | return x, all_col_names 89 | -------------------------------------------------------------------------------- /torch_frame/transforms/fittable_base_transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | from abc import abstractmethod 5 | from typing import Any 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | from torch_frame import NAStrategy, TensorFrame 11 | from torch_frame.data.stats import StatType 12 | from torch_frame.transforms import BaseTransform 13 | 14 | 15 | class FittableBaseTransform(BaseTransform): 16 | r"""An abstract base class for writing fittable transforms. 17 | Fittable transforms must be fitted on training data before transform. 18 | """ 19 | def __init__(self): 20 | super().__init__() 21 | self._is_fitted: bool = False 22 | 23 | def __call__(self, tf: TensorFrame) -> TensorFrame: 24 | # Shallow-copy the data so that we prevent in-place data modification. 25 | return self.forward(copy.copy(tf)) 26 | 27 | @property 28 | def is_fitted(self) -> bool: 29 | r"""Whether the transform is already fitted.""" 30 | return self._is_fitted 31 | 32 | def _replace_nans(self, x: Tensor, na_strategy: NAStrategy): 33 | r"""Replace NaNs based on NAStrategy. 34 | 35 | Args: 36 | x (Tensor): Input :class:`Tensor` whose NaN 37 | values in categorical columns will be replaced. 38 | na_strategy (NAStrategy): The :class:`NAStrategy` used to 39 | replace NaN values. 40 | 41 | Returns: 42 | Tensor: Output :class:`Tensor` with NaN values 43 | replaced. 44 | """ 45 | x = x.clone() 46 | for col in range(x.size(1)): 47 | column_data = x[:, col] 48 | if na_strategy.is_numerical_strategy: 49 | nan_mask = torch.isnan(column_data) 50 | else: 51 | nan_mask = column_data < 0 52 | if nan_mask.all(): 53 | raise ValueError("Column contains only nan values.") 54 | if not nan_mask.any(): 55 | continue 56 | valid_data = column_data[~nan_mask] 57 | if na_strategy == NAStrategy.MEAN: 58 | fill_value = valid_data.mean() 59 | elif na_strategy in [NAStrategy.ZEROS, NAStrategy.MOST_FREQUENT]: 60 | fill_value = torch.tensor(0.) 61 | else: 62 | raise ValueError(f'{na_strategy} is not supported.') 63 | column_data[nan_mask] = fill_value 64 | return x 65 | 66 | def fit( 67 | self, 68 | tf: TensorFrame, 69 | col_stats: dict[str, dict[StatType, Any]], 70 | ): 71 | r"""Fit the transform with train data. 72 | 73 | Args: 74 | tf (TensorFrame): Input :class:`TensorFrame` object representing 75 | the training data. 76 | col_stats (Dict[str, Dict[StatType, Any]], optional): The column 77 | stats of the input :class:`TensorFrame`. 78 | """ 79 | self._fit(tf, col_stats) 80 | self._is_fitted = True 81 | 82 | def forward(self, tf: TensorFrame) -> TensorFrame: 83 | if not self.is_fitted: 84 | raise ValueError(f"'{self.__class__.__name__}' is not yet fitted ." 85 | f"Please run `fit()` first before attempting to " 86 | f"transform the TensorFrame.") 87 | 88 | transformed_tf = self._forward(tf) 89 | transformed_tf.validate() 90 | return transformed_tf 91 | 92 | @abstractmethod 93 | def _fit( 94 | self, 95 | tf: TensorFrame, 96 | col_stats: dict[str, dict[StatType, Any]], 97 | ): 98 | raise NotImplementedError 99 | 100 | @abstractmethod 101 | def _forward(self, tf: TensorFrame) -> TensorFrame: 102 | raise NotImplementedError 103 | 104 | def state_dict(self) -> dict[str, Any]: 105 | return self.__dict__ 106 | 107 | def load_state_dict(self, state_dict: dict[str, Any]): 108 | self.__dict__.update(state_dict) 109 | return self 110 | -------------------------------------------------------------------------------- /torch_frame/typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Mapping 4 | from enum import Enum 5 | from typing import TypeAlias 6 | 7 | import pandas as pd 8 | import torch 9 | from torch import Tensor 10 | 11 | from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor 12 | from torch_frame.data.multi_nested_tensor import MultiNestedTensor 13 | 14 | WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2 15 | WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4 16 | 17 | 18 | class Metric(Enum): 19 | r"""The metric. 20 | 21 | Attributes: 22 | ACCURACY: accuracy 23 | ROCAUC: rocauc 24 | RMSE: rmse 25 | MAE: mae 26 | """ 27 | ACCURACY = 'accuracy' 28 | ROCAUC = 'rocauc' 29 | RMSE = 'rmse' 30 | MAE = 'mae' 31 | R2 = 'r2' 32 | 33 | def supports_task_type(self, task_type: TaskType) -> bool: 34 | return self in task_type.supported_metrics 35 | 36 | 37 | class TaskType(Enum): 38 | r"""The type of the task. 39 | 40 | Attributes: 41 | REGRESSION: Regression task. 42 | MULTICLASS_CLASSIFICATION: Multi-class classification task. 43 | BINARY_CLASSIFICATION: Binary classification task. 44 | """ 45 | REGRESSION = 'regression' 46 | MULTICLASS_CLASSIFICATION = 'multiclass_classification' 47 | BINARY_CLASSIFICATION = 'binary_classification' 48 | MULTILABEL_CLASSIFICATION = 'multilabel_classification' 49 | 50 | @property 51 | def is_classification(self) -> bool: 52 | return self in (TaskType.BINARY_CLASSIFICATION, 53 | TaskType.MULTICLASS_CLASSIFICATION) 54 | 55 | @property 56 | def is_regression(self) -> bool: 57 | return self == TaskType.REGRESSION 58 | 59 | @property 60 | def supported_metrics(self) -> list[Metric]: 61 | if self == TaskType.REGRESSION: 62 | return [Metric.RMSE, Metric.MAE, Metric.R2] 63 | elif self == TaskType.BINARY_CLASSIFICATION: 64 | return [Metric.ACCURACY, Metric.ROCAUC] 65 | elif self == TaskType.MULTICLASS_CLASSIFICATION: 66 | return [Metric.ACCURACY] 67 | else: 68 | return [] 69 | 70 | 71 | class NAStrategy(Enum): 72 | r"""Strategy for dealing with NaN values in columns. 73 | 74 | Attributes: 75 | MEAN: Replaces NaN values with the mean of a 76 | :obj:`torch_frame.numerical` column. 77 | ZEROS: Replaces NaN values with zeros in a 78 | :obj:`torch_frame.numerical` column. 79 | MOST_FREQUENT: Replaces NaN values with the most frequent category of a 80 | :obj:`torch_frame.categorical` column. 81 | """ 82 | MEAN = 'mean' 83 | MOST_FREQUENT = 'most_frequent' 84 | ZEROS = 'zeros' 85 | OLDEST_TIMESTAMP = 'oldest_timestamp' 86 | NEWEST_TIMESTAMP = 'newest_timestamp' 87 | MEDIAN_TIMESTAMP = 'median_timestamp' 88 | 89 | @property 90 | def is_categorical_strategy(self) -> bool: 91 | return self == NAStrategy.MOST_FREQUENT 92 | 93 | @property 94 | def is_multicategorical_strategy(self) -> bool: 95 | return self == NAStrategy.ZEROS 96 | 97 | @property 98 | def is_numerical_strategy(self) -> bool: 99 | return self in [NAStrategy.MEAN, NAStrategy.ZEROS] 100 | 101 | @property 102 | def is_timestamp_strategy(self) -> bool: 103 | return self in [ 104 | NAStrategy.NEWEST_TIMESTAMP, 105 | NAStrategy.OLDEST_TIMESTAMP, 106 | NAStrategy.MEDIAN_TIMESTAMP, 107 | ] 108 | 109 | 110 | Series: TypeAlias = pd.Series 111 | DataFrame: TypeAlias = pd.DataFrame 112 | 113 | IndexSelectType: TypeAlias = int | list[int] | range | slice | Tensor 114 | ColumnSelectType: TypeAlias = str | list[str] 115 | TextTokenizationMapping: TypeAlias = Mapping[str, Tensor] 116 | TextTokenizationOutputs: TypeAlias = \ 117 | list[TextTokenizationMapping] | TextTokenizationMapping 118 | TensorData: TypeAlias = (Tensor | MultiNestedTensor | MultiEmbeddingTensor 119 | | dict[str, MultiNestedTensor]) 120 | -------------------------------------------------------------------------------- /torch_frame/_stype.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class stype(Enum): 5 | r"""The semantic type of a column. 6 | 7 | A semantic type denotes the semantic meaning of a column, and denotes how 8 | columns are encoded into an embedding space within tabular deep learning 9 | models: 10 | 11 | .. code-block:: python 12 | 13 | import torch_frame 14 | 15 | stype = torch_frame.numerical # Numerical columns 16 | stype = torch_frame.categorical # Categorical columns 17 | ... 18 | 19 | Attributes: 20 | numerical: Numerical columns. 21 | categorical: Categorical columns. 22 | text_embedded: Pre-computed embeddings of text columns. 23 | text_tokenized: Tokenized text columns for finetuning. 24 | multicategorical: Multicategorical columns. 25 | sequence_numerical: Sequence of numerical values. 26 | embedding: Embedding columns. 27 | timestamp: Timestamp columns. 28 | image_embedded: Pre-computed embeddings of image columns. 29 | """ 30 | numerical = 'numerical' 31 | categorical = 'categorical' 32 | text_embedded = 'text_embedded' 33 | text_tokenized = 'text_tokenized' 34 | multicategorical = 'multicategorical' 35 | sequence_numerical = 'sequence_numerical' 36 | timestamp = 'timestamp' 37 | image_embedded = 'image_embedded' 38 | embedding = 'embedding' 39 | 40 | @property 41 | def is_text_stype(self) -> bool: 42 | return self in [stype.text_embedded, stype.text_tokenized] 43 | 44 | @property 45 | def is_image_stype(self) -> bool: 46 | return self in [stype.image_embedded] 47 | 48 | @property 49 | def use_multi_nested_tensor(self) -> bool: 50 | r"""This property indicates if the data of an stype is stored in 51 | :class:`torch_frame.data.MultiNestedTensor`. 52 | """ 53 | return self in [stype.multicategorical, self.sequence_numerical] 54 | 55 | @property 56 | def use_multi_embedding_tensor(self) -> bool: 57 | r"""This property indicates if the data of an stype is stored in 58 | :class:`torch_frame.data.MultiNestedTensor`. 59 | """ 60 | return self in [ 61 | stype.text_embedded, stype.image_embedded, stype.embedding 62 | ] 63 | 64 | @property 65 | def use_dict_multi_nested_tensor(self) -> bool: 66 | r"""This property indicates if the data of an stype is stored in 67 | a dictionary of :class:`torch_frame.data.MultiNestedTensor`. 68 | """ 69 | return self in [stype.text_tokenized] 70 | 71 | @property 72 | def use_multi_tensor(self) -> bool: 73 | r"""This property indicates if the data of an 74 | :class:`~torch_frame.stype` is stored in 75 | :class:`torch_frame.data._MultiTensor`. 76 | """ 77 | return self.use_multi_nested_tensor or self.use_multi_embedding_tensor 78 | 79 | @property 80 | def parent(self): 81 | r"""This property indicates if an :class:`~torch_frame.stype` is 82 | user-facing column :obj:`stype` or internal :obj:`stype` for grouping 83 | columns in :obj:`TensorFrame`. User-facing :class:`~torch_frame.stype` 84 | will be mapped to its parent during materialization. For 85 | :class:`~torch_frame.stype` that are both internal and 86 | user-facing, the parent maps to itself. 87 | """ 88 | if self == stype.text_embedded: 89 | return stype.embedding 90 | elif self == stype.image_embedded: 91 | return stype.embedding 92 | else: 93 | return self 94 | 95 | def __str__(self) -> str: 96 | return f'{self.name}' 97 | 98 | 99 | numerical = stype('numerical') 100 | categorical = stype('categorical') 101 | text_embedded = stype('text_embedded') 102 | text_tokenized = stype('text_tokenized') 103 | multicategorical = stype('multicategorical') 104 | sequence_numerical = stype('sequence_numerical') 105 | timestamp = stype('timestamp') 106 | image_embedded = stype('image_embedded') 107 | embedding = stype('embedding') 108 | -------------------------------------------------------------------------------- /torch_frame/datasets/kdd_census_income.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import os.path as osp 5 | import tarfile 6 | import zipfile 7 | 8 | import pandas as pd 9 | 10 | import torch_frame 11 | from torch_frame import stype 12 | 13 | 14 | class KDDCensusIncome(torch_frame.data.Dataset): 15 | r"""The `KDD Census Income 16 | `_ 17 | dataset. It's a task of forest cover type classification 18 | based on attributes such as elevation, slop and soil type etc. 19 | 20 | **STATS:** 21 | 22 | .. list-table:: 23 | :widths: 10 10 10 10 20 10 24 | :header-rows: 1 25 | 26 | * - #rows 27 | - #cols (numerical) 28 | - #cols (categorical) 29 | - #classes 30 | - Task 31 | - Missing value ratio 32 | * - 199,523 33 | - 7 34 | - 34 35 | - 2 36 | - binary_classification 37 | - 0.0% 38 | """ 39 | 40 | url = 'https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip' 41 | 42 | def __init__(self, root: str): 43 | data_dir = osp.join(root, 'census') 44 | filename = osp.join(data_dir, 'census-income.data') 45 | if not osp.exists(filename): 46 | path = self.download_url(self.url, root) 47 | tar_gz_path = osp.join(root, 'census.tar.gz') 48 | with zipfile.ZipFile(path, 'r') as zip_ref: 49 | zip_ref.extractall(root) 50 | with tarfile.open(tar_gz_path, 'r:gz') as tar_ref: 51 | tar_ref.extractall(data_dir) 52 | os.remove(tar_gz_path) 53 | os.remove(path) 54 | 55 | names = [ 56 | 'age', 57 | 'class of worker', 58 | 'industry code', 59 | 'occupation code', 60 | 'education', 61 | 'wage per hour', 62 | 'enrolled in edu inst last wk', 63 | 'marital status', 64 | 'major industry code', 65 | 'major occupation code', 66 | 'race', 67 | 'hispanic Origin', 68 | 'sex', 69 | 'member of a labor union', 70 | 'reason for unemployment', 71 | 'full or part time employment stat', 72 | 'capital gains', 73 | 'capital losses', 74 | 'divdends from stocks', 75 | 'tax filer status', 76 | 'region of previous residence', 77 | 'state of previous residence', 78 | 'detailed household and family stat', 79 | 'detailed household summary in household', 80 | 'migration code-change in msa', 81 | 'migration code-change in reg', 82 | 'migration code-move within reg', 83 | 'live in this house 1 year ago', 84 | 'migration prev res in sunbelt', 85 | 'family members under 18', 86 | 'num persons worked for employer', 87 | 'country of birth father', 88 | 'country of birth mother', 89 | 'country of birth self', 90 | 'citizenship', 91 | 'total person income', 92 | 'own business or self employed', 93 | "fill inc questionnaire for veteran's admin", 94 | 'veterans benefits', 95 | 'weeks worked in year', 96 | 'year', 97 | 'income above 50000', 98 | ] 99 | 100 | continuous_cols = { 101 | 'age', 102 | 'wage per hour', 103 | 'capital gains', 104 | 'capital losses', 105 | 'divdends from stocks', 106 | 'num persons worked for employer', 107 | 'weeks worked in year', 108 | } 109 | 110 | col_to_stype: dict[str, stype] = {} 111 | for name in names: 112 | if name in continuous_cols: 113 | col_to_stype[name] = torch_frame.numerical 114 | else: 115 | col_to_stype[name] = torch_frame.categorical 116 | 117 | df = pd.read_csv(filename, names=names) 118 | 119 | super().__init__(df, col_to_stype, target_col='income above 50000') 120 | -------------------------------------------------------------------------------- /torch_frame/transforms/mutual_information_sort.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from torch_frame import NAStrategy, TaskType, TensorFrame, stype 9 | from torch_frame.data.stats import StatType 10 | from torch_frame.transforms import FittableBaseTransform 11 | 12 | 13 | class MutualInformationSort(FittableBaseTransform): 14 | r"""A transform that sorts the numerical features of input 15 | :class:`TensorFrame` object based on mutual information. 16 | 17 | Args: 18 | task_type (TaskType): The task type. 19 | na_strategy (NAStrategy): Strategy used for imputing NaN values 20 | in numerical features. 21 | """ 22 | def __init__(self, task_type: TaskType, 23 | na_strategy: NAStrategy = NAStrategy.MEAN): 24 | super().__init__() 25 | 26 | from sklearn.feature_selection import ( 27 | mutual_info_classif, 28 | mutual_info_regression, 29 | ) 30 | 31 | if task_type in [ 32 | TaskType.MULTICLASS_CLASSIFICATION, 33 | TaskType.BINARY_CLASSIFICATION 34 | ]: 35 | self.mi_func = mutual_info_classif 36 | elif task_type == TaskType.REGRESSION: 37 | self.mi_func = mutual_info_regression 38 | else: 39 | raise ValueError( 40 | f"'{self.__class__.__name__}' can be only used on binary " 41 | "classification, multiclass classification or regression " 42 | f"task, but got {task_type}.") 43 | if not na_strategy.is_numerical_strategy: 44 | raise RuntimeError( 45 | f"Cannot use {na_strategy} for numerical features.") 46 | self.na_strategy = na_strategy 47 | 48 | def _fit(self, tf_train: TensorFrame, col_stats: dict[str, dict[StatType, 49 | Any]]): 50 | if tf_train.y is None: 51 | raise RuntimeError( 52 | "'{self.__class__.__name__}' cannot be used when target column" 53 | " is None.") 54 | if (stype.categorical in tf_train.col_names_dict 55 | and len(tf_train.col_names_dict[stype.categorical]) != 0): 56 | raise ValueError(f"'{self.__class__.__name__}' can be only used" 57 | " on TensorFrame with numerical only features.") 58 | feat_train = tf_train.feat_dict[stype.numerical] 59 | y_train = tf_train.y 60 | if torch.isnan(feat_train).any(): 61 | feat_train = self._replace_nans(feat_train, self.na_strategy) 62 | if torch.isnan(tf_train.y).any(): 63 | not_nan_indices = ~torch.isnan(y_train) 64 | if not not_nan_indices.any(): 65 | raise ValueError(f"'{self.__class__.__name__}' cannot be" 66 | "performed when all target values are" 67 | " nan.") 68 | y_train = y_train[not_nan_indices] 69 | feat_train = feat_train[not_nan_indices] 70 | mi_scores = self.mi_func(feat_train.cpu(), y_train.cpu()) 71 | self.mi_ranks = np.argsort(-mi_scores) 72 | self.mi_scores = mi_scores[self.mi_ranks] 73 | col_names = tf_train.col_names_dict[stype.numerical] 74 | ranks = {col_names[self.mi_ranks[i]]: i for i in range(len(col_names))} 75 | self.reordered_col_names = tf_train.col_names_dict[ 76 | stype.numerical].copy() 77 | 78 | for col, rank in ranks.items(): 79 | self.reordered_col_names[rank] = col 80 | self._transformed_stats = col_stats 81 | 82 | def _forward(self, tf: TensorFrame) -> TensorFrame: 83 | if tf.col_names_dict.keys() != {stype.numerical}: 84 | raise ValueError("The transform can be only used on TensorFrame" 85 | " with numerical only features.") 86 | 87 | tf.feat_dict[stype.numerical] = tf.feat_dict[ 88 | stype.numerical][:, self.mi_ranks] 89 | 90 | tf.col_names_dict[stype.numerical] = self.reordered_col_names 91 | 92 | # set lazy attribute for meta features 93 | tf.mi_scores = torch.tensor(self.mi_scores, dtype=torch.float32, 94 | device=tf.device) 95 | 96 | return tf 97 | -------------------------------------------------------------------------------- /torch_frame/nn/models/ft_transformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | from torch import Tensor 6 | from torch.nn import LayerNorm, Linear, Module, ReLU, Sequential 7 | 8 | import torch_frame 9 | from torch_frame import TensorFrame, stype 10 | from torch_frame.data.stats import StatType 11 | from torch_frame.nn.conv import FTTransformerConvs 12 | from torch_frame.nn.encoder.stype_encoder import ( 13 | EmbeddingEncoder, 14 | LinearEncoder, 15 | StypeEncoder, 16 | ) 17 | from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder 18 | 19 | 20 | class FTTransformer(Module): 21 | r"""The FT-Transformer model introduced in the 22 | `"Revisiting Deep Learning Models for Tabular Data" 23 | `_ paper. 24 | 25 | .. note:: 26 | 27 | For an example of using FTTransformer, see `examples/revisiting.py 28 | `_. 30 | 31 | Args: 32 | channels (int): Hidden channel dimensionality 33 | out_channels (int): Output channels dimensionality 34 | num_layers (int): Number of layers. (default: :obj:`3`) 35 | col_stats(dict[str,dict[:class:`torch_frame.data.stats.StatType`,Any]]): 36 | A dictionary that maps column name into stats. 37 | Available as :obj:`dataset.col_stats`. 38 | col_names_dict (dict[:obj:`torch_frame.stype`, list[str]]): A 39 | dictionary that maps stype to a list of column names. The column 40 | names are sorted based on the ordering that appear in 41 | :obj:`tensor_frame.feat_dict`. Available as 42 | :obj:`tensor_frame.col_names_dict`. 43 | stype_encoder_dict 44 | (dict[:class:`torch_frame.stype`, 45 | :class:`torch_frame.nn.encoder.StypeEncoder`], optional): 46 | A dictionary mapping stypes into their stype encoders. 47 | (default: :obj:`None`, will call 48 | :class:`torch_frame.nn.encoder.EmbeddingEncoder()` for categorical 49 | feature and :class:`torch_frame.nn.encoder.LinearEncoder()` 50 | for numerical feature) 51 | """ 52 | def __init__( 53 | self, 54 | channels: int, 55 | out_channels: int, 56 | num_layers: int, 57 | col_stats: dict[str, dict[StatType, Any]], 58 | col_names_dict: dict[torch_frame.stype, list[str]], 59 | stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] 60 | | None = None, 61 | ) -> None: 62 | super().__init__() 63 | if num_layers <= 0: 64 | raise ValueError( 65 | f"num_layers must be a positive integer (got {num_layers})") 66 | 67 | if stype_encoder_dict is None: 68 | stype_encoder_dict = { 69 | stype.categorical: EmbeddingEncoder(), 70 | stype.numerical: LinearEncoder(), 71 | } 72 | 73 | self.encoder = StypeWiseFeatureEncoder( 74 | out_channels=channels, 75 | col_stats=col_stats, 76 | col_names_dict=col_names_dict, 77 | stype_encoder_dict=stype_encoder_dict, 78 | ) 79 | self.backbone = FTTransformerConvs(channels=channels, 80 | num_layers=num_layers) 81 | self.decoder = Sequential( 82 | LayerNorm(channels), 83 | ReLU(), 84 | Linear(channels, out_channels), 85 | ) 86 | self.reset_parameters() 87 | 88 | def reset_parameters(self) -> None: 89 | self.encoder.reset_parameters() 90 | self.backbone.reset_parameters() 91 | for m in self.decoder: 92 | if not isinstance(m, ReLU): 93 | m.reset_parameters() 94 | 95 | def forward(self, tf: TensorFrame) -> Tensor: 96 | r"""Transforming :class:`TensorFrame` object into output prediction. 97 | 98 | Args: 99 | tf (TensorFrame): 100 | Input :class:`TensorFrame` object. 101 | 102 | Returns: 103 | torch.Tensor: Output of shape [batch_size, out_channels]. 104 | """ 105 | x, _ = self.encoder(tf) 106 | x, x_cls = self.backbone(x) 107 | out = self.decoder(x_cls) 108 | return out 109 | -------------------------------------------------------------------------------- /test/utils/test_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import tempfile 5 | 6 | import pytest 7 | 8 | import torch_frame 9 | from torch_frame import TensorFrame, load, save 10 | from torch_frame.config.text_embedder import TextEmbedderConfig 11 | from torch_frame.config.text_tokenizer import TextTokenizerConfig 12 | from torch_frame.datasets import FakeDataset 13 | from torch_frame.testing.text_embedder import HashTextEmbedder 14 | from torch_frame.testing.text_tokenizer import WhiteSpaceHashTokenizer 15 | 16 | TEST_DIR = tempfile.TemporaryDirectory() 17 | TEST_DATASET_NAME = 'test_dataset_tf.pt' 18 | TEST_SAVE_LOAD_NAME = 'tf.pt' 19 | 20 | 21 | def teardown_module(): 22 | if osp.exists(TEST_DIR.name): 23 | shutil.rmtree(TEST_DIR.name, ignore_errors=True) 24 | 25 | 26 | def get_fake_dataset( 27 | num_rows: int, 28 | col_to_text_embedder_cfg: TextEmbedderConfig, 29 | col_to_text_tokenizer_cfg: TextTokenizerConfig, 30 | ) -> FakeDataset: 31 | stypes = [ 32 | torch_frame.numerical, 33 | torch_frame.categorical, 34 | torch_frame.multicategorical, 35 | torch_frame.text_embedded, 36 | torch_frame.text_tokenized, 37 | torch_frame.sequence_numerical, 38 | torch_frame.embedding, 39 | ] 40 | dataset = FakeDataset( 41 | num_rows=num_rows, 42 | stypes=stypes, 43 | col_to_text_embedder_cfg=col_to_text_embedder_cfg, 44 | col_to_text_tokenizer_cfg=col_to_text_tokenizer_cfg, 45 | ) 46 | return dataset 47 | 48 | 49 | def test_dataset_cache(): 50 | num_rows = 10 51 | out_channels = 8 52 | 53 | col_to_text_embedder_cfg = TextEmbedderConfig( 54 | text_embedder=HashTextEmbedder(out_channels)) 55 | col_to_text_tokenizer_cfg = TextTokenizerConfig( 56 | text_tokenizer=WhiteSpaceHashTokenizer()) 57 | dataset = get_fake_dataset( 58 | num_rows, 59 | col_to_text_embedder_cfg, 60 | col_to_text_tokenizer_cfg, 61 | ) 62 | 63 | path = osp.join(TEST_DIR.name, TEST_DATASET_NAME) 64 | dataset.materialize(path=path) 65 | 66 | new_dataset = get_fake_dataset( 67 | num_rows, 68 | col_to_text_embedder_cfg, 69 | col_to_text_tokenizer_cfg, 70 | ) 71 | new_dataset.df = dataset.df 72 | 73 | # Test materialize via caching 74 | new_dataset.materialize(path=path) 75 | assert new_dataset.is_materialized 76 | assert dataset.col_stats == new_dataset.col_stats 77 | assert dataset.tensor_frame == new_dataset.tensor_frame 78 | 79 | # Test `tensor_frame` converter 80 | tf = new_dataset._to_tensor_frame_converter(dataset.df) 81 | assert dataset.tensor_frame == tf 82 | 83 | # Remove saved tensor frame object 84 | os.remove(path) 85 | 86 | new_dataset = get_fake_dataset( 87 | num_rows, 88 | col_to_text_embedder_cfg, 89 | col_to_text_tokenizer_cfg, 90 | ) 91 | new_dataset.df = new_dataset.df 92 | 93 | # Test materialize again with specified path 94 | new_dataset.materialize() 95 | new_dataset.materialize(path=path) 96 | 97 | assert new_dataset.is_materialized 98 | 99 | 100 | def test_save_load_tensor_frame(): 101 | num_rows = 10 102 | out_channels = 8 103 | col_to_text_embedder_cfg = TextEmbedderConfig( 104 | text_embedder=HashTextEmbedder(out_channels)) 105 | col_to_text_tokenizer_cfg = TextTokenizerConfig( 106 | text_tokenizer=WhiteSpaceHashTokenizer(), 107 | batch_size=None, 108 | ) 109 | dataset = get_fake_dataset(num_rows, col_to_text_embedder_cfg, 110 | col_to_text_tokenizer_cfg) 111 | dataset.materialize() 112 | 113 | path = osp.join(TEST_DIR.name, TEST_SAVE_LOAD_NAME) 114 | save(dataset.tensor_frame, dataset.col_stats, path) 115 | 116 | tf, col_stats = load(path) 117 | assert dataset.col_stats == col_stats 118 | assert dataset.tensor_frame == tf 119 | 120 | 121 | class UntrustedClass: 122 | pass 123 | 124 | 125 | @pytest.mark.skipif( 126 | not torch_frame.typing.WITH_PT24, 127 | reason='Requres PyTorch 2.4', 128 | ) 129 | def test_load_weights_only_gracefully(tmpdir): 130 | save( 131 | tensor_frame=TensorFrame({}, {}), 132 | col_stats={'a': UntrustedClass()}, 133 | path=tmpdir.join('tf.pt'), 134 | ) 135 | with pytest.warns(UserWarning, match='Weights only load failed'): 136 | load(tmpdir.join('tf.pt')) 137 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PyTorch Frame 2 | 3 | If you are interested in contributing to PyTorch Frame, your contributions will likely fall into one of the following two categories: 4 | 5 | 1. You want to implement a new feature: 6 | - In general, we accept any features as long as they fit the scope of this package. If you are unsure about this or need help on the design/implementation of your feature, post about it in an issue. 7 | 1. You want to fix a bug: 8 | - Feel free to send a Pull Request (PR) any time you encounter a bug. Please provide a clear and concise description of what the bug was. If you are unsure about if this is a bug at all or how to fix, post about it in an issue. 9 | 10 | Once you finish implementing a feature or bug-fix, please send a PR to https://github.com/pyg-team/pytorch-frame. 11 | 12 | Your PR will be merged after one or more rounds of reviews by the [pyg-team](https://github.com/pyg-team). 13 | 14 | ## Developing PyTorch Frame 15 | 16 | To develop PyTorch Frame on your machine, here are some tips: 17 | 18 | 1. Ensure that you are running on one of the supported PyTorch versions (*e.g.*, `2.1.0`): 19 | 20 | ```python 21 | import torch 22 | print(torch.__version__) 23 | ``` 24 | 25 | 1. Uninstall all existing PyTorch Frame installations. 26 | It is advised to run this command repeatedly to confirm that installations across all locations are properly removed. 27 | 28 | ```bash 29 | pip uninstall pytorch-frame 30 | ``` 31 | 32 | 1. Fork and clone the PyTorch Frame repository: 33 | 34 | ```bash 35 | git clone https://github.com//pytorch-frame.git 36 | cd pytorch-frame 37 | 38 | ``` 39 | 40 | 1. If you already cloned PyTorch Frame from source, update it: 41 | 42 | ```bash 43 | git pull 44 | ``` 45 | 46 | 1. Install PyTorch Frame in editable mode: 47 | 48 | ```bash 49 | pip install -e ".[dev,full]" 50 | ``` 51 | 52 | This mode will symlink the Python files from the current local source tree into the Python install. 53 | Hence, if you modify a Python file, you do not need to re-install PyTorch Frame again. 54 | 55 | 1. Ensure that you have a working PyTorch Frame installation by running the entire test suite with 56 | 57 | ```bash 58 | pytest 59 | ``` 60 | 61 | 1. Install pre-commit hooks: 62 | 63 | ```bash 64 | pre-commit install 65 | ``` 66 | 67 | ## Unit Testing 68 | 69 | The PyTorch Frame testing suite is located under `test/`. 70 | Run the test suite with 71 | 72 | ```bash 73 | # all test cases 74 | pytest 75 | 76 | # individual test cases 77 | pytest test/utils/test_split.py 78 | ``` 79 | 80 | ## Continuous Integration 81 | 82 | PyTorch Frame uses [GitHub Actions](https://github.com/pyg-team/pytorch-frame/actions) in combination with [CodeCov](https://codecov.io/github/pyg-team/pytorch-frame?branch=master) for continuous integration. 83 | 84 | Everytime you send a Pull Request, your commit will be built and checked against the PyTorch Frame guidelines: 85 | 86 | 1. Ensure that your code is formatted correctly by testing against the styleguide of [`flake8`](https://github.com/PyCQA/flake8). 87 | We use the [`Flake8-pyproject`](https://pypi.org/project/Flake8-pyproject/) plugin for configuration: 88 | 89 | ```bash 90 | flake8 91 | ``` 92 | 93 | If you do not want to format your code manually, we recommend to use [`yapf`](https://github.com/google/yapf). 94 | 95 | 1. Ensure that the entire test suite passes and that code coverage roughly stays the same. 96 | Please feel encouraged to provide a test with your submitted code. 97 | To test, either run 98 | 99 | ```bash 100 | pytest --cov 101 | ``` 102 | 103 | (which runs a set of additional but time-consuming tests) dependening on your needs. 104 | 105 | 1. Add your feature/bugfix to the [`CHANGELOG.md`](https://github.com/pyg-team/pyotrch-frame/blob/master/CHANGELOG.md?plain=1). 106 | If multiple PRs move towards integrating a single feature, it is advised to group them together into one bullet point. 107 | 108 | ## Building Documentation 109 | 110 | To build the documentation: 111 | 112 | 1. [Build and install](#developing-pytorch-frame) PyTorch Frame from source. 113 | 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via 114 | ```bash 115 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git 116 | ``` 117 | 1. Generate the documentation via: 118 | ```bash 119 | cd docs 120 | make html 121 | ``` 122 | -------------------------------------------------------------------------------- /torch_frame/nn/models/mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import ( 8 | BatchNorm1d, 9 | Dropout, 10 | LayerNorm, 11 | Linear, 12 | Module, 13 | ReLU, 14 | Sequential, 15 | ) 16 | 17 | import torch_frame 18 | from torch_frame import TensorFrame, stype 19 | from torch_frame.data.stats import StatType 20 | from torch_frame.nn.encoder.stype_encoder import ( 21 | EmbeddingEncoder, 22 | LinearEncoder, 23 | StypeEncoder, 24 | ) 25 | from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder 26 | 27 | 28 | class MLP(Module): 29 | r"""The light-weight MLP model that mean-pools column embeddings and 30 | applies MLP over it. 31 | 32 | Args: 33 | channels (int): The number of channels in the backbone layers. 34 | out_channels (int): The number of output channels in the decoder. 35 | num_layers (int): The number of layers in the backbone. 36 | col_stats(dict[str,Dict[:class:`torch_frame.data.stats.StatType`,Any]]): 37 | A dictionary that maps column name into stats. 38 | Available as :obj:`dataset.col_stats`. 39 | col_names_dict (dict[:class:`torch_frame.stype`, List[str]]): A 40 | dictionary that maps stype to a list of column names. The column 41 | names are sorted based on the ordering that appear in 42 | :obj:`tensor_frame.feat_dict`. Available as 43 | :obj:`tensor_frame.col_names_dict`. 44 | stype_encoder_dict 45 | (dict[:class:`torch_frame.stype`, 46 | :class:`torch_frame.nn.encoder.StypeEncoder`], optional): 47 | A dictionary mapping stypes into their stype encoders. 48 | (default: :obj:`None`, will call :obj:`EmbeddingEncoder()` 49 | for categorical feature and :obj:`LinearEncoder()` for 50 | numerical feature) 51 | normalization (str, optional): The type of normalization to use. 52 | :obj:`batch_norm`, :obj:`layer_norm`, or :obj:`None`. 53 | (default: :obj:`layer_norm`) 54 | dropout_prob (float): The dropout probability (default: `0.2`). 55 | """ 56 | def __init__( 57 | self, 58 | channels: int, 59 | out_channels: int, 60 | num_layers: int, 61 | col_stats: dict[str, dict[StatType, Any]], 62 | col_names_dict: dict[torch_frame.stype, list[str]], 63 | stype_encoder_dict: dict[torch_frame.stype, StypeEncoder] 64 | | None = None, 65 | normalization: str | None = "layer_norm", 66 | dropout_prob: float = 0.2, 67 | ) -> None: 68 | super().__init__() 69 | 70 | if stype_encoder_dict is None: 71 | stype_encoder_dict = { 72 | stype.categorical: EmbeddingEncoder(), 73 | stype.numerical: LinearEncoder(), 74 | } 75 | 76 | self.encoder = StypeWiseFeatureEncoder( 77 | out_channels=channels, 78 | col_stats=col_stats, 79 | col_names_dict=col_names_dict, 80 | stype_encoder_dict=stype_encoder_dict, 81 | ) 82 | 83 | self.mlp = Sequential() 84 | 85 | for _ in range(num_layers - 1): 86 | self.mlp.append(Linear(channels, channels)) 87 | if normalization == "layer_norm": 88 | self.mlp.append(LayerNorm(channels)) 89 | elif normalization == "batch_norm": 90 | self.mlp.append(BatchNorm1d(channels)) 91 | self.mlp.append(ReLU()) 92 | self.mlp.append(Dropout(p=dropout_prob)) 93 | self.mlp.append(Linear(channels, out_channels)) 94 | 95 | self.reset_parameters() 96 | 97 | def reset_parameters(self) -> None: 98 | self.encoder.reset_parameters() 99 | for param in self.mlp: 100 | if hasattr(param, 'reset_parameters'): 101 | param.reset_parameters() 102 | 103 | def forward(self, tf: TensorFrame) -> Tensor: 104 | r"""Transforming :class:`TensorFrame` object into output prediction. 105 | 106 | Args: 107 | tf (TensorFrame): Input :class:`TensorFrame` object. 108 | 109 | Returns: 110 | torch.Tensor: Output of shape [batch_size, out_channels]. 111 | """ 112 | x, _ = self.encoder(tf) 113 | 114 | x = torch.mean(x, dim=1) 115 | 116 | out = self.mlp(x) 117 | return out 118 | -------------------------------------------------------------------------------- /benchmark/encoder/README.md: -------------------------------------------------------------------------------- 1 | # Encoder Benchmark 2 | 3 | ## Usage 4 | 5 | Exemplary command: 6 | 7 | ``` 8 | python encoder_benchmark.py --stype-kv categorical embedding --stype-kv numerical linear 9 | ``` 10 | 11 | It will create a dataset that will contain categorical and numerical columns and will use for them 12 | embedding and linear encoders, respectively. 13 | 14 | Arguments: 15 | 16 | - **--stype-kv**: Specify the stype(s) and corresponding encoder(s) to run. 17 | - **--num-rows**: The number of rows in the dataset (default is `8192`). 18 | - **--out-channels**: The number of output channels (default is `128`). 19 | - **--with-nan**: If specified, the dataset will include NaN values. 20 | - **--runs**: The number of runs for the benchmark (default is `1000`). 21 | - **--warmup-size**: The size of the warmup stage (default is `200`). 22 | - **--torch-profile**: If specified, torch profiling will be enabled. 23 | - **--line-profile**: If specified, line profiling will be enabled. 24 | - **--line-profile-level**: The level of line profiling (default is `'encode_forward'`). 25 | - **--device**: The device to run the benchmark on (default is `'cpu'`). 26 | 27 | No matter if any profiler is used, benchmark always outputs a latency (single run execution time), e.g.: 28 | 29 | ``` 30 | Latency: 0.034277s 31 | ``` 32 | 33 | Torch profiler produces a table of operations sorted by execution time, e.g.: 34 | 35 | ``` 36 | ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ 37 | Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls 38 | ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ 39 | aten::cat 47.49% 2.027s 48.05% 2.051s 1.025ms 2000 40 | aten::nan_to_num 19.74% 842.584ms 39.35% 1.680s 419.945us 4000 41 | aten::add 10.04% 428.549ms 10.04% 428.549ms 214.274us 2000 42 | aten::index_select 6.28% 268.051ms 7.78% 331.959ms 165.980us 2000 43 | aten::mul 4.96% 211.853ms 4.96% 211.853ms 211.853us 1000 44 | aten::sub 1.36% 58.064ms 1.36% 58.064ms 58.064us 1000 45 | aten::any 1.30% 55.612ms 1.41% 60.278ms 30.139us 2000 46 | aten::div 1.22% 52.159ms 1.22% 52.159ms 52.159us 1000 47 | ... 48 | ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ 49 | Self CPU time total: 4.268s 50 | ``` 51 | 52 | Line profiler shows how many percent was spent on each of method lines, e.g.: 53 | 54 | ``` 55 | Total time: 1.03661 s 56 | File: {PF_BASE_PATH}/pytorch-frame/torch_frame/nn/encoder/stype_encoder.py 57 | Function: encode_forward at line 295 58 | 59 | Line # Hits Time Per Hit % Time Line Contents 60 | ============================================================== 61 | 295 def encode_forward( 62 | 296 self, 63 | 297 feat: Tensor, 64 | 298 col_names: list[str] | None = None, 65 | 299 ) -> Tensor: 66 | 300 # TODO: Make this more efficient. 67 | 301 # Increment the index by one so that NaN index (-1) becomes 0 68 | 302 # (padding_idx) 69 | 303 # feat: [batch_size, num_cols] 70 | 304 1200 29867.4 24.9 2.9 feat = feat + 1 71 | 305 1200 944.3 0.8 0.1 xs = [] 72 | 306 3600 19345.7 5.4 1.9 for i, emb in enumerate(self.embs): 73 | 307 2400 466967.7 194.6 45.0 xs.append(emb(feat[:, i])) 74 | 308 # [batch_size, num_cols, hidden_channels] 75 | 309 1200 519044.9 432.5 50.1 x = torch.stack(xs, dim=1) 76 | 310 1200 440.8 0.4 0.0 return x 77 | ``` 78 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires=["flit_core >=3.12,<4"] 3 | build-backend="flit_core.buildapi" 4 | 5 | [project] 6 | name="pytorch-frame" 7 | authors=[ 8 | {name="PyG Team", email="team@pyg.org"}, 9 | ] 10 | dynamic=["version"] 11 | description="Tabular Deep Learning Library for PyTorch" 12 | readme="README.md" 13 | requires-python=">=3.10" 14 | keywords=[ 15 | "deep-learning", 16 | "pytorch", 17 | "tabular-learning", 18 | "data-frame", 19 | ] 20 | license = "MIT" 21 | license-files = ["LICENSE"] 22 | classifiers=[ 23 | "Development Status :: 4 - Beta", 24 | "Programming Language :: Python", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Programming Language :: Python :: 3.13", 29 | "Programming Language :: Python :: 3 :: Only", 30 | ] 31 | dependencies=[ 32 | "numpy", 33 | "pandas", 34 | "torch", 35 | "tqdm", 36 | "pyarrow", 37 | "Pillow", 38 | ] 39 | 40 | [project.optional-dependencies] 41 | test=[ 42 | "pytest", 43 | "pytest-cov", 44 | "mypy", 45 | ] 46 | dev=[ 47 | "pytorch-frame[test]", 48 | "pre-commit", 49 | ] 50 | full=[ 51 | "scikit-learn", 52 | "xgboost>=1.7.0, <2.0.0", 53 | "optuna>=3.0.0", 54 | "optuna-integration", 55 | "mpmath==1.3.0", 56 | "catboost", 57 | "lightgbm", 58 | "datasets", 59 | "torchmetrics", 60 | ] 61 | 62 | [project.urls] 63 | homepage="https://pyg.org" 64 | documentation="https://pytorch-frame.readthedocs.io" 65 | repository="https://github.com/pyg-team/pytorch-frame.git" 66 | changelog="https://github.com/pyg-team/pytorch-frame/blob/master/CHANGELOG.md" 67 | 68 | [tool.flit.module] 69 | name="torch_frame" 70 | 71 | [tool.ruff] # https://docs.astral.sh/ruff/rules 72 | target-version = "py310" 73 | src = ["torch_frame", "test", "examples", "benchmark"] 74 | line-length = 80 75 | indent-width = 4 76 | 77 | [tool.ruff.lint] 78 | select = [ 79 | "B", # flake8-bugbear 80 | "D", # pydocstyle 81 | "UP", # pyupgrade 82 | ] 83 | ignore = [ 84 | "D100", # TODO: Don't ignore "Missing docstring in public module" 85 | "D101", # TODO: Don't ignore "Missing docstring in public class" 86 | "D102", # TODO: Don't ignore "Missing docstring in public method" 87 | "D103", # TODO: Don't ignore "Missing docstring in public function" 88 | "D105", # Ignore "Missing docstring in magic method" 89 | "D107", # Ignore "Missing docstring in __init__" 90 | "D205", # Ignore "blank line required between summary line and description" 91 | ] 92 | 93 | [tool.ruff.lint.pydocstyle] 94 | convention = "google" 95 | 96 | [tool.yapf] 97 | based_on_style = "pep8" 98 | split_before_named_assigns = false 99 | blank_line_before_nested_class_or_def = false 100 | 101 | [tool.isort] 102 | multi_line_output = 3 103 | include_trailing_comma = true 104 | skip = [".gitignore", "__init__.py"] 105 | 106 | [tool.flake8] 107 | ignore = ["F811", "W503", "W504"] 108 | 109 | [tool.mypy] 110 | files = ["torch_frame"] 111 | install_types = true 112 | non_interactive = true 113 | ignore_missing_imports = true 114 | show_error_codes = true 115 | warn_redundant_casts = true 116 | warn_unused_configs = true 117 | warn_unused_ignores = true 118 | 119 | # TODO: the goal is for this ignore list to be empty 120 | [[tool.mypy.overrides]] 121 | ignore_errors = true 122 | # Run this command to generate this list of files 123 | # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' 124 | module = [ 125 | "torch_frame.data.dataset", 126 | "torch_frame.data.loader", 127 | "torch_frame.data.mapper", 128 | "torch_frame.data.stats", 129 | "torch_frame.data.tensor_frame", 130 | "torch_frame.gbdt.gbdt", 131 | "torch_frame.gbdt.tuned_catboost", 132 | "torch_frame.gbdt.tuned_lightgbm", 133 | "torch_frame.gbdt.tuned_xgboost", 134 | "torch_frame.nn.encoder.stype_encoder", 135 | "torch_frame.testing.text_tokenizer", 136 | "torch_frame.transforms.base_transform", 137 | "torch_frame.transforms.cat_to_num_transform", 138 | "torch_frame.transforms.fittable_base_transform", 139 | "torch_frame.transforms.mutual_information_sort", 140 | ] 141 | 142 | [tool.pytest.ini_options] 143 | addopts = [ 144 | "--capture=no", 145 | "--color=yes", 146 | "-vv", 147 | ] 148 | 149 | [tool.coverage.report] 150 | exclude_lines = [ 151 | "pragma: no cover", 152 | "pass", 153 | "raise NotImplementedError", 154 | ] 155 | -------------------------------------------------------------------------------- /examples/tabnet.py: -------------------------------------------------------------------------------- 1 | """Reported (reproduced) results of of TabNet model in the original paper 2 | https://arxiv.org/abs/1908.07442. 3 | 4 | Forest Cover Type: 96.99 (96.53) 5 | KDD Census Income: 95.5 (95.41) 6 | """ 7 | 8 | import argparse 9 | import os.path as osp 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.lr_scheduler import ExponentialLR 14 | from tqdm import tqdm 15 | 16 | from torch_frame.data import DataLoader 17 | from torch_frame.datasets import ForestCoverType, KDDCensusIncome 18 | from torch_frame.nn import TabNet 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', type=str, default="ForestCoverType", 22 | choices=["ForestCoverType", "KDDCensusIncome"]) 23 | parser.add_argument('--channels', type=int, default=128) 24 | parser.add_argument('--gamma', type=int, default=1.2) 25 | parser.add_argument('--num_layers', type=int, default=6) 26 | parser.add_argument('--batch_size', type=int, default=4096) 27 | parser.add_argument('--lr', type=float, default=0.005) 28 | parser.add_argument('--epochs', type=int, default=50) 29 | parser.add_argument('--seed', type=int, default=0) 30 | parser.add_argument('--compile', action='store_true') 31 | args = parser.parse_args() 32 | 33 | torch.manual_seed(args.seed) 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | 36 | # Prepare datasets 37 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 38 | args.dataset) 39 | if args.dataset == "ForestCoverType": 40 | dataset = ForestCoverType(root=path) 41 | elif args.dataset == "KDDCensusIncome": 42 | dataset = KDDCensusIncome(root=path) 43 | else: 44 | raise ValueError(f"Unsupported dataset called {args.dataset}") 45 | 46 | dataset.materialize() 47 | assert dataset.task_type.is_classification 48 | dataset = dataset.shuffle() 49 | # Split ratio is set to 80% / 10% / 10% (no clear mentioning of split in the 50 | # original TabNet paper) 51 | train_dataset, val_dataset, test_dataset = dataset[:0.8], dataset[ 52 | 0.8:0.9], dataset[0.9:] 53 | 54 | # Set up data loaders 55 | train_tensor_frame = train_dataset.tensor_frame 56 | val_tensor_frame = val_dataset.tensor_frame 57 | test_tensor_frame = test_dataset.tensor_frame 58 | train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size, 59 | shuffle=True) 60 | val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size) 61 | test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size) 62 | 63 | # Set up model and optimizer 64 | model = TabNet( 65 | out_channels=dataset.num_classes, 66 | num_layers=args.num_layers, 67 | split_attn_channels=args.channels, 68 | split_feat_channels=args.channels, 69 | gamma=args.gamma, 70 | col_stats=dataset.col_stats, 71 | col_names_dict=train_tensor_frame.col_names_dict, 72 | ).to(device) 73 | model = torch.compile(model, dynamic=True) if args.compile else model 74 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 75 | lr_scheduler = ExponentialLR(optimizer, gamma=0.95) 76 | 77 | 78 | def train(epoch: int) -> float: 79 | model.train() 80 | loss_accum = total_count = 0 81 | 82 | for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'): 83 | tf = tf.to(device) 84 | pred = model(tf) 85 | loss = F.cross_entropy(pred, tf.y) 86 | optimizer.zero_grad() 87 | loss.backward() 88 | loss_accum += float(loss) * len(tf.y) 89 | total_count += len(tf.y) 90 | optimizer.step() 91 | return loss_accum / total_count 92 | 93 | 94 | @torch.no_grad() 95 | def test(loader: DataLoader) -> float: 96 | model.eval() 97 | accum = total_count = 0 98 | 99 | for tf in loader: 100 | tf = tf.to(device) 101 | pred = model(tf) 102 | pred_class = pred.argmax(dim=-1) 103 | accum += float((tf.y == pred_class).sum()) 104 | total_count += len(tf.y) 105 | 106 | return accum / total_count 107 | 108 | 109 | best_val_acc = 0 110 | best_test_acc = 0 111 | for epoch in range(1, args.epochs + 1): 112 | train_loss = train(epoch) 113 | train_acc = test(train_loader) 114 | val_acc = test(val_loader) 115 | test_acc = test(test_loader) 116 | if best_val_acc < val_acc: 117 | best_val_acc = val_acc 118 | best_test_acc = test_acc 119 | print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, ' 120 | f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}') 121 | lr_scheduler.step() 122 | 123 | print(f'Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}') 124 | --------------------------------------------------------------------------------