├── docs
├── source
│ ├── .gitignore
│ ├── _figures
│ │ ├── modular.png
│ │ ├── voyageai.webp
│ │ ├── OpenAI_Logo.png
│ │ ├── architecture.png
│ │ ├── cohere-logo.png
│ │ ├── hf-logo-with-title.png
│ │ ├── pytorch_frame_logo.JPG
│ │ ├── pytorch_frame_logo.png
│ │ ├── pytorch_frame_logo_text.JPG
│ │ └── pytorch_frame_logo_text.png
│ ├── modules
│ │ ├── root.rst
│ │ ├── utils.rst
│ │ ├── config.rst
│ │ ├── gbdt.rst
│ │ ├── datasets.rst
│ │ ├── data.rst
│ │ ├── nn.rst
│ │ └── transforms.rst
│ ├── _templates
│ │ └── autosummary
│ │ │ └── class.rst
│ ├── _static
│ │ └── js
│ │ │ └── version_alert.js
│ ├── get_started
│ │ └── installation.rst
│ ├── index.rst
│ └── conf.py
├── requirements.txt
├── Makefile
└── README.md
├── torch_frame
├── nn
│ ├── encoding
│ │ ├── __init__.py
│ │ ├── positional_encoding.py
│ │ └── cyclic_encoding.py
│ ├── decoder
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ ├── excelformer_decoder.py
│ │ └── trompt_decoder.py
│ ├── __init__.py
│ ├── conv
│ │ ├── __init__.py
│ │ ├── table_conv.py
│ │ └── ft_transformer_convs.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── ft_transformer.py
│ │ └── mlp.py
│ ├── encoder
│ │ ├── __init__.py
│ │ ├── encoder.py
│ │ └── stypewise_encoder.py
│ ├── utils
│ │ └── init.py
│ └── base.py
├── testing
│ ├── __init__.py
│ ├── image_embedder.py
│ ├── text_embedder.py
│ ├── decorators.py
│ └── text_tokenizer.py
├── gbdt
│ └── __init__.py
├── config
│ ├── __init__.py
│ ├── model.py
│ ├── text_tokenizer.py
│ ├── text_embedder.py
│ └── image_embedder.py
├── transforms
│ ├── __init__.py
│ ├── base_transform.py
│ ├── fittable_base_transform.py
│ └── mutual_information_sort.py
├── utils
│ ├── __init__.py
│ ├── memory.py
│ └── split.py
├── data
│ ├── __init__.py
│ ├── download.py
│ └── loader.py
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── titanic.py
│ ├── dota2.py
│ ├── adult_census_income.py
│ ├── amazon_fine_food_reviews.py
│ ├── poker_hand.py
│ ├── mercari.py
│ ├── bank_marketing.py
│ ├── diamond_images.py
│ ├── amphibians.py
│ ├── movielens_1m.py
│ ├── mushroom.py
│ └── kdd_census_income.py
├── typing.py
└── _stype.py
├── .github
├── labeler.yml
├── workflows
│ ├── changelog.yml
│ ├── labeler.yml
│ ├── linting.yml
│ ├── auto-merge.yml
│ ├── documentation.yml
│ ├── release.yml
│ └── testing.yml
├── dependabot.yml
├── actions
│ └── setup
│ │ └── action.yml
└── CONTRIBUTING.md
├── readthedocs.yml
├── .gitignore
├── codecov.yml
├── test
├── nn
│ ├── conv
│ │ ├── test_ft_transformer_convs.py
│ │ ├── test_excelformer_conv.py
│ │ ├── test_tab_transformer_conv.py
│ │ └── test_trompt_conv.py
│ ├── decoder
│ │ ├── test_excelformer_decoder.py
│ │ └── test_trompt_decoder.py
│ ├── encoding
│ │ ├── test_positional_encoding.py
│ │ └── test_cyclic_encoding.py
│ ├── models
│ │ ├── test_tabnet.py
│ │ ├── test_resnet.py
│ │ ├── test_ft_transformer.py
│ │ ├── test_mlp.py
│ │ ├── test_tab_transformer.py
│ │ ├── test_trompt.py
│ │ ├── test_excelformer.py
│ │ └── test_compile.py
│ └── test_simple_basecls.py
├── data
│ └── test_loader.py
├── utils
│ ├── test_split.py
│ ├── test_memory.py
│ ├── test_concat.py
│ ├── test_infer_stype.py
│ └── test_io.py
├── test_stype.py
├── datasets
│ ├── test_titanic.py
│ └── test_movielens_1m.py
├── transforms
│ └── test_mutual_information_sort.py
├── conftest.py
└── gbdt
│ └── test_gbdt.py
├── CITATION.cff
├── LICENSE
├── examples
├── tabpfn_classification.py
├── tuned_gbdt.py
└── tabnet.py
├── .pre-commit-config.yaml
├── benchmark
└── encoder
│ └── README.md
└── pyproject.toml
/docs/source/.gitignore:
--------------------------------------------------------------------------------
1 | generated/
2 |
--------------------------------------------------------------------------------
/docs/source/_figures/modular.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/modular.png
--------------------------------------------------------------------------------
/docs/source/_figures/voyageai.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/voyageai.webp
--------------------------------------------------------------------------------
/docs/source/_figures/OpenAI_Logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/OpenAI_Logo.png
--------------------------------------------------------------------------------
/docs/source/_figures/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/architecture.png
--------------------------------------------------------------------------------
/docs/source/_figures/cohere-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/cohere-logo.png
--------------------------------------------------------------------------------
/docs/source/_figures/hf-logo-with-title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/hf-logo-with-title.png
--------------------------------------------------------------------------------
/docs/source/_figures/pytorch_frame_logo.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo.JPG
--------------------------------------------------------------------------------
/docs/source/_figures/pytorch_frame_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo.png
--------------------------------------------------------------------------------
/docs/source/_figures/pytorch_frame_logo_text.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo_text.JPG
--------------------------------------------------------------------------------
/docs/source/_figures/pytorch_frame_logo_text.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pyg-team/pytorch-frame/HEAD/docs/source/_figures/pytorch_frame_logo_text.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl
2 | git+https://github.com/pyg-team/pyg_sphinx_theme.git
3 |
--------------------------------------------------------------------------------
/docs/source/modules/root.rst:
--------------------------------------------------------------------------------
1 | torch_frame
2 | ===========
3 |
4 | .. automodule:: torch_frame.stype
5 | :members:
6 |
7 | .. automodule:: torch_frame.typing
8 | :members:
9 |
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :show-inheritance:
7 | :members:
8 |
--------------------------------------------------------------------------------
/torch_frame/nn/encoding/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Encoding package."""
2 | from .cyclic_encoding import CyclicEncoding
3 | from .positional_encoding import PositionalEncoding
4 |
5 | __all__ = classes = [
6 | 'CyclicEncoding',
7 | 'PositionalEncoding',
8 | ]
9 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | SPHINXBUILD = sphinx-build
2 | SPHINXPROJ = pytorch-frame
3 | SOURCEDIR = source
4 | BUILDDIR = build
5 |
6 | .PHONY: help Makefile
7 |
8 | %: Makefile
9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(0)
10 |
--------------------------------------------------------------------------------
/.github/labeler.yml:
--------------------------------------------------------------------------------
1 | documentation:
2 | - docs/**/*
3 |
4 | example:
5 | - examples/**/*
6 |
7 | data:
8 | - torch_frame/data/**/*
9 |
10 | dataset:
11 | - torch_frame/datasets/**/*
12 |
13 | nn:
14 | - torch_frame/nn/**/*
15 |
16 | utils:
17 | - torch_frame/utils/**/*
18 |
--------------------------------------------------------------------------------
/torch_frame/testing/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Utility package for testing."""
2 | from .decorators import (
3 | has_package,
4 | withPackage,
5 | withCUDA,
6 | onlyCUDA,
7 | )
8 |
9 | __all__ = [
10 | 'has_package',
11 | 'withPackage',
12 | 'withCUDA',
13 | 'onlyCUDA',
14 | ]
15 |
--------------------------------------------------------------------------------
/torch_frame/nn/decoder/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Decoder package."""
2 | from .decoder import Decoder
3 | from .trompt_decoder import TromptDecoder
4 | from .excelformer_decoder import ExcelFormerDecoder
5 |
6 | __all__ = classes = [
7 | 'Decoder',
8 | 'TromptDecoder',
9 | 'ExcelFormerDecoder',
10 | ]
11 |
--------------------------------------------------------------------------------
/torch_frame/nn/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Neural network module package."""
2 | from .base import Module
3 | from .encoder import * # noqa
4 | from .encoding import * # noqa
5 | from .conv import * # noqa
6 | from .decoder import * # noqa
7 | from .models import * # noqa
8 |
9 | __all__ = [
10 | 'Module',
11 | ]
12 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | sphinx:
4 | configuration: docs/source/conf.py
5 |
6 | build:
7 | os: ubuntu-22.04
8 | tools:
9 | python: "3.10"
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 | - method: pip
15 | path: .
16 |
17 | formats: []
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .pytest_cache/
3 | .mypy_cache/
4 | .DS_Store
5 | build/
6 | dist/
7 | alpha/
8 | .cache/
9 | .eggs/
10 | *.egg-info/
11 | .ipynb_checkpoints
12 | .coverage
13 | .coverage.*
14 | coverage.xml
15 | .vscode
16 | .idea
17 | .venv
18 | venv/*
19 | *.out
20 | data/**
21 | catboost_info/
22 | .pt_tmp/
23 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | # See: https://docs.codecov.io/docs/codecov-yaml
2 | coverage:
3 | range: 80..100
4 | round: down
5 | precision: 2
6 | status:
7 | project:
8 | default:
9 | target: 80%
10 | threshold: 1%
11 | patch:
12 | default:
13 | target: 80%
14 | threshold: 1%
15 |
--------------------------------------------------------------------------------
/torch_frame/gbdt/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Gradient Boosting Decision Trees package."""
2 | from .gbdt import GBDT
3 | from .tuned_xgboost import XGBoost
4 | from .tuned_catboost import CatBoost
5 | from .tuned_lightgbm import LightGBM
6 |
7 | __all__ = classes = [
8 | 'GBDT',
9 | 'XGBoost',
10 | 'CatBoost',
11 | 'LightGBM',
12 | ]
13 |
--------------------------------------------------------------------------------
/docs/source/modules/utils.rst:
--------------------------------------------------------------------------------
1 | torch_frame.utils
2 | =================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | .. currentmodule:: torch_frame.utils
8 |
9 | .. autosummary::
10 | :nosignatures:
11 | :toctree: ../generated
12 |
13 | {% for name in torch_frame.utils.functions %}
14 | {{ name }}
15 | {% endfor %}
16 |
--------------------------------------------------------------------------------
/docs/source/modules/config.rst:
--------------------------------------------------------------------------------
1 | torch_frame.config
2 | ==================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | .. currentmodule:: torch_frame.config
8 |
9 | Text Embedder
10 | -------------
11 |
12 | .. autosummary::
13 | :nosignatures:
14 | :toctree: ../generated
15 |
16 | {% for name in torch_frame.config.classes %}
17 | {{ name }}
18 | {% endfor %}
19 |
--------------------------------------------------------------------------------
/docs/source/modules/gbdt.rst:
--------------------------------------------------------------------------------
1 | torch_frame.gbdt
2 | ================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | .. currentmodule:: torch_frame.gbdt
8 |
9 | Gradient Boosted Decision Trees
10 | -------------------------------
11 |
12 | .. autosummary::
13 | :nosignatures:
14 | :toctree: ../generated
15 |
16 | {% for name in torch_frame.gbdt.classes %}
17 | {{ name }}
18 | {% endfor %}
19 |
--------------------------------------------------------------------------------
/test/nn/conv/test_ft_transformer_convs.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import FTTransformerConvs
4 |
5 |
6 | def test_ft_transformer_convs():
7 | x = torch.randn(size=(10, 3, 8))
8 | conv = FTTransformerConvs(channels=8, num_layers=3)
9 | x, x_cls = conv(x)
10 | # The first added column corresponds to CLS token.
11 | assert x.shape == (10, 3, 8)
12 | assert x_cls.shape == (10, 8)
13 |
--------------------------------------------------------------------------------
/torch_frame/config/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Config package."""
2 | from .text_embedder import TextEmbedderConfig
3 | from .text_tokenizer import TextTokenizerConfig
4 | from .model import ModelConfig
5 | from .image_embedder import ImageEmbedderConfig, ImageEmbedder
6 |
7 | __all__ = classes = [
8 | 'TextEmbedderConfig',
9 | 'TextTokenizerConfig',
10 | 'ModelConfig',
11 | 'ImageEmbedderConfig',
12 | 'ImageEmbedder',
13 | ]
14 |
--------------------------------------------------------------------------------
/torch_frame/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Transforms package."""
2 | from .base_transform import BaseTransform
3 | from .fittable_base_transform import FittableBaseTransform
4 | from .cat_to_num_transform import CatToNumTransform
5 | from .mutual_information_sort import MutualInformationSort
6 |
7 | __all__ = functions = [
8 | 'BaseTransform',
9 | 'FittableBaseTransform',
10 | 'CatToNumTransform',
11 | 'MutualInformationSort',
12 | ]
13 |
--------------------------------------------------------------------------------
/torch_frame/utils/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Utility package."""
2 | from .io import save, load
3 | from .concat import cat
4 | from .split import generate_random_split
5 | from .infer_stype import infer_series_stype, infer_df_stype
6 | from .memory import num_bytes
7 |
8 | __all__ = functions = [
9 | "save",
10 | "load",
11 | "cat",
12 | "generate_random_split",
13 | "infer_series_stype",
14 | "infer_df_stype",
15 | "num_bytes",
16 | ]
17 |
--------------------------------------------------------------------------------
/test/nn/decoder/test_excelformer_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import ExcelFormerDecoder
4 |
5 |
6 | def test_excelformer_decoder():
7 | batch_size = 10
8 | num_cols = 18
9 | in_channels = 8
10 | out_channels = 3
11 | x = torch.randn(batch_size, num_cols, in_channels)
12 | decoder = ExcelFormerDecoder(in_channels, out_channels, num_cols)
13 | y = decoder(x)
14 | assert y.shape == (batch_size, out_channels)
15 |
--------------------------------------------------------------------------------
/test/nn/decoder/test_trompt_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import TromptDecoder
4 |
5 |
6 | def test_trompt_decoder():
7 | batch_size = 10
8 | num_prompts = 2
9 | in_channels = 8
10 | out_channels = 1
11 | x_prompt = torch.randn(batch_size, num_prompts, in_channels)
12 | decoder = TromptDecoder(in_channels, out_channels, num_prompts)
13 | y = decoder(x_prompt)
14 | assert y.shape == (batch_size, out_channels)
15 |
--------------------------------------------------------------------------------
/torch_frame/nn/conv/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Convolutional layer package."""
2 | from .table_conv import TableConv
3 | from .ft_transformer_convs import FTTransformerConvs
4 | from .trompt_conv import TromptConv
5 | from .excelformer_conv import ExcelFormerConv
6 | from .tab_transformer_conv import TabTransformerConv
7 |
8 | __all__ = classes = [
9 | 'TableConv',
10 | 'FTTransformerConvs',
11 | 'TromptConv',
12 | 'ExcelFormerConv',
13 | 'TabTransformerConv',
14 | ]
15 |
--------------------------------------------------------------------------------
/test/nn/encoding/test_positional_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import PositionalEncoding
4 |
5 |
6 | def test_positional_encoding():
7 | out_size = 8
8 | for size in [(10, ), (10, 4), (10, 5, 8)]:
9 | input_tensor = torch.randint(0, 10, size=size)
10 | positional_encoding = PositionalEncoding(out_size)
11 | out_tensor = positional_encoding(input_tensor)
12 | assert out_tensor.shape == input_tensor.shape + (out_size, )
13 |
--------------------------------------------------------------------------------
/test/data/test_loader.py:
--------------------------------------------------------------------------------
1 | from torch_frame.data import DataLoader, TensorFrame
2 |
3 |
4 | def test_data_loader(get_fake_tensor_frame):
5 | tf = get_fake_tensor_frame(num_rows=10)
6 | loader = DataLoader(tf, batch_size=3)
7 | assert len(loader) == 4
8 |
9 | for i, batch in enumerate(loader):
10 | assert isinstance(batch, TensorFrame)
11 | if i + 1 < len(loader):
12 | assert len(batch) == 3
13 | else:
14 | assert len(batch) == 1
15 |
--------------------------------------------------------------------------------
/test/nn/conv/test_excelformer_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import ExcelFormerConv
4 |
5 |
6 | def test_excelformer_conv():
7 | batch_size = 10
8 | channels = 16
9 | num_cols = 15
10 | num_heads = 8
11 | # Feature-based embeddings
12 | x = torch.randn(size=(batch_size, num_cols, channels))
13 | conv = ExcelFormerConv(channels, num_cols, num_heads=num_heads)
14 | x_out = conv(x)
15 | assert x_out.shape == (batch_size, num_cols, channels)
16 |
--------------------------------------------------------------------------------
/torch_frame/nn/models/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Model package."""
2 | from .trompt import Trompt
3 | from .ft_transformer import FTTransformer
4 | from .excelformer import ExcelFormer
5 | from .tabnet import TabNet
6 | from .resnet import ResNet
7 | from .tab_transformer import TabTransformer
8 | from .mlp import MLP
9 |
10 | __all__ = classes = [
11 | 'Trompt',
12 | 'FTTransformer',
13 | 'ExcelFormer',
14 | 'TabNet',
15 | 'ResNet',
16 | 'TabTransformer',
17 | 'MLP',
18 | ]
19 |
--------------------------------------------------------------------------------
/test/nn/conv/test_tab_transformer_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import TabTransformerConv
4 |
5 |
6 | def test_tab_transformer_conv():
7 | batch_size = 10
8 | channels = 16
9 | num_cols = 15
10 | num_heads = 8
11 | # Feature-based embeddings
12 | x = torch.randn(size=(batch_size, num_cols, channels))
13 | conv = TabTransformerConv(channels, num_heads=num_heads, attn_dropout=0.)
14 | x_out = conv(x)
15 | assert x_out.shape == (batch_size, num_cols, channels)
16 |
--------------------------------------------------------------------------------
/.github/workflows/changelog.yml:
--------------------------------------------------------------------------------
1 | name: Changelog Enforcer
2 |
3 | on: # yamllint disable-line rule:truthy
4 | pull_request:
5 | types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled]
6 |
7 | jobs:
8 |
9 | changelog:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Checkout repository
14 | uses: actions/checkout@v6
15 |
16 | - name: Enforce changelog entry
17 | uses: dangoslen/changelog-enforcer@v3
18 | with:
19 | skipLabels: 'skip-changelog'
20 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Building Documentation
2 |
3 | To build the documentation:
4 |
5 | 1. [Build and install](https://github.com/pyg-team/pytorch-frame/blob/master/.github/CONTRIBUTING.md) PyTorch Frame from source.
6 | 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via
7 | ```
8 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git
9 | ```
10 | 1. Generate the documentation file via:
11 | ```
12 | cd docs
13 | make html
14 | ```
15 |
16 | The documentation is now available to view by opening `docs/build/html/index.html`.
17 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/dependabot-options-reference
2 | version: 2
3 | updates:
4 | - package-ecosystem: "github-actions"
5 | directories:
6 | - "/"
7 | - "/.github/actions/setup"
8 | schedule:
9 | interval: "daily"
10 | time: "00:00"
11 | labels:
12 | - "ci"
13 | - "skip-changelog"
14 | pull-request-branch-name:
15 | separator: "-"
16 | open-pull-requests-limit: 10
17 | reviewers:
18 | - "akihironitta"
19 | assignees:
20 | - "akihironitta"
21 |
--------------------------------------------------------------------------------
/test/nn/conv/test_trompt_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import TromptConv
4 |
5 |
6 | def test_trompt_conv():
7 | batch_size = 10
8 | channels = 8
9 | num_cols = 5
10 | num_prompts = 2
11 | # Feature-based embeddings
12 | x = torch.randn(size=(batch_size, num_cols, channels))
13 | # Prompt embeddings
14 | x_prompt = torch.randn(size=(batch_size, num_prompts, channels))
15 | conv = TromptConv(channels=8, num_cols=num_cols, num_prompts=num_prompts)
16 | x_prompt = conv(x, x_prompt)
17 | assert x_prompt.shape == (batch_size, num_prompts, channels)
18 |
--------------------------------------------------------------------------------
/torch_frame/config/model.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from dataclasses import dataclass
3 |
4 | from torch import Tensor
5 |
6 | from torch_frame.typing import TensorData
7 |
8 |
9 | @dataclass
10 | class ModelConfig:
11 | r"""Learnable model that maps a single-column :class:`TensorData` object
12 | into row embeddings.
13 |
14 | Args:
15 | model (callable): A callable model that takes a :obj:`TensorData`
16 | object of shape :obj:`[batch_size, 1, *]` as input and outputs
17 | embeddings of shape :obj:`[batch_size, 1, out_channels]`.
18 | out_channels (int): Model output channels.
19 |
20 | """
21 | model: Callable[[TensorData], Tensor]
22 | out_channels: int
23 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | ---
2 | cff-version: 1.2.0
3 | message: "Please cite our paper if you use this code in your own work."
4 | title: "PyTorch Frame: A Deep Learning Framework for Tabular Data"
5 | authors:
6 | - family-names: "Hu"
7 | given-names: "Weihua"
8 | - family-names: "Yuan"
9 | given-names: "Yiwen"
10 | - family-names: "Zhang"
11 | given-names: "Zecheng"
12 | - family-names: "Nitta"
13 | given-names: "Akihiro"
14 | - family-names: "Cao"
15 | given-names: "Kaidi"
16 | - family-names: "Kocijan"
17 | given-names: "Vid"
18 | - family-names: "Leskovec"
19 | given-names: "Jure"
20 | - family-names: "Fey"
21 | given-names: "Matthias"
22 | date-released: 2023-10-24
23 | license: MIT
24 | url: "https://github.com/pyg-team/pytorch-frame"
25 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: PR Labeler
2 |
3 | on: # yamllint disable-line rule:truthy
4 | pull_request:
5 |
6 | jobs:
7 |
8 | triage:
9 | if: github.repository == 'pyg-team/pytorch-frame'
10 | runs-on: ubuntu-latest
11 |
12 | permissions:
13 | contents: read
14 | pull-requests: write
15 |
16 | steps:
17 | - name: Add PR labels
18 | uses: actions/labeler@v6
19 | continue-on-error: true
20 | with:
21 | repo-token: "${{ secrets.GITHUB_TOKEN }}"
22 | sync-labels: true
23 |
24 | - name: Add PR author
25 | uses: samspills/assign-pr-to-author@v1.0
26 | if: github.event_name == 'pull_request'
27 | continue-on-error: true
28 | with:
29 | repo-token: "${{ secrets.GITHUB_TOKEN }}"
30 |
--------------------------------------------------------------------------------
/docs/source/_static/js/version_alert.js:
--------------------------------------------------------------------------------
1 | function warnOnLatestVersion() {
2 | if (!window.READTHEDOCS_DATA || window.READTHEDOCS_DATA.version !== "latest") {
3 | return; // not on ReadTheDocs and not latest.
4 | }
5 |
6 | var note = document.createElement('div');
7 | note.setAttribute('class', 'admonition note');
8 | note.innerHTML = "
Note
" +
9 | " " +
10 | "This documentation is for an unreleased development version. " +
11 | "Click here to access the documentation of the current stable release." +
12 | "
";
13 |
14 | var parent = document.querySelector('#pyg-documentation');
15 | if (parent)
16 | parent.insertBefore(note, parent.querySelector('h1'));
17 | }
18 |
19 | document.addEventListener('DOMContentLoaded', warnOnLatestVersion);
20 |
--------------------------------------------------------------------------------
/torch_frame/testing/image_embedder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | from PIL import Image
5 | from torch import Tensor
6 |
7 | from torch_frame.config.image_embedder import ImageEmbedder
8 |
9 |
10 | class RandomImageEmbedder(ImageEmbedder):
11 | r"""A random-based light-weight image embedder for testing
12 | purposes. It opens each image and generates a random embedding
13 | with :obj:`out_channels` embedding size.
14 |
15 | Args:
16 | out_channels (int): The output dimensionality
17 | """
18 | def __init__(
19 | self,
20 | out_channels: int,
21 | ) -> None:
22 | super().__init__()
23 | self.out_channels = out_channels
24 |
25 | def forward_embed(self, images: list[Image.Image]) -> Tensor:
26 | return torch.rand(len(images), self.out_channels)
27 |
--------------------------------------------------------------------------------
/test/utils/test_split.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split
4 |
5 |
6 | def test_generate_random_split():
7 | num_data = 20
8 | train_ratio = 0.8
9 | val_ratio = 0.1
10 | test_ratio = 0.1
11 |
12 | split = generate_random_split(num_data, seed=42, train_ratio=train_ratio,
13 | val_ratio=val_ratio)
14 | assert (split == SPLIT_TO_NUM['train']).sum() == int(num_data *
15 | train_ratio)
16 | assert (split == SPLIT_TO_NUM['val']).sum() == int(num_data * val_ratio)
17 | assert (split == SPLIT_TO_NUM['test']).sum() == int(num_data * test_ratio)
18 | assert np.allclose(
19 | split,
20 | np.array([0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0]),
21 | )
22 |
--------------------------------------------------------------------------------
/torch_frame/data/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | from .tensor_frame import TensorFrame
4 | from .multi_embedding_tensor import MultiEmbeddingTensor
5 | from .multi_nested_tensor import MultiNestedTensor
6 | from .stats import StatType
7 | from .dataset import Dataset, DataFrameToTensorFrameConverter
8 | from .loader import DataLoader
9 | from .download import download_url
10 |
11 | data_classes = [
12 | 'TensorFrame',
13 | 'MultiEmbeddingTensor',
14 | 'MultiNestedTensor',
15 | 'Dataset',
16 | ]
17 |
18 | loader_classes = [
19 | 'DataLoader',
20 | ]
21 |
22 | stats_classes = [
23 | 'StatType',
24 | ]
25 |
26 | helper_functions = [
27 | 'download_url',
28 | 'DataFrameToTensorFrameConverter',
29 | ]
30 |
31 | __all__ = data_classes + loader_classes + stats_classes + helper_functions
32 |
33 | classes = data_classes + loader_classes + stats_classes
34 |
--------------------------------------------------------------------------------
/docs/source/get_started/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | :pyf:`PyTorch Frame` is available for `Python 3.10` to `Python 3.13` on Linux, Windows and macOS.
5 |
6 | Installation via PyPI
7 | ---------------------
8 |
9 | .. code-block:: bash
10 |
11 | pip install pytorch-frame
12 |
13 | # Install with optional dependencies
14 | pip install pytorch-frame[full]
15 |
16 |
17 | Installation from master
18 | ------------------------
19 |
20 | .. code-block:: bash
21 |
22 | pip install git+https://github.com/pyg-team/pytorch-frame.git
23 |
24 |
25 | Installation for development
26 | ----------------------------
27 |
28 | .. code-block:: bash
29 |
30 | git clone https://github.com/pyg-team/pytorch-frame.git
31 | cd pytorch-frame
32 | pip install -e .[dev]
33 |
34 | # Install with optional dependencies
35 | pip install -e .[dev,full]
36 |
--------------------------------------------------------------------------------
/test/nn/models/test_tabnet.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_frame.data import Dataset
4 | from torch_frame.datasets import FakeDataset
5 | from torch_frame.nn import TabNet
6 |
7 |
8 | @pytest.mark.parametrize('batch_size', [0, 5])
9 | def test_tabnet(batch_size):
10 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False)
11 | dataset.materialize()
12 | tensor_frame = dataset.tensor_frame[:batch_size]
13 | out_channels = 12
14 | model = TabNet(out_channels=out_channels, num_layers=3,
15 | split_feat_channels=8, split_attn_channels=8, gamma=1.2,
16 | col_stats=dataset.col_stats,
17 | col_names_dict=tensor_frame.col_names_dict)
18 | model.reset_parameters()
19 | out, reg = model(tensor_frame, return_reg=True)
20 | assert out.shape == (len(tensor_frame), out_channels)
21 | assert reg >= 0
22 |
--------------------------------------------------------------------------------
/test/nn/models/test_resnet.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_frame.data.dataset import Dataset
4 | from torch_frame.datasets import FakeDataset
5 | from torch_frame.nn import ResNet
6 |
7 |
8 | @pytest.mark.parametrize('batch_size', [0, 5])
9 | def test_resnet(batch_size):
10 | channels = 8
11 | out_channels = 1
12 | num_layers = 3
13 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False)
14 | dataset.materialize()
15 | tensor_frame = dataset.tensor_frame[:batch_size]
16 | # Feature-based embeddings
17 | model = ResNet(
18 | channels=channels,
19 | out_channels=out_channels,
20 | num_layers=num_layers,
21 | col_stats=dataset.col_stats,
22 | col_names_dict=tensor_frame.col_names_dict,
23 | )
24 | model.reset_parameters()
25 | out = model(tensor_frame)
26 | assert out.shape == (batch_size, out_channels)
27 |
--------------------------------------------------------------------------------
/test/nn/models/test_ft_transformer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_frame.data.dataset import Dataset
4 | from torch_frame.datasets import FakeDataset
5 | from torch_frame.nn import FTTransformer
6 |
7 |
8 | @pytest.mark.parametrize('batch_size', [0, 5])
9 | def test_ft_transformer(batch_size):
10 | channels = 8
11 | out_channels = 1
12 | num_layers = 3
13 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False)
14 | dataset.materialize()
15 | tensor_frame = dataset.tensor_frame[:batch_size]
16 | # Feature-based embeddings
17 | model = FTTransformer(
18 | channels=channels,
19 | out_channels=out_channels,
20 | num_layers=num_layers,
21 | col_stats=dataset.col_stats,
22 | col_names_dict=tensor_frame.col_names_dict,
23 | )
24 | model.reset_parameters()
25 | out = model(tensor_frame)
26 | assert out.shape == (batch_size, out_channels)
27 |
--------------------------------------------------------------------------------
/test/nn/encoding/test_cyclic_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_frame.nn import CyclicEncoding
4 |
5 |
6 | def test_cyclic_encoding_shape():
7 | out_size = 8
8 | for size in [(10, ), (10, 4), (10, 5, 8)]:
9 | input_tensor = torch.rand(size)
10 | cyclic_encoding = CyclicEncoding(out_size)
11 | out_tensor = cyclic_encoding(input_tensor)
12 | assert out_tensor.shape == input_tensor.shape + (out_size, )
13 |
14 |
15 | def test_cyclic_encoding_values():
16 | out_size = 8
17 | for size in [(10, ), (10, 4), (10, 5, 8)]:
18 | cyclic_encoding = CyclicEncoding(out_size)
19 | input_zeros_tensor = torch.zeros(size)
20 | input_ones_tensor = torch.ones(size)
21 | out_zeros_tensor = cyclic_encoding(input_zeros_tensor)
22 | out_ones_tensor = cyclic_encoding(input_ones_tensor)
23 | assert torch.allclose(out_zeros_tensor, out_ones_tensor, atol=1e-5)
24 |
--------------------------------------------------------------------------------
/torch_frame/nn/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Encoder package."""
2 | from .encoder import FeatureEncoder
3 | from .stypewise_encoder import StypeWiseFeatureEncoder
4 | from .stype_encoder import (
5 | StypeEncoder,
6 | EmbeddingEncoder,
7 | MultiCategoricalEmbeddingEncoder,
8 | LinearEncoder,
9 | LinearBucketEncoder,
10 | LinearPeriodicEncoder,
11 | ExcelFormerEncoder,
12 | LinearEmbeddingEncoder,
13 | LinearModelEncoder,
14 | StackEncoder,
15 | TimestampEncoder,
16 | )
17 |
18 | __all__ = classes = [
19 | 'FeatureEncoder',
20 | 'StypeWiseFeatureEncoder',
21 | 'StypeEncoder',
22 | 'EmbeddingEncoder',
23 | 'MultiCategoricalEmbeddingEncoder',
24 | 'LinearEncoder',
25 | 'LinearBucketEncoder',
26 | 'LinearPeriodicEncoder',
27 | 'ExcelFormerEncoder',
28 | 'LinearEmbeddingEncoder',
29 | 'LinearModelEncoder',
30 | 'StackEncoder',
31 | 'TimestampEncoder',
32 | ]
33 |
--------------------------------------------------------------------------------
/.github/workflows/linting.yml:
--------------------------------------------------------------------------------
1 | name: Linting
2 |
3 | on: # yamllint disable-line rule:truthy
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ github.event_name == 'pull_request' }}
12 |
13 | jobs:
14 | mypy:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Checkout repository
18 | uses: actions/checkout@v6
19 |
20 | - name: Set up Python
21 | uses: actions/setup-python@v6
22 | with:
23 | python-version: '3.10'
24 |
25 | - name: Install dependencies
26 | run: |
27 | # TODO: Use the latest PyTorch version once type issues are addressed in the codebase:
28 | pip install -e '.[full,test]' 'torch==2.7.*' -f https://download.pytorch.org/whl/cpu
29 | pip list
30 |
31 | - name: Check type hints
32 | run: mypy
33 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/pyg-team/pytorch-frame
2 |
3 | PyTorch Frame Documentation
4 | ===========================
5 | :pyf:`null` **PyTorch Frame** is a library built upon :pytorch:`null` `PyTorch `_ to easily write and train tabular deep learning models.
6 |
7 | .. slack_button::
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 | :caption: Get Started
12 |
13 | get_started/installation
14 | get_started/introduction
15 | get_started/modular_design
16 |
17 | .. toctree::
18 | :maxdepth: 1
19 | :caption: Handling Advanced Stypes
20 |
21 | handling_advanced_stypes/handle_heterogeneous_stypes
22 | handling_advanced_stypes/handle_text
23 |
24 | .. toctree::
25 | :maxdepth: 1
26 | :caption: Package Reference
27 |
28 | modules/root
29 | modules/data
30 | modules/datasets
31 | modules/nn
32 | modules/gbdt
33 | modules/config
34 | modules/transforms
35 | modules/utils
36 |
--------------------------------------------------------------------------------
/torch_frame/nn/conv/table_conv.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 | from torch import Tensor
5 | from torch.nn import Module
6 |
7 |
8 | class TableConv(Module, ABC):
9 | r"""Base class for table convolution that transforms the input column-wise
10 | pytorch tensor.
11 | """
12 | @abstractmethod
13 | def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any:
14 | r"""Process column-wise 3-dimensional tensor into another column-wise
15 | 3-dimensional tensor.
16 |
17 | Args:
18 | x (torch.Tensor): Input column-wise tensor of shape
19 | :obj:`[batch_size, num_cols, hidden_channels]`.
20 | args (Any): Extra arguments.
21 | kwargs (Any): Extra keyword arguments.
22 | """
23 | raise NotImplementedError
24 |
25 | def reset_parameters(self) -> None:
26 | r"""Resets all learnable parameters of the module."""
27 |
--------------------------------------------------------------------------------
/torch_frame/utils/memory.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from torch import Tensor
4 |
5 | from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
6 | from torch_frame.data.multi_nested_tensor import MultiNestedTensor
7 |
8 |
9 | def num_bytes(data: Any) -> int:
10 | r"""Returns the number of bytes the tensor data consumes.
11 |
12 | Args:
13 | data: The tensor data.
14 | """
15 | if isinstance(data, Tensor):
16 | return data.element_size() * data.numel()
17 | if isinstance(data, MultiNestedTensor | MultiEmbeddingTensor):
18 | return num_bytes(data.values) + num_bytes(data.offset)
19 | if isinstance(data, list):
20 | return sum([num_bytes(value) for value in data])
21 | if isinstance(data, dict):
22 | return sum([num_bytes(value) for value in data.values()])
23 |
24 | raise NotImplementedError(f"'num_bytes' not implemented for "
25 | f"'{type(data)}'")
26 |
--------------------------------------------------------------------------------
/test/nn/models/test_mlp.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_frame.data.dataset import Dataset
4 | from torch_frame.datasets import FakeDataset
5 | from torch_frame.nn import MLP
6 |
7 |
8 | @pytest.mark.parametrize('batch_size', [0, 5])
9 | @pytest.mark.parametrize('normalization', ["layer_norm", "batch_norm"])
10 | def test_mlp(batch_size, normalization):
11 | channels = 8
12 | out_channels = 1
13 | num_layers = 3
14 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False)
15 | dataset.materialize()
16 | tensor_frame = dataset.tensor_frame[:batch_size]
17 | # Feature-based embeddings
18 | model = MLP(
19 | channels=channels,
20 | out_channels=out_channels,
21 | num_layers=num_layers,
22 | col_stats=dataset.col_stats,
23 | col_names_dict=tensor_frame.col_names_dict,
24 | normalization=normalization,
25 | )
26 | model.reset_parameters()
27 | out = model(tensor_frame)
28 | assert out.shape == (batch_size, out_channels)
29 |
--------------------------------------------------------------------------------
/docs/source/modules/datasets.rst:
--------------------------------------------------------------------------------
1 | torch_frame.datasets
2 | ====================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | .. currentmodule:: torch_frame.datasets
8 |
9 | Real-World Datasets
10 | -------------------
11 |
12 | .. autosummary::
13 | :nosignatures:
14 | :toctree: ../generated
15 | :template: autosummary/class.rst
16 |
17 | {% for name in torch_frame.datasets.real_world_datasets %}
18 | {{ name }}
19 | {% endfor %}
20 |
21 | Synthetic Datasets
22 | ------------------
23 |
24 | .. autosummary::
25 | :nosignatures:
26 | :toctree: ../generated
27 | :template: autosummary/class.rst
28 |
29 | {% for name in torch_frame.datasets.synthetic_datasets %}
30 | {{ name }}
31 | {% endfor %}
32 |
33 | Other Datasets
34 | --------------
35 |
36 | .. autosummary::
37 | :nosignatures:
38 | :toctree: ../generated
39 | :template: autosummary/class.rst
40 |
41 | {% for name in torch_frame.datasets.other_datasets %}
42 | {{ name }}
43 | {% endfor %}
44 |
--------------------------------------------------------------------------------
/.github/workflows/auto-merge.yml:
--------------------------------------------------------------------------------
1 | name: Dependabot auto-merge
2 |
3 | on: # yamllint disable-line rule:truthy
4 | pull_request_target:
5 | types: [opened, reopened]
6 |
7 | permissions:
8 | contents: write
9 | pull-requests: write
10 |
11 | jobs:
12 | auto-merge:
13 | runs-on: ubuntu-latest
14 | if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]' }}
15 | steps:
16 | - uses: actions/checkout@v6
17 |
18 | - name: Label bot PRs
19 | run: gh pr edit --add-label "ci,skip-changelog" ${{ github.event.pull_request.html_url }}
20 | env:
21 | GITHUB_TOKEN: ${{ secrets.PAT }}
22 |
23 | - name: Auto-approve
24 | uses: hmarr/auto-approve-action@v4
25 | with:
26 | github-token: ${{ secrets.PAT }}
27 |
28 | - name: Enable auto-merge
29 | run: gh pr merge --auto --squash ${{ github.event.pull_request.html_url }}
30 | env:
31 | GITHUB_TOKEN: ${{ secrets.PAT }}
32 |
--------------------------------------------------------------------------------
/torch_frame/nn/decoder/decoder.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 | from torch import Tensor
5 | from torch.nn import Module
6 |
7 |
8 | class Decoder(Module, ABC):
9 | r"""Base class for decoder that transforms the input column-wise PyTorch
10 | tensor into output tensor on which prediction head is applied.
11 | """
12 | @abstractmethod
13 | def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Any:
14 | r"""Decode :obj:`x` of shape :obj:`[batch_size, num_cols, channels]`
15 | into an output tensor of shape :obj:`[batch_size, out_channels]`.
16 |
17 | Args:
18 | x (torch.Tensor): Input column-wise tensor of shape
19 | :obj:`[batch_size, num_cols, hidden_channels]`.
20 | args (Any): Extra arguments.
21 | kwargs (Any): Extra keyward arguments.
22 | """
23 | raise NotImplementedError
24 |
25 | def reset_parameters(self) -> None:
26 | r"""Resets all learnable parameters of the module."""
27 |
--------------------------------------------------------------------------------
/test/utils/test_memory.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_frame.data import MultiEmbeddingTensor, MultiNestedTensor
5 | from torch_frame.utils import num_bytes
6 |
7 |
8 | def test_num_bytes():
9 | data = torch.randn(4, 8)
10 | assert num_bytes(data) == 4 * 8 * 4
11 |
12 | data = MultiNestedTensor(
13 | num_rows=3,
14 | num_cols=2,
15 | values=torch.randn(12),
16 | offset=torch.tensor([0, 2, 4, 6, 8, 10, 12]),
17 | )
18 | assert num_bytes(data) == 12 * 4 + 7 * 8
19 |
20 | data = MultiEmbeddingTensor(
21 | num_rows=2,
22 | num_cols=3,
23 | values=torch.randn(2, 10),
24 | offset=torch.tensor([0, 3, 5, 10]),
25 | )
26 | assert num_bytes(data) == 2 * 10 * 4 + 4 * 8
27 |
28 | mapping = {
29 | 'A': data,
30 | 'B': data,
31 | }
32 | assert num_bytes(mapping) == 2 * num_bytes(data)
33 |
34 | seq = [data, data, data]
35 | assert num_bytes(seq) == 3 * num_bytes(data)
36 |
37 | with pytest.raises(NotImplementedError):
38 | num_bytes("unsupported")
39 |
--------------------------------------------------------------------------------
/torch_frame/config/text_tokenizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from dataclasses import dataclass
5 |
6 | from torch_frame.typing import TextTokenizationOutputs
7 |
8 |
9 | @dataclass
10 | class TextTokenizerConfig:
11 | r"""Text tokenizer that maps a list of strings/sentences into a
12 | dictionary of :class:`MultiNestedTensor`.
13 |
14 | Args:
15 | text_tokenizer (callable): A callable text tokenizer that takes a
16 | list of strings as input and outputs a list of dictionaries.
17 | Each dictionary contains keys that are arguments to the text
18 | encoder model and values are corresponding tensors such as
19 | tokens and attention masks.
20 | batch_size (int, optional): Batch size to use when tokenizing the
21 | sentences. If set to :obj:`None`, the text embeddings will
22 | be obtained in a full-batch manner. (default: :obj:`None`)
23 |
24 | """
25 | text_tokenizer: Callable[[list[str]], TextTokenizationOutputs]
26 | batch_size: int | None = None
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2023 PyG Team
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/docs/source/modules/data.rst:
--------------------------------------------------------------------------------
1 | torch_frame.data
2 | ================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | .. currentmodule:: torch_frame.data
8 |
9 | Data Objects
10 | ------------
11 |
12 | .. autosummary::
13 | :nosignatures:
14 | :toctree: ../generated
15 | :template: autosummary/class.rst
16 |
17 | {% for name in torch_frame.data.data_classes %}
18 | {{ name }}
19 | {% endfor %}
20 |
21 | Stats
22 | -----
23 |
24 | .. autosummary::
25 | :nosignatures:
26 | :toctree: ../generated
27 | :template: autosummary/class.rst
28 |
29 | {% for name in torch_frame.data.stats_classes %}
30 | {{ name }}
31 | {% endfor %}
32 |
33 | Data Loaders
34 | ------------
35 |
36 | .. autosummary::
37 | :nosignatures:
38 | :toctree: ../generated
39 | :template: autosummary/class.rst
40 |
41 | {% for name in torch_frame.data.loader_classes %}
42 | {{ name }}
43 | {% endfor %}
44 |
45 | Helper Functions
46 | ----------------
47 |
48 | .. autosummary::
49 | :nosignatures:
50 | :toctree: ../generated
51 |
52 | {% for name in torch_frame.data.helper_functions %}
53 | {{ name }}
54 | {% endfor %}
55 |
--------------------------------------------------------------------------------
/torch_frame/config/text_embedder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from dataclasses import dataclass
5 |
6 | from torch import Tensor
7 |
8 |
9 | @dataclass
10 | class TextEmbedderConfig:
11 | r"""Text embedder model that maps a list of strings/sentences into PyTorch
12 | Tensor embeddings.
13 |
14 | Args:
15 | text_embedder (callable): A callable text embedder that takes a list
16 | of strings as input and outputs the PyTorch Tensor embeddings for
17 | that list of strings.
18 | batch_size (int, optional): Batch size to use when encoding the
19 | sentences. If set to :obj:`None`, the text embeddings will
20 | be obtained in a full-batch manner. (default: :obj:`None`)
21 |
22 | """
23 | text_embedder: Callable[[list[str]], Tensor]
24 | # Batch size to use when encoding the sentences. It is recommended to set
25 | # it to a reasonable value when one uses a heavy text embedding model
26 | # (e.g., Transformer) on GPU. If set to :obj:`None`, the text embeddings
27 | # will be obtained in a full-batch manner.
28 | batch_size: int | None = None
29 |
--------------------------------------------------------------------------------
/test/test_stype.py:
--------------------------------------------------------------------------------
1 | import torch_frame
2 |
3 |
4 | def test_stype():
5 | assert len(torch_frame.stype) == 9
6 | assert torch_frame.numerical == torch_frame.stype('numerical')
7 | assert not torch_frame.numerical.is_text_stype
8 | assert torch_frame.categorical == torch_frame.stype('categorical')
9 | assert not torch_frame.categorical.is_text_stype
10 | assert torch_frame.multicategorical == torch_frame.stype(
11 | 'multicategorical')
12 | assert not torch_frame.multicategorical.is_text_stype
13 | assert torch_frame.sequence_numerical == torch_frame.stype(
14 | 'sequence_numerical')
15 | assert not torch_frame.sequence_numerical.is_text_stype
16 | assert torch_frame.text_embedded == torch_frame.stype('text_embedded')
17 | assert torch_frame.text_embedded.is_text_stype
18 | assert torch_frame.text_tokenized == torch_frame.stype('text_tokenized')
19 | assert torch_frame.text_tokenized.is_text_stype
20 | assert torch_frame.image_embedded == torch_frame.stype('image_embedded')
21 | assert torch_frame.image_embedded.is_image_stype
22 | assert torch_frame.embedding == torch_frame.stype('embedding')
23 | assert torch_frame.embedding.use_multi_embedding_tensor
24 |
--------------------------------------------------------------------------------
/test/nn/models/test_tab_transformer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_frame import stype
4 | from torch_frame.data.dataset import Dataset
5 | from torch_frame.datasets import FakeDataset
6 | from torch_frame.nn import TabTransformer
7 |
8 |
9 | @pytest.mark.parametrize('stypes', [[stype.categorical, stype.numerical],
10 | [stype.categorical], [stype.numerical]])
11 | @pytest.mark.parametrize('batch_size', [0, 5])
12 | def test_tab_transformer(stypes, batch_size):
13 | channels = 8
14 | out_channels = 1
15 | num_layers = 3
16 | num_heads = 2
17 | encoder_pad_size = 2
18 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False, stypes=stypes)
19 | dataset.materialize()
20 | tensor_frame = dataset.tensor_frame[:batch_size]
21 | model = TabTransformer(
22 | channels=channels,
23 | out_channels=out_channels,
24 | num_layers=num_layers,
25 | num_heads=num_heads,
26 | encoder_pad_size=encoder_pad_size,
27 | attn_dropout=0.,
28 | ffn_dropout=0.,
29 | col_stats=dataset.col_stats,
30 | col_names_dict=tensor_frame.col_names_dict,
31 | )
32 | model.reset_parameters()
33 | out = model(tensor_frame)
34 | assert out.shape == (batch_size, out_channels)
35 |
--------------------------------------------------------------------------------
/torch_frame/__init__.py:
--------------------------------------------------------------------------------
1 | r"""Utility package."""
2 | from ._stype import (
3 | stype,
4 | numerical,
5 | categorical,
6 | text_embedded,
7 | text_tokenized,
8 | multicategorical,
9 | sequence_numerical,
10 | timestamp,
11 | image_embedded,
12 | embedding,
13 | )
14 | from .data import TensorFrame
15 | from .typing import (
16 | TaskType,
17 | Metric,
18 | DataFrame,
19 | NAStrategy,
20 | WITH_PT24,
21 | )
22 | from torch_frame.utils import save, load, cat # noqa
23 | import torch_frame.data # noqa
24 | import torch_frame.datasets # noqa
25 | import torch_frame.nn # noqa
26 | import torch_frame.gbdt # noqa
27 |
28 | if WITH_PT24:
29 | import torch
30 |
31 | torch.serialization.add_safe_globals([
32 | stype,
33 | torch_frame.data.stats.StatType,
34 | ])
35 |
36 | # https://peps.python.org/pep-0440/
37 | __version__ = '0.4.0.dev0'
38 |
39 | __all__ = [
40 | 'DataFrame',
41 | 'stype',
42 | 'numerical',
43 | 'categorical',
44 | 'text_embedded',
45 | 'text_tokenized',
46 | 'multicategorical',
47 | 'sequence_numerical',
48 | 'timestamp',
49 | 'image_embedded',
50 | 'embedding',
51 | 'TaskType',
52 | 'Metric',
53 | 'NAStrategy',
54 | 'TensorFrame',
55 | 'save',
56 | 'load',
57 | 'cat',
58 | 'torch_frame',
59 | '__version__',
60 | ]
61 |
--------------------------------------------------------------------------------
/torch_frame/testing/text_embedder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Embedding
6 |
7 |
8 | class HashTextEmbedder:
9 | r"""A hash-based light-weight text embedder for testing purposes.
10 | It hashes each sentence into an index modulo :obj:`num_hash` and then
11 | uses :class:`torch.nn.Embedding` to look up the index to obtain the
12 | sentence embedding.
13 |
14 | Args:
15 | out_channels (int): The output dimensionality
16 | num_hash_bins (int): Number of hash bins to use.
17 | (default: :obj:`64`)
18 | device (torch.device, optional): The device to put :class:`Embedding`
19 | module. (default: :obj:`None`)
20 | """
21 | def __init__(
22 | self,
23 | out_channels: int,
24 | num_hash_bins: int = 64,
25 | device: torch.device | None = None,
26 | ) -> None:
27 | self.out_channels = out_channels
28 | self.num_hash_bins = num_hash_bins
29 | self.device = device
30 | self.embedding = Embedding(num_hash_bins, out_channels).to(device)
31 |
32 | def __call__(self, sentences: list[str]) -> Tensor:
33 | idx = torch.tensor([hash(s) % self.num_hash_bins for s in sentences],
34 | device=self.device)
35 | return self.embedding(idx).detach()
36 |
--------------------------------------------------------------------------------
/torch_frame/nn/utils/init.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn.init import _calculate_correct_fan, calculate_gain
6 |
7 |
8 | def attenuated_kaiming_uniform_(
9 | tensor: Tensor,
10 | scale: float = 0.1,
11 | a: float = math.sqrt(5),
12 | mode: str = 'fan_in',
13 | nonlinearity: str = 'leaky_relu',
14 | ) -> Tensor:
15 | r"""Attenuated Kaiming Uniform Initialization.
16 |
17 | Args:
18 | tensor (tensor): Input tensor to be initialized
19 | scale (float): Positive rescaling constant to the variance.
20 | a (float): Negative slope of the rectifier used after this layer
21 | mode (str): Either 'fan_in' (default) or 'fan_out'. Choosing
22 | 'fan_in' preserves the magnitude of the variance of the weights
23 | in the forward pass. Choosing 'fan_out' preserves the magnitudes
24 | in the backwards pass.
25 | nonlinearity (str) : the non-linear function (nn.functional name),
26 | recommended to use only with 'relu' or 'leaky_relu'.
27 | """
28 | with torch.no_grad():
29 | fan = _calculate_correct_fan(tensor, mode)
30 | gain = calculate_gain(nonlinearity, a)
31 | std = gain * scale / math.sqrt(fan)
32 | # Calculate uniform bounds from standard deviation
33 | bound = math.sqrt(3.0) * std
34 | return tensor.uniform_(-bound, bound)
35 |
--------------------------------------------------------------------------------
/torch_frame/nn/encoder/encoder.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from torch import Tensor
4 | from torch.nn import Module
5 |
6 | from torch_frame import TensorFrame
7 |
8 |
9 | class FeatureEncoder(Module, ABC):
10 | r"""Base class for feature encoder that transforms input
11 | :class:`torch_frame.TensorFrame` into :obj:`(x, col_names)`,
12 | where :obj:`x` is the colum-wise PyTorch tensor of shape
13 | :obj:`[batch_size, num_cols, channels]` and :obj:`col_names` is the
14 | names of the columns. This class contains learnable parameters and missing
15 | value handling.
16 | """
17 | @abstractmethod
18 | def forward(self, tf: TensorFrame) -> tuple[Tensor, list[str]]:
19 | r"""Encode :class:`TensorFrame` object into a tuple
20 | :obj:`(x, col_names)`.
21 |
22 | Args:
23 | tf (:class:`torch_frame.TensorFrame`): Input :class:`TensorFrame`
24 | object.
25 |
26 | Returns:
27 | (torch.Tensor, List[str]): A tuple of an output column-wise
28 | :class:`torch.Tensor` of shape
29 | :obj:`[batch_size, num_cols, hidden_channels]` and a list of
30 | column names of :obj:`x`. The length needs to be
31 | :obj:`num_cols`.
32 | """
33 | raise NotImplementedError
34 |
35 | def reset_parameters(self) -> None:
36 | r"""Resets all learnable parameters of the module."""
37 |
--------------------------------------------------------------------------------
/.github/workflows/documentation.yml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 |
3 | on: # yamllint disable-line rule:truthy
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ github.event_name == 'pull_request' }}
12 |
13 | jobs:
14 |
15 | make_html:
16 | runs-on: ubuntu-latest
17 |
18 | steps:
19 | - name: Checkout repository
20 | uses: actions/checkout@v6
21 | with:
22 | fetch-depth: 40
23 |
24 | # Skip workflow if only certain files have been changed.
25 | - name: Get changed files
26 | id: changed-files-specific
27 | uses: tj-actions/changed-files@v47
28 | with:
29 | files: |
30 | examples/**
31 | README.md
32 | CHANGELOG.md
33 |
34 | - name: Setup packages
35 | if: steps.changed-files-specific.outputs.only_changed != 'true'
36 | uses: ./.github/actions/setup
37 |
38 | - name: Install main package
39 | if: steps.changed-files-specific.outputs.only_changed != 'true'
40 | run: |
41 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git
42 | pip install -e .
43 | pip list
44 |
45 |
46 | - name: Build documentation
47 | if: steps.changed-files-specific.outputs.only_changed != 'true'
48 | run: |
49 | cd docs && make clean && make html SPHINXOPTS="-W" # Fail on warning.
50 |
--------------------------------------------------------------------------------
/torch_frame/data/download.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import os.path as osp
5 | import ssl
6 | import sys
7 | import urllib.request
8 |
9 |
10 | def download_url(
11 | url: str,
12 | root: str,
13 | filename: str | None = None,
14 | *,
15 | log: bool = True,
16 | ) -> str:
17 | r"""Downloads the content of :obj:`url` to the specified folder
18 | :obj:`root`.
19 |
20 | Args:
21 | url (str): The URL.
22 | root (str): The root folder.
23 | filename (str, optional): If set, will rename the downloaded file.
24 | (default: :obj:`None`)
25 | log (bool, optional): If :obj:`False`, will not print anything to the
26 | console. (default: :obj:`True`)
27 | """
28 | if filename is None:
29 | filename = url.rpartition('/')[2]
30 | if filename[0] != '?':
31 | filename = filename.split('?')[0]
32 |
33 | path = osp.join(root, filename)
34 |
35 | if osp.exists(path):
36 | return path
37 |
38 | if log and 'pytest' not in sys.modules:
39 | print(f'Downloading {url}', file=sys.stderr)
40 |
41 | os.makedirs(root, exist_ok=True)
42 |
43 | context = ssl._create_unverified_context()
44 | data = urllib.request.urlopen(url, context=context)
45 |
46 | with open(path, 'wb') as f:
47 | while True:
48 | chunk = data.read(10 * 1024 * 1024)
49 | if not chunk:
50 | break
51 | f.write(chunk)
52 |
53 | return path
54 |
--------------------------------------------------------------------------------
/torch_frame/nn/encoding/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import Module
4 |
5 |
6 | class PositionalEncoding(Module):
7 | r"""Positional encoding introduced in `"Attention Is All You Need"
8 | `_ paper. Given an input tensor of shape
9 | :obj:`(*, )`, this encoding expands it into an output tensor of shape
10 | :obj:`(*, out_size)`.
11 |
12 | Args:
13 | out_size (int): The output dimension size.
14 | """
15 | def __init__(self, out_size: int) -> None:
16 | super().__init__()
17 | if out_size % 2 != 0:
18 | raise ValueError(
19 | f"out_size should be divisible by 2 (got {out_size}).")
20 | self.out_size = out_size
21 | self.mult_term: Tensor
22 | self.register_buffer(
23 | "mult_term",
24 | torch.pow(
25 | 1 / 10000.0,
26 | torch.arange(0, self.out_size, 2) / out_size,
27 | ),
28 | )
29 |
30 | def forward(self, input_tensor: Tensor) -> Tensor:
31 | assert torch.all(input_tensor >= 0)
32 | # (*, 1) * (1, ..., 1, out_size // 2) -> (*, out_size // 2)
33 | mult_tensor = input_tensor.unsqueeze(-1) * self.mult_term.reshape(
34 | (1, ) * input_tensor.ndim + (-1, ))
35 | # cat([(*, out_size // 2), (*, out_size // 2)]) -> (*, out_size)
36 | return torch.cat([torch.sin(mult_tensor),
37 | torch.cos(mult_tensor)], dim=-1)
38 |
--------------------------------------------------------------------------------
/torch_frame/nn/encoding/cyclic_encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Module
6 |
7 |
8 | class CyclicEncoding(Module):
9 | r"""Cyclic encoding for input data containing values between 0 and 1.
10 | This function maps each value in the input using sine and cosine
11 | functions of different wavelengths to preserve the cyclical nature. This
12 | is particularly useful for encoding cyclical features like hours of a
13 | day, days of the week, etc. Given an input tensor of shape
14 | :obj:`(*, )`, this encoding expands it into an output tensor of shape
15 | :obj:`(*, out_size)`.
16 |
17 | Args:
18 | out_size (int): The output dimension size.
19 | """
20 | def __init__(self, out_size: int) -> None:
21 | super().__init__()
22 | if out_size % 2 != 0:
23 | raise ValueError(
24 | f"out_size should be divisible by 2 (got {out_size}).")
25 | self.out_size = out_size
26 | self.mult_term: Tensor
27 | self.register_buffer(
28 | "mult_term",
29 | torch.arange(1, self.out_size // 2 + 1),
30 | )
31 |
32 | def forward(self, input_tensor: Tensor) -> Tensor:
33 | assert torch.all((input_tensor >= 0) & (input_tensor <= 1))
34 | mult_tensor = input_tensor.unsqueeze(-1) * self.mult_term.reshape(
35 | (1, ) * input_tensor.ndim + (-1, ))
36 | return torch.cat([
37 | torch.sin(mult_tensor * math.pi),
38 | torch.cos(mult_tensor * 2 * math.pi)
39 | ], dim=-1)
40 |
--------------------------------------------------------------------------------
/torch_frame/utils/split.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Mapping split name to integer.
4 | SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2}
5 |
6 |
7 | def generate_random_split(length: int, seed: int, train_ratio: float = 0.8,
8 | val_ratio: float = 0.1,
9 | include_test: bool = True) -> np.ndarray:
10 | r"""Generate a list of random split assignments of the specified length.
11 | The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing
12 | train, val, test, respectively. Note that this function relies on the fact
13 | that numpy's shuffle is consistent across versions, which has been
14 | historically the case.
15 | """
16 | assert train_ratio > 0
17 | assert val_ratio > 0
18 |
19 | if include_test:
20 | assert train_ratio + val_ratio < 1
21 | train_num = int(length * train_ratio)
22 | val_num = int(length * val_ratio)
23 | test_num = length - train_num - val_num
24 | arr = np.concatenate([
25 | np.full(train_num, SPLIT_TO_NUM['train']),
26 | np.full(val_num, SPLIT_TO_NUM['val']),
27 | np.full(test_num, SPLIT_TO_NUM['test'])
28 | ])
29 | else:
30 | assert train_ratio + val_ratio == 1
31 | train_num = int(length * train_ratio)
32 | val_num = length - train_num
33 | arr = np.concatenate([
34 | np.full(train_num, SPLIT_TO_NUM['train']),
35 | np.full(val_num, SPLIT_TO_NUM['val']),
36 | ])
37 | np.random.seed(seed)
38 | np.random.shuffle(arr)
39 |
40 | return arr
41 |
--------------------------------------------------------------------------------
/torch_frame/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .adult_census_income import AdultCensusIncome
3 | from .amazon_fine_food_reviews import AmazonFineFoodReviews
4 | from .amphibians import Amphibians
5 | from .bank_marketing import BankMarketing
6 | from .data_frame_benchmark import DataFrameBenchmark
7 | from .data_frame_text_benchmark import DataFrameTextBenchmark
8 | from .diamond_images import DiamondImages
9 | from .dota2 import Dota2
10 | from .fake import FakeDataset
11 | from .forest_cover_type import ForestCoverType
12 | from .huggingface_dataset import HuggingFaceDatasetDict
13 | from .kdd_census_income import KDDCensusIncome
14 | from .mercari import Mercari
15 | from .movielens_1m import Movielens1M
16 | from .multimodal_text_benchmark import MultimodalTextBenchmark
17 | from .mushroom import Mushroom
18 | from .poker_hand import PokerHand
19 | from .tabular_benchmark import TabularBenchmark
20 | from .titanic import Titanic
21 | from .yandex import Yandex
22 |
23 | real_world_datasets = [
24 | 'AdultCensusIncome',
25 | 'AmazonFineFoodReviews',
26 | 'Amphibians',
27 | 'BankMarketing',
28 | 'DataFrameBenchmark',
29 | 'DataFrameTextBenchmark',
30 | 'Dota2',
31 | 'Titanic',
32 | 'ForestCoverType',
33 | 'HuggingFaceDatasetDict',
34 | 'KDDCensusIncome',
35 | 'Mercari',
36 | 'Movielens1M',
37 | 'MultimodalTextBenchmark',
38 | 'Mushroom',
39 | 'PokerHand',
40 | 'TabularBenchmark',
41 | 'Yandex',
42 | 'DiamondImages',
43 | ]
44 |
45 | synthetic_datasets = [
46 | 'FakeDataset',
47 | ]
48 |
49 | other_datasets = [
50 | 'HuggingFaceDatasetDict',
51 | ]
52 |
53 | __all__ = real_world_datasets + synthetic_datasets + other_datasets
54 |
--------------------------------------------------------------------------------
/torch_frame/datasets/titanic.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | import torch_frame
4 |
5 |
6 | class Titanic(torch_frame.data.Dataset):
7 | r"""The Titanic dataset from the `Titanic Kaggle competition
8 | `_.
9 | The Titanic dataset is known as the MNIST equivalent for tabular learning.
10 | The goal is to predict which passenger survived using passenger data
11 | (*i.e.* gender, age, etc).
12 |
13 | **STATS:**
14 |
15 | .. list-table::
16 | :widths: 10 10 10 10 20 10
17 | :header-rows: 1
18 |
19 | * - #rows
20 | - #cols (numerical)
21 | - #cols (categorical)
22 | - #classes
23 | - Task
24 | - Missing value ratio
25 | * - 891
26 | - 4
27 | - 3
28 | - 2
29 | - binary_classification
30 | - 8.8%
31 | """
32 |
33 | url = 'https://github.com/datasciencedojo/datasets/raw/master/titanic.csv'
34 |
35 | def __init__(self, root: str) -> None:
36 | path = self.download_url(self.url, root)
37 | df = pd.read_csv(path, index_col=['PassengerId'])
38 |
39 | col_to_stype = { # TODO Use 'Name', 'Ticket' and 'Cabin'.
40 | 'Survived': torch_frame.categorical,
41 | 'Pclass': torch_frame.categorical,
42 | 'Sex': torch_frame.categorical,
43 | 'Age': torch_frame.numerical,
44 | 'SibSp': torch_frame.numerical,
45 | 'Parch': torch_frame.numerical,
46 | 'Fare': torch_frame.numerical,
47 | 'Embarked': torch_frame.categorical,
48 | }
49 |
50 | super().__init__(df, col_to_stype, target_col='Survived')
51 |
--------------------------------------------------------------------------------
/.github/actions/setup/action.yml:
--------------------------------------------------------------------------------
1 | name: Setup
2 |
3 | inputs:
4 | python-version:
5 | required: false
6 | default: '3.10'
7 | torch-version:
8 | required: false
9 | default: '2.9'
10 | cuda-version:
11 | required: false
12 | default: cpu
13 |
14 | runs:
15 | using: composite
16 |
17 | steps:
18 | - name: Set up Python ${{ inputs.python-version }}
19 | uses: actions/setup-python@v6
20 | with:
21 | python-version: ${{ inputs.python-version }}
22 | check-latest: true
23 | cache: pip
24 | cache-dependency-path: |
25 | pyproject.toml
26 |
27 | - name: Pre-install NumPy<2 if necessary
28 | run: |
29 | [[ ${{ inputs.torch-version }} =~ ^2\.[0-2] ]] && pip install 'numpy<2' || echo "Skipping NumPy<2 installation"
30 | shell: bash
31 |
32 | - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }}
33 | if: ${{ inputs.torch-version != 'nightly' }}
34 | run: |
35 | pip install torch==${{ inputs.torch-version }}.* --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }}
36 | shell: bash
37 |
38 | - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }}
39 | if: ${{ inputs.torch-version == 'nightly' }}
40 | run: |
41 | pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/${{ inputs.cuda-version }}
42 | shell: bash
43 |
44 | - name: List installed packages
45 | run: |
46 | pip list
47 | python -c "import torch; print('PyTorch:', torch.__version__)"
48 | python -c "import torch; print('CUDA:', torch.version.cuda)"
49 | shell: bash
50 |
--------------------------------------------------------------------------------
/torch_frame/data/loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 |
5 | from torch_frame.data import Dataset, TensorFrame
6 | from torch_frame.typing import IndexSelectType
7 |
8 |
9 | class DataLoader(torch.utils.data.DataLoader):
10 | r"""A data loader which creates mini-batches from a
11 | :class:`torch_frame.Dataset` or :class:`torch_frame.TensorFrame` object.
12 |
13 | .. code-block:: python
14 |
15 | import torch_frame
16 |
17 | dataset = ...
18 |
19 | loader = torch_frame.data.DataLoader(
20 | dataset,
21 | batch_size=512,
22 | shuffle=True,
23 | )
24 |
25 | Args:
26 | dataset (Dataset or TensorFrame): The dataset or tensor frame from
27 | which to load the data.
28 | *args (optional): Additional arguments of
29 | :class:`torch.utils.data.DataLoader`.
30 | **kwargs (optional): Additional keyword arguments of
31 | :class:`torch.utils.data.DataLoader`.
32 | """
33 | def __init__(
34 | self,
35 | dataset: Dataset | TensorFrame,
36 | *args,
37 | **kwargs,
38 | ):
39 | kwargs.pop('collate_fn', None)
40 |
41 | if isinstance(dataset, Dataset):
42 | self.tensor_frame: TensorFrame = dataset.materialize().tensor_frame
43 | else:
44 | self.tensor_frame: TensorFrame = dataset
45 |
46 | super().__init__(
47 | range(len(dataset)),
48 | *args,
49 | collate_fn=self.collate_fn,
50 | **kwargs,
51 | )
52 |
53 | def collate_fn(self, index: IndexSelectType) -> TensorFrame:
54 | return self.tensor_frame[index]
55 |
--------------------------------------------------------------------------------
/test/nn/models/test_trompt.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from torch_frame import stype
4 | from torch_frame.data.dataset import Dataset
5 | from torch_frame.datasets import FakeDataset
6 | from torch_frame.nn import EmbeddingEncoder, LinearEncoder, Trompt
7 |
8 |
9 | @pytest.mark.parametrize('batch_size', [0, 5])
10 | @pytest.mark.parametrize('use_stype_encoder_dicts', [
11 | True,
12 | False,
13 | ])
14 | def test_trompt(batch_size, use_stype_encoder_dicts):
15 | batch_size = 10
16 | channels = 8
17 | out_channels = 1
18 | num_prompts = 2
19 | num_layers = 3
20 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False)
21 | dataset.materialize()
22 | tensor_frame = dataset.tensor_frame[:batch_size]
23 | if use_stype_encoder_dicts:
24 | stype_encoder_dicts = [
25 | {
26 | stype.numerical: LinearEncoder(),
27 | stype.categorical: EmbeddingEncoder(),
28 | },
29 | {
30 | stype.numerical: LinearEncoder(),
31 | stype.categorical: EmbeddingEncoder(),
32 | },
33 | {
34 | stype.numerical: LinearEncoder(),
35 | stype.categorical: EmbeddingEncoder(),
36 | },
37 | ]
38 | else:
39 | stype_encoder_dicts = None
40 | model = Trompt(
41 | channels=channels,
42 | out_channels=out_channels,
43 | num_prompts=num_prompts,
44 | num_layers=num_layers,
45 | col_stats=dataset.col_stats,
46 | col_names_dict=tensor_frame.col_names_dict,
47 | stype_encoder_dicts=stype_encoder_dicts,
48 | )
49 | model.reset_parameters()
50 | pred = model(tensor_frame)
51 | assert pred.shape == (batch_size, num_layers, out_channels)
52 |
--------------------------------------------------------------------------------
/torch_frame/nn/decoder/excelformer_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from torch.nn import Linear, PReLU
4 |
5 | from torch_frame.nn.decoder import Decoder
6 |
7 |
8 | class ExcelFormerDecoder(Decoder):
9 | r"""The ExcelFormer decoder introduced in the
10 | `"ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data"
11 | `_ paper.
12 |
13 | Args:
14 | in_channels (int): Input channel dimensionality
15 | out_channels (int): Output channel dimensionality
16 | num_cols (int): Number of columns.
17 | """
18 | def __init__(
19 | self,
20 | in_channels: int,
21 | out_channels: int,
22 | num_cols: int,
23 | ) -> None:
24 | super().__init__()
25 | self.in_channels = in_channels
26 | self.out_channels = out_channels
27 | self.lin_f = Linear(num_cols, self.out_channels)
28 | self.activation = PReLU()
29 | self.lin_d = Linear(self.in_channels, 1)
30 | self.reset_parameters()
31 |
32 | def reset_parameters(self) -> None:
33 | self.lin_f.reset_parameters()
34 | self.lin_d.reset_parameters()
35 | with torch.no_grad():
36 | self.activation.weight.fill_(0.25)
37 |
38 | def forward(self, x: Tensor) -> Tensor:
39 | r"""Transforming :obj:`x` into output predictions.
40 |
41 | Args:
42 | x (Tensor): Input column-wise tensor of shape
43 | [batch_size, num_cols, in_channels]
44 |
45 | Returns:
46 | Tensor: [batch_size, out_channels].
47 | """
48 | x = x.transpose(1, 2)
49 | x = self.lin_f(x)
50 | x = self.activation(x)
51 | x = self.lin_d(x.transpose(1, 2)).squeeze(2)
52 | return x
53 |
--------------------------------------------------------------------------------
/docs/source/modules/nn.rst:
--------------------------------------------------------------------------------
1 | torch_frame.nn
2 | ==============
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | .. currentmodule:: torch_frame.nn
8 |
9 | torch_frame.nn.encoder
10 | ----------------------
11 |
12 | .. currentmodule:: torch_frame.nn.encoder
13 |
14 | .. autosummary::
15 | :nosignatures:
16 | :toctree: ../generated
17 | :template: autosummary/class.rst
18 |
19 | {% for name in torch_frame.nn.encoder.classes %}
20 | {{ name }}
21 | {% endfor %}
22 |
23 |
24 | torch_frame.nn.encoding
25 | ------------------------
26 |
27 | .. currentmodule:: torch_frame.nn.encoding
28 |
29 | .. autosummary::
30 | :nosignatures:
31 | :toctree: ../generated
32 | :template: autosummary/class.rst
33 |
34 | {% for name in torch_frame.nn.encoding.classes %}
35 | {{ name }}
36 | {% endfor %}
37 |
38 | torch_frame.nn.conv
39 | ------------------------
40 |
41 | .. currentmodule:: torch_frame.nn.conv
42 |
43 | .. autosummary::
44 | :nosignatures:
45 | :toctree: ../generated
46 | :template: autosummary/class.rst
47 |
48 | {% for name in torch_frame.nn.conv.classes %}
49 | {{ name }}
50 | {% endfor %}
51 |
52 | torch_frame.nn.decoder
53 | ------------------------
54 |
55 | .. currentmodule:: torch_frame.nn.decoder
56 |
57 | .. autosummary::
58 | :nosignatures:
59 | :toctree: ../generated
60 | :template: autosummary/class.rst
61 |
62 | {% for name in torch_frame.nn.decoder.classes %}
63 | {{ name }}
64 | {% endfor %}
65 |
66 | torch_frame.nn.models
67 | ------------------------
68 |
69 | .. currentmodule:: torch_frame.nn.models
70 |
71 | .. autosummary::
72 | :nosignatures:
73 | :toctree: ../generated
74 | :template: autosummary/class.rst
75 |
76 | {% for name in torch_frame.nn.models.classes %}
77 | {{ name }}
78 | {% endfor %}
79 |
--------------------------------------------------------------------------------
/test/datasets/test_titanic.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 |
3 | import torch
4 |
5 | import torch_frame
6 | from torch_frame.data.stats import StatType
7 | from torch_frame.datasets import Titanic
8 |
9 |
10 | def test_titanic():
11 | with tempfile.TemporaryDirectory() as temp_dir:
12 | dataset = Titanic(temp_dir)
13 | assert str(dataset) == 'Titanic()'
14 | assert len(dataset) == 891
15 | assert dataset.feat_cols == [
16 | 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'
17 | ]
18 |
19 | dataset = dataset.materialize()
20 |
21 | tensor_frame = dataset.tensor_frame
22 | assert len(tensor_frame.feat_dict) == 2
23 | assert tensor_frame.feat_dict[torch_frame.numerical].dtype == torch.float
24 | assert tensor_frame.feat_dict[torch_frame.numerical].size() == (891, 4)
25 | assert tensor_frame.feat_dict[torch_frame.categorical].dtype == torch.long
26 | assert tensor_frame.feat_dict[torch_frame.categorical].size() == (891, 3)
27 | assert tensor_frame.col_names_dict == {
28 | torch_frame.categorical: ['Embarked', 'Pclass', 'Sex'],
29 | torch_frame.numerical: ['Age', 'Fare', 'Parch', 'SibSp'],
30 | }
31 | assert tensor_frame.y.size() == (891, )
32 | assert tensor_frame.y.min() == 0 and tensor_frame.y.max() == 1
33 |
34 | col_stats = dataset.col_stats
35 | assert len(col_stats) == 8
36 | assert StatType.COUNT in col_stats['Survived']
37 | assert StatType.COUNT in col_stats['Pclass']
38 | assert StatType.COUNT in col_stats['Sex']
39 | assert StatType.COUNT in col_stats['Embarked']
40 | assert StatType.MEAN and StatType.STD in col_stats['Age']
41 | assert StatType.MEAN and StatType.STD in col_stats['SibSp']
42 | assert StatType.MEAN and StatType.STD in col_stats['Parch']
43 | assert StatType.MEAN and StatType.STD in col_stats['Fare']
44 |
--------------------------------------------------------------------------------
/torch_frame/nn/decoder/trompt_decoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import Tensor
3 | from torch.nn import LayerNorm, Linear, ReLU, Sequential
4 |
5 | from torch_frame.nn.decoder import Decoder
6 |
7 |
8 | class TromptDecoder(Decoder):
9 | r"""The Trompt downstream introduced in
10 | `"Trompt: Towards a Better Deep Neural Network for Tabular Data"
11 | `_ paper.
12 |
13 | Args:
14 | in_channels (int): Input channel dimensionality
15 | out_channels (int): Output channel dimensionality
16 | num_prompts (int): Number of prompt columns.
17 | """
18 | def __init__(
19 | self,
20 | in_channels: int,
21 | out_channels: int,
22 | num_prompts: int,
23 | ) -> None:
24 | super().__init__()
25 | self.in_channels = in_channels
26 | self.num_prompts = num_prompts
27 | self.lin_attn = Linear(in_channels, 1)
28 | self.mlp = Sequential(
29 | Linear(in_channels, in_channels),
30 | ReLU(),
31 | LayerNorm(in_channels),
32 | Linear(in_channels, out_channels),
33 | )
34 | self.reset_parameters()
35 |
36 | def reset_parameters(self) -> None:
37 | self.lin_attn.reset_parameters()
38 | for m in self.mlp:
39 | if not isinstance(m, ReLU):
40 | m.reset_parameters()
41 |
42 | def forward(self, x: Tensor) -> Tensor:
43 | batch_size = len(x)
44 | assert x.shape == (batch_size, self.num_prompts, self.in_channels)
45 | # [batch_size, num_prompts, 1]
46 | w_prompt = F.softmax(self.lin_attn(x), dim=1)
47 | # [batch_size, in_channels]
48 | x = (w_prompt * x).sum(dim=1)
49 | # [batch_size, out_channels]
50 | x = self.mlp(x)
51 | return x
52 |
--------------------------------------------------------------------------------
/test/utils/test_concat.py:
--------------------------------------------------------------------------------
1 | import torch_frame
2 | from torch_frame import TensorFrame
3 |
4 |
5 | def test_cat_along_row(get_fake_tensor_frame):
6 | num_rows = 10
7 | num_repeats = 5
8 | tf: TensorFrame = get_fake_tensor_frame(num_rows=num_rows)
9 | tf_cat = torch_frame.cat([tf for _ in range(num_repeats)], dim=0)
10 | assert len(tf_cat) == num_rows * num_repeats
11 | for i in range(num_repeats):
12 | tf_mini = tf_cat[num_rows * i:num_rows * (i + 1)]
13 | assert tf_mini == tf
14 |
15 |
16 | def test_cat_along_col(get_fake_tensor_frame):
17 | num_rows = 10
18 | tf = get_fake_tensor_frame(num_rows=num_rows)
19 | stypes = list(tf.col_names_dict.keys())
20 | feat_dict1 = {}
21 | feat_dict2 = {}
22 | col_names_dict1 = {}
23 | col_names_dict2 = {}
24 | for stype in stypes:
25 | if stype.use_dict_multi_nested_tensor:
26 | feat_dict1[stype] = {
27 | name: tf.feat_dict[stype][name][:, :1]
28 | for name in tf.feat_dict[stype].keys()
29 | }
30 | col_names_dict1[stype] = tf.col_names_dict[stype][:1]
31 | feat_dict2[stype] = {
32 | name: tf.feat_dict[stype][name][:, 1:]
33 | for name in tf.feat_dict[stype].keys()
34 | }
35 | col_names_dict2[stype] = tf.col_names_dict[stype][1:]
36 | else:
37 | feat_dict1[stype] = tf.feat_dict[stype][:, :1]
38 | col_names_dict1[stype] = tf.col_names_dict[stype][:1]
39 | feat_dict2[stype] = tf.feat_dict[stype][:, 1:]
40 | col_names_dict2[stype] = tf.col_names_dict[stype][1:]
41 |
42 | tf1 = TensorFrame(feat_dict1, col_names_dict1, tf.y)
43 | tf2 = TensorFrame(feat_dict2, col_names_dict2, None)
44 | tf_cat = torch_frame.cat([tf1, tf2], dim=1)
45 | assert tf_cat == tf
46 |
--------------------------------------------------------------------------------
/torch_frame/testing/decorators.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from importlib import import_module
3 | from importlib.util import find_spec
4 |
5 | import torch
6 | from packaging.requirements import Requirement
7 | from packaging.version import Version
8 |
9 |
10 | def has_package(package: str) -> bool:
11 | r"""Returns :obj:`True` in case :obj:`package` is installed."""
12 | if '|' in package:
13 | return any(has_package(p) for p in package.split('|'))
14 |
15 | req = Requirement(package)
16 | if find_spec(req.name) is None:
17 | return False
18 | module = import_module(req.name)
19 | if not hasattr(module, '__version__'):
20 | return True
21 |
22 | version = Version(module.__version__).base_version
23 | return version in req.specifier
24 |
25 |
26 | def withPackage(*args) -> Callable:
27 | r"""A decorator to skip tests if certain packages are not installed.
28 | Also supports version specification.
29 | """
30 | na_packages = {package for package in args if not has_package(package)}
31 |
32 | def decorator(func: Callable) -> Callable:
33 | import pytest
34 | return pytest.mark.skipif(
35 | len(na_packages) > 0,
36 | reason=f"Package(s) {na_packages} are not installed",
37 | )(func)
38 |
39 | return decorator
40 |
41 |
42 | def withCUDA(func: Callable):
43 | r"""A decorator to test both on CPU and CUDA (if available)."""
44 | import pytest
45 |
46 | devices = [pytest.param(torch.device('cpu'), id='cpu')]
47 | if torch.cuda.is_available():
48 | devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))
49 |
50 | return pytest.mark.parametrize('device', devices)(func)
51 |
52 |
53 | def onlyCUDA(func: Callable) -> Callable:
54 | r"""A decorator to skip tests if CUDA is not found."""
55 | import pytest
56 | return pytest.mark.skipif(
57 | not torch.cuda.is_available(),
58 | reason="CUDA not available",
59 | )(func)
60 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os.path as osp
3 | import sys
4 |
5 | import pyg_sphinx_theme
6 |
7 | import torch_frame
8 |
9 | author = 'PyG Team'
10 | project = 'pytorch-frame'
11 | version = torch_frame.__version__
12 | copyright = f'{datetime.datetime.now().year}, {author}'
13 |
14 | sys.path.append(osp.join(osp.dirname(pyg_sphinx_theme.__file__), 'extension'))
15 |
16 | extensions = [
17 | 'sphinx.ext.autodoc',
18 | 'sphinx.ext.autosummary',
19 | 'sphinx.ext.intersphinx',
20 | 'sphinx.ext.mathjax',
21 | 'sphinx.ext.napoleon',
22 | 'sphinx.ext.viewcode',
23 | 'sphinx_copybutton',
24 | 'pyg',
25 | ]
26 |
27 | html_theme = 'pyg_sphinx_theme'
28 | html_logo = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/'
29 | 'master/pyg_sphinx_theme/static/img/pytorch_frame_logo.png')
30 | html_favicon = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/'
31 | 'master/pyg_sphinx_theme/static/img/pytorch_frame_favicon.png')
32 | html_static_path = ['_static']
33 | templates_path = ['_templates']
34 |
35 | add_module_names = False
36 | autodoc_member_order = 'bysource'
37 |
38 | suppress_warnings = ['autodoc.import_object']
39 |
40 | intersphinx_mapping = {
41 | 'python': ('https://docs.python.org/3', None),
42 | 'numpy': ('http://docs.scipy.org/doc/numpy', None),
43 | 'pandas': ('http://pandas.pydata.org/pandas-docs/dev', None),
44 | 'torch': ('https://pytorch.org/docs/stable', None),
45 | 'optuna': ('https://optuna.readthedocs.io/en/stable/', None),
46 | 'xgboost': ('https://xgboost.readthedocs.io/en/stable/', None),
47 | }
48 |
49 | copybutton_prompt_text = r">>> |\.\.\. "
50 | copybutton_prompt_is_regexp = True
51 |
52 |
53 | def setup(app):
54 | def rst_jinja_render(app, _, source):
55 | rst_context = {'torch_frame': torch_frame}
56 | source[0] = app.builder.templates.render_string(source[0], rst_context)
57 |
58 | app.connect('source-read', rst_jinja_render)
59 | app.add_js_file('js/version_alert.js')
60 |
--------------------------------------------------------------------------------
/torch_frame/transforms/base_transform.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import copy
4 | from abc import ABC, abstractmethod
5 | from typing import Any
6 |
7 | from torch_frame import TensorFrame
8 | from torch_frame.data.stats import StatType
9 |
10 |
11 | class BaseTransform(ABC):
12 | r"""An abstract base class for writing transforms.
13 |
14 | Transforms are a general way to modify and customize
15 | :class:`TensorFrame`
16 | """
17 | def __init__(self):
18 | self._transformed_stats: dict[str, dict[StatType, Any]] | None = None
19 |
20 | def __call__(self, tf: TensorFrame) -> TensorFrame:
21 | # Shallow-copy the data so that we prevent in-place data modification.
22 | return self.forward(copy.copy(tf))
23 |
24 | @abstractmethod
25 | def forward(self, tf: TensorFrame) -> TensorFrame:
26 | r"""Process TensorFrame obj into another TensorFrame obj.
27 |
28 | Args:
29 | tf (TensorFrame): Input :class:`TensorFrame`.
30 |
31 | Returns:
32 | TensorFrame: Input :class:`TensorFrame` after transform.
33 | """
34 | return tf
35 |
36 | @property
37 | def transformed_stats(self) -> dict[str, dict[StatType, Any]]:
38 | r"""The column stats after the transform.
39 |
40 | Returns:
41 | transformed_stats (Dict[str, Dict[StatType, Any]]):
42 | Transformed column stats. The :class:`TensorFrame` object might
43 | be modified by the transform, so the returned
44 | :obj:`transformed_stats` would contain the column stats of the
45 | modified :class:`TensorFrame` object.
46 | """
47 | if self._transformed_stats is None:
48 | raise ValueError("Transformed column stats is not computed yet. "
49 | "Please run necessary functions to compute this"
50 | " first.")
51 | return self._transformed_stats
52 |
53 | def __repr__(self) -> str:
54 | return f'{self.__class__.__name__}()'
55 |
--------------------------------------------------------------------------------
/test/datasets/test_movielens_1m.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 |
3 | import torch
4 |
5 | import torch_frame
6 | from torch_frame.config.text_embedder import TextEmbedderConfig
7 | from torch_frame.data.stats import StatType
8 | from torch_frame.datasets import Movielens1M
9 | from torch_frame.testing.text_embedder import HashTextEmbedder
10 |
11 |
12 | def test_movielens_1m():
13 | with tempfile.TemporaryDirectory() as temp_dir:
14 | dataset = Movielens1M(
15 | temp_dir,
16 | col_to_text_embedder_cfg=TextEmbedderConfig(
17 | text_embedder=HashTextEmbedder(10)),
18 | )
19 | assert str(dataset) == 'Movielens1M()'
20 | assert len(dataset) == 1000209
21 | assert dataset.feat_cols == [
22 | 'user_id', 'gender', 'age', 'occupation', 'zip', 'movie_id', 'title',
23 | 'genres', 'timestamp'
24 | ]
25 |
26 | dataset = dataset.materialize()
27 |
28 | tensor_frame = dataset.tensor_frame
29 | assert len(tensor_frame.feat_dict) == 4
30 | assert tensor_frame.feat_dict[torch_frame.categorical].dtype == torch.int64
31 | assert tensor_frame.feat_dict[torch_frame.categorical].size() == (1000209,
32 | 6)
33 | assert tensor_frame.feat_dict[
34 | torch_frame.multicategorical].dtype == torch.int64
35 | assert tensor_frame.feat_dict[torch_frame.embedding].dtype == torch.float32
36 | assert tensor_frame.col_names_dict == {
37 | torch_frame.categorical:
38 | ['age', 'gender', 'movie_id', 'occupation', 'user_id', 'zip'],
39 | torch_frame.multicategorical: ['genres'],
40 | torch_frame.timestamp: ['timestamp'],
41 | torch_frame.embedding: ['title'],
42 | }
43 | assert tensor_frame.y.size() == (1000209, )
44 | assert tensor_frame.y.min() == 1 and tensor_frame.y.max() == 5
45 |
46 | col_stats = dataset.col_stats
47 | assert len(col_stats) == 10
48 | assert StatType.COUNT in col_stats['user_id']
49 | assert StatType.MULTI_COUNT in col_stats['genres']
50 | assert StatType.YEAR_RANGE in col_stats['timestamp']
51 | assert StatType.EMB_DIM in col_stats['title']
52 |
--------------------------------------------------------------------------------
/test/utils/test_infer_stype.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 |
4 | import torch_frame
5 | from torch_frame.config.text_embedder import TextEmbedderConfig
6 | from torch_frame.datasets import FakeDataset
7 | from torch_frame.testing.text_embedder import HashTextEmbedder
8 | from torch_frame.utils import infer_df_stype
9 |
10 |
11 | def get_fake_dataset(
12 | num_rows: int,
13 | col_to_text_embedder_cfg: TextEmbedderConfig,
14 | with_nan: bool,
15 | ) -> FakeDataset:
16 | stypes = [
17 | torch_frame.numerical,
18 | torch_frame.categorical,
19 | torch_frame.multicategorical,
20 | torch_frame.text_embedded,
21 | torch_frame.sequence_numerical,
22 | torch_frame.timestamp,
23 | torch_frame.embedding,
24 | ]
25 | dataset = FakeDataset(
26 | num_rows=num_rows,
27 | stypes=stypes,
28 | col_to_text_embedder_cfg=col_to_text_embedder_cfg,
29 | with_nan=with_nan,
30 | )
31 | return dataset
32 |
33 |
34 | @pytest.mark.parametrize("with_nan", [True, False])
35 | def test_infer_df_stype(with_nan):
36 | num_rows = 200
37 | col_to_text_embedder_cfg = TextEmbedderConfig(
38 | text_embedder=HashTextEmbedder(8))
39 | dataset = get_fake_dataset(num_rows, col_to_text_embedder_cfg, with_nan)
40 | col_to_stype_inferred = infer_df_stype(dataset.df)
41 | assert col_to_stype_inferred == dataset.col_to_stype
42 |
43 |
44 | def test_infer_stypes():
45 | # Test when multicategoricals are lists
46 | df = pd.DataFrame({
47 | 'category': [['Books', 'Mystery, Thriller'],
48 | ['Books', "Children's Books", 'Geography'],
49 | ['Books', 'Health', 'Fitness & Dieting'],
50 | ['Books', 'Teen & oung Adult']] * 50,
51 | 'id': [i for i in range(200)]
52 | })
53 | col_to_stype_inferred = infer_df_stype(df)
54 | assert col_to_stype_inferred['category'] == torch_frame.multicategorical
55 |
56 | df = pd.DataFrame({'bool': [True] * 50 + [False] * 50})
57 |
58 | col_to_stype_inferred = infer_df_stype(df)
59 | assert col_to_stype_inferred['bool'] == torch_frame.categorical
60 |
--------------------------------------------------------------------------------
/torch_frame/datasets/dota2.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import zipfile
3 |
4 | import pandas as pd
5 |
6 | import torch_frame
7 |
8 |
9 | class Dota2(torch_frame.data.Dataset):
10 | r"""The `Dota2 Game Results
11 | `_
12 | dataset. Dota2 is a popular moba game with two teams of 5 players.
13 | At start of the game, each player choose a unique hero with
14 | different strengths and weakness. The dataset is reasonably sparse
15 | as only 10 of 113 possible heroes are chosen in a given game. All
16 | games were played in a space of 2 hours on the 13th of August 2016.
17 | The classification goal is to predict the winning team.
18 |
19 | **STATS:**
20 |
21 | .. list-table::
22 | :widths: 10 10 10 10 20 10
23 | :header-rows: 1
24 |
25 | * - #rows
26 | - #cols (numerical)
27 | - #cols (categorical)
28 | - #classes
29 | - Task
30 | - Missing value ratio
31 | * - 92,650
32 | - 0
33 | - 116
34 | - 2
35 | - binary_classification
36 | - 0.0%
37 | """
38 |
39 | url = 'https://archive.ics.uci.edu/static/public/367/dota2+games+results.zip' # noqa
40 |
41 | def __init__(self, root: str):
42 | path = self.download_url(self.url, root)
43 | names = [
44 | 'Team won the game',
45 | 'Cluster ID',
46 | 'Game mode',
47 | 'Game type',
48 | ]
49 | num_heroes = 113
50 | names += [f'hero_{i}' for i in range(num_heroes)]
51 | folder_path = osp.dirname(path)
52 | with zipfile.ZipFile(path, 'r') as zip_ref:
53 | zip_ref.extractall(folder_path)
54 |
55 | df = pd.read_csv(osp.join(folder_path, 'dota2Train.csv'), names=names)
56 |
57 | col_to_stype = {
58 | 'Team won the game': torch_frame.categorical,
59 | 'Cluster ID': torch_frame.categorical,
60 | 'Game mode': torch_frame.categorical,
61 | 'Game type': torch_frame.categorical,
62 | }
63 | for i in range(num_heroes):
64 | col_to_stype[f'hero_{i}'] = torch_frame.categorical
65 |
66 | super().__init__(df, col_to_stype, target_col='Team won the game')
67 |
--------------------------------------------------------------------------------
/test/transforms/test_mutual_information_sort.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_frame import TaskType, TensorFrame, stype
5 | from torch_frame.data import Dataset
6 | from torch_frame.datasets.fake import FakeDataset
7 | from torch_frame.transforms import MutualInformationSort
8 |
9 |
10 | @pytest.mark.parametrize('with_nan', [True, False])
11 | def test_mutual_information_sort(with_nan):
12 | task_type = TaskType.REGRESSION
13 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=with_nan,
14 | stypes=[stype.numerical], create_split=True,
15 | task_type=task_type)
16 | # modify the FakeDataset so column num_1 would have highest mutual
17 | # information score
18 | dataset.df['num_1'] = dataset.df['target'].astype(float)
19 | dataset.materialize()
20 |
21 | tensor_frame: TensorFrame = dataset.tensor_frame
22 | train_dataset = dataset.get_split('train')
23 | transform = MutualInformationSort(task_type)
24 | transform.fit(train_dataset.tensor_frame, train_dataset.col_stats)
25 | out = transform(tensor_frame)
26 |
27 | # column num_1 ranks the first
28 | assert (out.col_names_dict[stype.numerical][0] == 'num_1')
29 | actual_first_col = out.feat_dict[stype.numerical][:, 0]
30 | actual_first_col_nan_mask = torch.isnan(actual_first_col)
31 | expected_first_col = torch.tensor(dataset.df['num_1'].values,
32 | dtype=torch.float32)
33 | expected_first_col_nan_mask = torch.isnan(expected_first_col)
34 | # if the tensor on first column contains NaNs, make sure the NaNs
35 | # are unchanged
36 | assert (torch.equal(actual_first_col_nan_mask,
37 | expected_first_col_nan_mask))
38 | actual = actual_first_col[~actual_first_col_nan_mask]
39 | expected = expected_first_col[~expected_first_col_nan_mask]
40 | # make sure that the non NaN values are the same on first column
41 | assert (torch.allclose(actual, expected))
42 |
43 | # make sure the shapes are unchanged
44 | assert (set(out.col_names_dict[stype.numerical]) == set(
45 | tensor_frame.col_names_dict[stype.numerical]))
46 | assert (out.feat_dict[stype.numerical].size() == tensor_frame.feat_dict[
47 | stype.numerical].size())
48 |
--------------------------------------------------------------------------------
/test/nn/models/test_excelformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import pytest
4 | import torch
5 |
6 | from torch_frame import TaskType, stype
7 | from torch_frame.data.dataset import Dataset
8 | from torch_frame.datasets.fake import FakeDataset
9 | from torch_frame.nn import ExcelFormer
10 |
11 |
12 | @pytest.mark.parametrize('task_type', [
13 | TaskType.REGRESSION,
14 | TaskType.BINARY_CLASSIFICATION,
15 | TaskType.MULTICLASS_CLASSIFICATION,
16 | ])
17 | @pytest.mark.parametrize('batch_size', [0, 5])
18 | @pytest.mark.parametrize('mixup', [None, 'feature', 'hidden'])
19 | def test_excelformer(task_type, batch_size, mixup):
20 | in_channels = 8
21 | num_heads = 2
22 | num_layers = 6
23 | dataset: Dataset = FakeDataset(num_rows=10, with_nan=False,
24 | stypes=[stype.numerical],
25 | task_type=task_type)
26 | dataset.materialize()
27 | if task_type.is_classification:
28 | out_channels = dataset.num_classes
29 | else:
30 | out_channels = 1
31 | num_cols = len(dataset.col_stats) - 1
32 | tensor_frame = dataset.tensor_frame[:batch_size]
33 | model = ExcelFormer(
34 | in_channels=in_channels,
35 | out_channels=out_channels,
36 | num_cols=num_cols,
37 | num_layers=num_layers,
38 | num_heads=num_heads,
39 | mixup=mixup,
40 | col_stats=dataset.col_stats,
41 | col_names_dict=tensor_frame.col_names_dict,
42 | )
43 | model.reset_parameters()
44 |
45 | # Test the original forward pass
46 | out = model(tensor_frame)
47 | assert out.shape == (batch_size, out_channels)
48 |
49 | # Test the mixup forward pass
50 | feat_num = copy.copy(tensor_frame.feat_dict[stype.numerical])
51 | # Set lazy mutual information scores for `feature` mixup
52 | tensor_frame.mi_scores = torch.rand(torch.Size((feat_num.shape[1], )))
53 | out_mixedup, y_mixedup = model(tensor_frame, mixup_encoded=True)
54 | assert out_mixedup.shape == (batch_size, out_channels)
55 | # Make sure the numerical feature is not modified.
56 | assert torch.allclose(feat_num, tensor_frame.feat_dict[stype.numerical])
57 |
58 | if task_type.is_classification:
59 | assert y_mixedup.shape == (batch_size, out_channels)
60 | else:
61 | assert y_mixedup.shape == tensor_frame.y.shape
62 |
--------------------------------------------------------------------------------
/torch_frame/datasets/adult_census_income.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | import torch_frame
4 |
5 |
6 | class AdultCensusIncome(torch_frame.data.Dataset):
7 | r"""The `Adult Census Income
8 | `_
9 | dataset from Kaggle. It's extracted from census bureau database and the
10 | task is to predict whether a person's income exceeds $50K/year.
11 |
12 | **STATS:**
13 |
14 | .. list-table::
15 | :widths: 10 10 10 10 20 10
16 | :header-rows: 1
17 |
18 | * - #rows
19 | - #cols (numerical)
20 | - #cols (categorical)
21 | - #classes
22 | - Task
23 | - Missing value ratio
24 | * - 32,561
25 | - 4
26 | - 8
27 | - 2
28 | - binary_classification
29 | - 0.0%
30 | """
31 |
32 | url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data' # noqa
33 |
34 | def __init__(self, root: str):
35 | path = self.download_url(self.url, root)
36 | names = [
37 | 'age',
38 | 'workclass',
39 | 'fnlwgt',
40 | 'education',
41 | 'education.num',
42 | 'marital.status',
43 | 'occupation',
44 | 'relationship',
45 | 'race',
46 | 'sex',
47 | 'capital.gain',
48 | 'capital.loss',
49 | 'hours.per.week',
50 | 'native.country',
51 | 'income',
52 | ]
53 | df = pd.read_csv(path, names=names)
54 |
55 | col_to_stype = {
56 | 'age': torch_frame.numerical,
57 | 'workclass': torch_frame.categorical,
58 | 'education': torch_frame.categorical,
59 | 'marital.status': torch_frame.categorical,
60 | 'occupation': torch_frame.categorical,
61 | 'relationship': torch_frame.categorical,
62 | 'race': torch_frame.categorical,
63 | 'sex': torch_frame.categorical,
64 | 'capital.gain': torch_frame.numerical,
65 | 'capital.loss': torch_frame.numerical,
66 | 'hours.per.week': torch_frame.numerical,
67 | 'native.country': torch_frame.categorical,
68 | 'income': torch_frame.categorical,
69 | }
70 |
71 | super().__init__(df, col_to_stype, target_col='income')
72 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Publish Package
2 |
3 | on: # yamllint disable-line rule:truthy
4 | push:
5 | branches:
6 | - master
7 | release:
8 | types:
9 | - published
10 |
11 | defaults:
12 | run:
13 | shell: bash
14 |
15 | jobs:
16 | build-package:
17 | runs-on: ubuntu-latest
18 | steps:
19 | - uses: actions/checkout@v6
20 |
21 | - uses: actions/setup-python@v6
22 | with:
23 | python-version: '3.10'
24 |
25 | - name: Install build tools
26 | run: |
27 | pip install -U pip
28 | pip install -U flit twine
29 | pip list
30 |
31 | - name: Build package
32 | run: |
33 | flit build --no-use-vcs
34 | twine check dist/* --strict
35 |
36 | - uses: actions/upload-artifact@v5
37 | with:
38 | name: release-${{ github.sha }}
39 | path: dist
40 |
41 | publish-package-test:
42 | if: github.event_name == 'release'
43 | runs-on: ubuntu-latest
44 | needs: [build-package]
45 | steps:
46 | - uses: actions/checkout@v6
47 |
48 | - uses: actions/setup-python@v6
49 | with:
50 | python-version: '3.10'
51 |
52 | - name: Download package
53 | uses: actions/download-artifact@v6
54 | with:
55 | name: release-${{ github.sha }}
56 | path: dist
57 |
58 | - name: Publish package to TestPyPI
59 | uses: pypa/gh-action-pypi-publish@release/v1
60 | with:
61 | repository-url: https://test.pypi.org/legacy/
62 | username: __token__
63 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
64 | verbose: true
65 | print-hash: true
66 |
67 | publish-package:
68 | if: github.event_name == 'release'
69 | runs-on: ubuntu-latest
70 | needs: [publish-package-test]
71 | permissions:
72 | id-token: write
73 | steps:
74 | - uses: actions/checkout@v6
75 |
76 | - uses: actions/setup-python@v6
77 | with:
78 | python-version: '3.10'
79 |
80 | - name: Download package
81 | uses: actions/download-artifact@v6
82 | with:
83 | name: release-${{ github.sha }}
84 | path: dist
85 |
86 | - name: Publish package to PyPI
87 | uses: pypa/gh-action-pypi-publish@release/v1
88 | with:
89 | username: __token__
90 | password: ${{ secrets.PYPI_API_TOKEN }}
91 | verbose: true
92 | print-hash: true
93 |
--------------------------------------------------------------------------------
/.github/workflows/testing.yml:
--------------------------------------------------------------------------------
1 | name: Testing
2 |
3 | on: # yamllint disable-line rule:truthy
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
11 | cancel-in-progress: ${{ github.event_name == 'pull_request' }}
12 |
13 | jobs:
14 |
15 | pytest:
16 | runs-on: ubuntu-latest
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | include:
21 | - {torch-version: '2.7', python-version: '3.10'}
22 | - {torch-version: '2.7', python-version: '3.13'}
23 | - {torch-version: '2.9', python-version: '3.10'}
24 | - {torch-version: '2.9', python-version: '3.13'}
25 | - {torch-version: 'nightly', python-version: '3.10'}
26 | - {torch-version: 'nightly', python-version: '3.13'}
27 |
28 | steps:
29 | - name: Checkout repository
30 | uses: actions/checkout@v6
31 | with:
32 | fetch-depth: 40
33 |
34 | # Skip workflow if only certain files have been changed.
35 | - name: Get changed files
36 | id: changed-files-specific
37 | uses: tj-actions/changed-files@v47
38 | with:
39 | files: |
40 | docs/**
41 | examples/**
42 | README.md
43 | CHANGELOG.md
44 | .github/workflows/changelog.yml
45 | .github/workflows/dependabot-auto-merge.yml
46 | .github/workflows/documentation.yml
47 | .github/workflows/labeler.yml
48 | .github/workflows/linting.yml
49 | .github/workflows/release.yml
50 |
51 | - name: Setup packages
52 | if: steps.changed-files-specific.outputs.only_changed != 'true'
53 | uses: ./.github/actions/setup
54 | with:
55 | python-version: ${{ matrix.python-version }}
56 | torch-version: ${{ matrix.torch-version }}
57 | cuda-version: cpu
58 |
59 | - name: Install main package
60 | if: steps.changed-files-specific.outputs.only_changed != 'true'
61 | run: |
62 | pip install -e .[full,test]
63 | pip list
64 |
65 | - name: Run tests
66 | if: steps.changed-files-specific.outputs.only_changed != 'true'
67 | run: |
68 | pytest --cov --cov-report=xml
69 |
70 | - name: Upload coverage
71 | if: steps.changed-files-specific.outputs.only_changed != 'true'
72 | uses: codecov/codecov-action@v5
73 | with:
74 | fail_ci_if_error: false
75 |
--------------------------------------------------------------------------------
/docs/source/modules/transforms.rst:
--------------------------------------------------------------------------------
1 | torch_frame.transforms
2 | ======================
3 |
4 | .. contents:: Contents
5 | :local:
6 |
7 | .. currentmodule:: torch_frame.transforms
8 |
9 | Transforms
10 | ----------
11 |
12 | :pyf:`PyTorch Frame` allows for data transformation across different :obj:`stype`'s or within the same :obj:`stype`. Transforms takes in both :class:`TensorFrame` and column stats.
13 |
14 | Let's look an example, where we apply `CatToNumTransform `_ to transform the categorical features into numerical features.
15 |
16 | .. code-block:: python
17 |
18 | from torch_frame.datasets import Yandex
19 | from torch_frame.transforms import CatToNumTransform
20 | from torch_frame import stype
21 |
22 | dataset = Yandex(root='/tmp/adult', name='adult')
23 | dataset.materialize()
24 | transform = CatToNumTransform()
25 | train_dataset = dataset.get_split('train')
26 |
27 | train_dataset.tensor_frame.col_names_dict[stype.categorical]
28 | >>> ['C_feature_0', 'C_feature_1', 'C_feature_2', 'C_feature_3', 'C_feature_4', 'C_feature_5', 'C_feature_6', 'C_feature_7']
29 |
30 | test_dataset = dataset.get_split('test')
31 | transform.fit(train_dataset.tensor_frame, dataset.col_stats)
32 |
33 | transformed_col_stats = transform.transformed_stats
34 |
35 | transformed_col_stats.keys()
36 | >>> dict_keys(['C_feature_0_0', 'C_feature_1_0', 'C_feature_2_0', 'C_feature_3_0', 'C_feature_4_0', 'C_feature_5_0', 'C_feature_6_0', 'C_feature_7_0'])
37 |
38 | transformed_col_stats['C_feature_0_0']
39 | >>> {: 0.6984029484029484, : 0.45895127199411595, : [0.0, 0.0, 1.0, 1.0, 1.0]}
40 |
41 | transform(test_dataset.tensor_frame)
42 | >>> TensorFrame(
43 | num_cols=14,
44 | num_rows=16281,
45 | numerical (14): ['N_feature_0', 'N_feature_1', 'N_feature_2', 'N_feature_3', 'N_feature_4', 'N_feature_5', 'C_feature_0_0', 'C_feature_1_0', 'C_feature_2_0', 'C_feature_3_0', 'C_feature_4_0', 'C_feature_5_0', 'C_feature_6_0', 'C_feature_7_0'],
46 | has_target=True,
47 | device=cpu,
48 | )
49 |
50 | You can see that after the transform, the column names of the categorical features changes and the categorical features are transformed into numerical features.
51 |
52 |
53 | .. autosummary::
54 | :nosignatures:
55 | :toctree: ../generated
56 |
57 | {% for name in torch_frame.transforms.functions %}
58 | {{ name }}
59 | {% endfor %}
60 |
--------------------------------------------------------------------------------
/torch_frame/datasets/amazon_fine_food_reviews.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pandas as pd
4 |
5 | import torch_frame
6 | from torch_frame.config.text_embedder import TextEmbedderConfig
7 | from torch_frame.config.text_tokenizer import TextTokenizerConfig
8 |
9 |
10 | class AmazonFineFoodReviews(torch_frame.data.Dataset):
11 | r"""The `Amazon Fine Food Reviews `_
12 | dataset. It consists of reviews of fine foods from amazon.
13 |
14 | Args:
15 | text_stype (torch_frame.stype): Text stype to use for text columns
16 | in the dataset. (default: :obj:`torch_frame.text_embedded`)
17 |
18 | **STATS:**
19 |
20 | .. list-table::
21 | :widths: 10 10 10 10 10 20 10
22 | :header-rows: 1
23 |
24 | * - #rows
25 | - #cols (numerical)
26 | - #cols (categorical)
27 | - #cols (text)
28 | - #classes
29 | - Task
30 | - Missing value ratio
31 | * - 568,454
32 | - 2
33 | - 3
34 | - 2
35 | - 5
36 | - multiclass_classification
37 | - 0.0%
38 | """
39 |
40 | url = "https://data.pyg.org/datasets/tables/amazon_fine_food_reviews.zip"
41 |
42 | def __init__(
43 | self,
44 | root: str,
45 | text_stype: torch_frame.stype = torch_frame.text_embedded,
46 | col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
47 | | TextEmbedderConfig | None = None,
48 | col_to_text_tokenizer_cfg: dict[str, TextTokenizerConfig]
49 | | TextTokenizerConfig | None = None,
50 | ) -> None:
51 | self.root = root
52 | self.text_stype = text_stype
53 | path = self.download_url(self.url, root)
54 |
55 | col_to_stype = {
56 | 'ProductId': torch_frame.categorical,
57 | 'UserId': torch_frame.categorical,
58 | 'HelpfulnessNumerator': torch_frame.numerical,
59 | 'HelpfulnessDenominator': torch_frame.numerical,
60 | 'Score': torch_frame.categorical,
61 | # 'Time': torch_frame.categorical, # TODO: change to timestamp
62 | 'Summary': text_stype,
63 | 'Text': text_stype,
64 | }
65 |
66 | df = pd.read_csv(path)[list(col_to_stype.keys())]
67 |
68 | super().__init__(
69 | df,
70 | col_to_stype,
71 | target_col='Score',
72 | col_to_text_embedder_cfg=col_to_text_embedder_cfg,
73 | col_to_text_tokenizer_cfg=col_to_text_tokenizer_cfg,
74 | )
75 |
--------------------------------------------------------------------------------
/test/nn/models/test_compile.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from torch_frame import stype
5 | from torch_frame.datasets import FakeDataset
6 | from torch_frame.nn.models import (
7 | ExcelFormer,
8 | FTTransformer,
9 | ResNet,
10 | TabNet,
11 | TabTransformer,
12 | Trompt,
13 | )
14 | from torch_frame.testing import withPackage
15 |
16 |
17 | @withPackage("torch>=2.6.0")
18 | @pytest.mark.parametrize(
19 | "model_cls, model_kwargs, stypes, expected_graph_breaks",
20 | [
21 | pytest.param(
22 | FTTransformer,
23 | dict(channels=8),
24 | None,
25 | 2,
26 | id="FTTransformer",
27 | ),
28 | pytest.param(ResNet, dict(channels=8), None, 2, id="ResNet"),
29 | pytest.param(
30 | TabNet,
31 | dict(
32 | split_feat_channels=2,
33 | split_attn_channels=2,
34 | gamma=0.1,
35 | ),
36 | None,
37 | 2,
38 | id="TabNet",
39 | ),
40 | pytest.param(
41 | TabTransformer,
42 | dict(
43 | channels=8,
44 | num_heads=2,
45 | encoder_pad_size=2,
46 | attn_dropout=0.5,
47 | ffn_dropout=0.5,
48 | ),
49 | None,
50 | 0,
51 | id="TabTransformer",
52 | ),
53 | pytest.param(
54 | Trompt,
55 | dict(channels=8, num_prompts=2),
56 | None,
57 | 3,
58 | id="Trompt",
59 | ),
60 | pytest.param(
61 | ExcelFormer,
62 | dict(in_channels=8, num_cols=3, num_heads=1),
63 | [stype.numerical],
64 | 1,
65 | id="ExcelFormer",
66 | ),
67 | ],
68 | )
69 | def test_compile_graph_break(
70 | model_cls,
71 | model_kwargs,
72 | stypes,
73 | expected_graph_breaks,
74 | ):
75 | torch._dynamo.config.suppress_errors = True
76 |
77 | dataset = FakeDataset(
78 | num_rows=10,
79 | with_nan=False,
80 | stypes=stypes or [stype.categorical, stype.numerical],
81 | )
82 | dataset.materialize()
83 | tf = dataset.tensor_frame
84 | model = model_cls(
85 | out_channels=1,
86 | num_layers=2,
87 | col_stats=dataset.col_stats,
88 | col_names_dict=tf.col_names_dict,
89 | **model_kwargs,
90 | )
91 | explanation = torch._dynamo.explain(model)(tf)
92 | graph_breaks = explanation.graph_break_count
93 | assert graph_breaks == expected_graph_breaks
94 |
--------------------------------------------------------------------------------
/torch_frame/datasets/poker_hand.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import zipfile
3 |
4 | import pandas as pd
5 |
6 | import torch_frame
7 |
8 |
9 | class PokerHand(torch_frame.data.Dataset):
10 | r"""The `Poker Hand
11 | `_
12 | dataset. It's a task to predict 5-card poker hand.
13 |
14 | **STATS:**
15 |
16 | .. list-table::
17 | :widths: 10 10 10 10 20 10
18 | :header-rows: 1
19 |
20 | * - #rows
21 | - #cols (numerical)
22 | - #cols (categorical)
23 | - #classes
24 | - Task
25 | - Missing value ratio
26 | * - 1,025,010
27 | - 5
28 | - 5
29 | - 10
30 | - multiclass_classification
31 | - 0.0%
32 | """
33 |
34 | url = 'https://archive.ics.uci.edu/static/public/158/poker+hand.zip'
35 |
36 | def __init__(self, root: str):
37 | path = self.download_url(self.url, root)
38 | folder_path = osp.dirname(path)
39 |
40 | with zipfile.ZipFile(path, 'r') as zip_ref:
41 | zip_ref.extractall(folder_path)
42 |
43 | train_path = osp.join(folder_path, 'poker-hand-training-true.data')
44 | test_path = osp.join(folder_path, 'poker-hand-testing.data')
45 |
46 | names = [
47 | 'Suit of card #1',
48 | 'Rank of card #1',
49 | 'Suit of card #2',
50 | 'Rank of card #2',
51 | 'Suit of card #3',
52 | 'Rank of card #3',
53 | 'Suit of card #4',
54 | 'Rank of card #4',
55 | 'Suit of card #5',
56 | 'Rank of card #5',
57 | 'Poker Hand',
58 | ]
59 | train_df = pd.read_csv(train_path, names=names)
60 | test_df = pd.read_csv(test_path, names=names)
61 | df = pd.concat([train_df, test_df], ignore_index=True)
62 |
63 | col_to_stype = {
64 | 'Suit of card #1': torch_frame.categorical,
65 | 'Rank of card #1': torch_frame.numerical,
66 | 'Suit of card #2': torch_frame.categorical,
67 | 'Rank of card #2': torch_frame.numerical,
68 | 'Suit of card #3': torch_frame.categorical,
69 | 'Rank of card #3': torch_frame.numerical,
70 | 'Suit of card #4': torch_frame.categorical,
71 | 'Rank of card #4': torch_frame.numerical,
72 | 'Suit of card #5': torch_frame.categorical,
73 | 'Rank of card #5': torch_frame.numerical,
74 | 'Poker Hand': torch_frame.categorical,
75 | }
76 |
77 | super().__init__(df, col_to_stype, target_col='Poker Hand')
78 |
--------------------------------------------------------------------------------
/torch_frame/datasets/mercari.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os.path as osp
4 |
5 | import pandas as pd
6 |
7 | import torch_frame
8 | from torch_frame.config.text_embedder import TextEmbedderConfig
9 | from torch_frame.utils.split import SPLIT_TO_NUM
10 |
11 | SPLIT_COL = 'split_col'
12 |
13 |
14 | class Mercari(torch_frame.data.Dataset):
15 | r"""The `Mercari Price Suggestion Challenge
16 | `_
17 | dataset from Kaggle.
18 |
19 | Args:
20 | num_rows (int, optional): Number of rows to subsample.
21 | (default: :obj:`None`)
22 |
23 | **STATS:**
24 |
25 | .. list-table::
26 | :widths: 10 10 10 10 20 10
27 | :header-rows: 1
28 |
29 | * - #rows
30 | - #cols (numerical)
31 | - #cols (categorical)
32 | - #cols (text_embedded)
33 | - Task
34 | - Missing value ratio
35 | * - 1,482,535
36 | - 1
37 | - 4
38 | - 2
39 | - regression
40 | - 0.0%
41 | """
42 | base_url = 'https://data.pyg.org/datasets/tables/mercari_price_suggestion/'
43 | files = ['train', 'test_stg2']
44 |
45 | def __init__(
46 | self,
47 | root: str,
48 | num_rows: int | None = None,
49 | col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
50 | | TextEmbedderConfig | None = None,
51 | ) -> None:
52 | col_to_stype = {
53 | 'name': torch_frame.text_embedded,
54 | 'item_condition_id': torch_frame.categorical,
55 | 'category_name': torch_frame.multicategorical,
56 | 'brand_name': torch_frame.categorical,
57 | 'price': torch_frame.numerical,
58 | 'shipping': torch_frame.categorical,
59 | 'item_description': torch_frame.text_embedded,
60 | }
61 | train_path = osp.join(self.base_url, 'train.csv')
62 | self.download_url(train_path, root)
63 | df_train = pd.read_csv(train_path)
64 | test_path = osp.join(self.base_url, 'test_stg2.csv')
65 | self.download_url(test_path, root)
66 | df_test = pd.read_csv(test_path)
67 | df_train[SPLIT_COL] = SPLIT_TO_NUM['train']
68 | df_test[SPLIT_COL] = SPLIT_TO_NUM['test']
69 | df = pd.concat([df_train, df_test], axis=0, ignore_index=True)
70 | if num_rows is not None:
71 | df = df.head(num_rows)
72 | df.drop(['train_id'], axis=1, inplace=True)
73 | super().__init__(df, col_to_stype, target_col='price', col_to_sep="/",
74 | col_to_text_embedder_cfg=col_to_text_embedder_cfg,
75 | split_col=SPLIT_COL)
76 |
--------------------------------------------------------------------------------
/examples/tabpfn_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 |
4 | import numpy as np
5 | import torch
6 | from tabpfn import TabPFNClassifier
7 | # Please run `pip install tabpfn` to install the package
8 | from tqdm import tqdm
9 |
10 | from torch_frame.data import DataLoader
11 | from torch_frame.datasets import (
12 | ForestCoverType,
13 | KDDCensusIncome,
14 | Mushroom,
15 | Titanic,
16 | )
17 |
18 | parser = argparse.ArgumentParser(
19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20 | parser.add_argument(
21 | '--dataset', type=str, default="Titanic",
22 | choices=["Titanic", "Mushroom", "ForestCoverType", "KDDCensusIncome"])
23 | parser.add_argument('--train_batch_size', type=int, default=4096)
24 | parser.add_argument('--test_batch_size', type=int, default=128)
25 | parser.add_argument('--seed', type=int, default=0)
26 | args = parser.parse_args()
27 |
28 | torch.manual_seed(args.seed)
29 |
30 | # Prepare datasets
31 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
32 | args.dataset)
33 |
34 | if args.dataset == "Titanic":
35 | dataset = Titanic(root=path)
36 | elif args.dataset == "ForestCoverType":
37 | dataset = ForestCoverType(root=path)
38 | elif args.dataset == "KDDCensusIncome":
39 | dataset = KDDCensusIncome(root=path)
40 | else:
41 | dataset = Mushroom(root=path)
42 |
43 | dataset.materialize()
44 | assert dataset.task_type.is_classification
45 | dataset = dataset.shuffle()
46 | train_dataset, test_dataset = dataset[:0.9], dataset[0.9:]
47 | train_tensor_frame = train_dataset.tensor_frame
48 | test_tensor_frame = test_dataset.tensor_frame
49 | train_loader = DataLoader(
50 | train_tensor_frame,
51 | batch_size=args.train_batch_size,
52 | shuffle=True,
53 | )
54 | X_train = []
55 | train_data = next(iter(train_loader))
56 | for stype in train_data.stypes:
57 | X_train.append(train_data.feat_dict[stype])
58 | X_train: torch.Tensor = torch.cat(X_train, dim=1)
59 | clf = TabPFNClassifier()
60 | clf.fit(X_train, train_data.y)
61 | test_loader = DataLoader(test_tensor_frame, batch_size=args.test_batch_size)
62 |
63 |
64 | @torch.no_grad()
65 | def test() -> float:
66 | accum = total_count = 0
67 | for test_data in tqdm(test_loader):
68 | X_test = []
69 | for stype in train_data.stypes:
70 | X_test.append(test_data.feat_dict[stype])
71 | X_test = torch.cat(X_test, dim=1)
72 | pred: np.ndarray = clf.predict_proba(X_test)
73 | pred_class = pred.argmax(axis=-1)
74 | accum += float((test_data.y.numpy() == pred_class).sum())
75 | total_count += len(test_data.y)
76 |
77 | return accum / total_count
78 |
79 |
80 | acc = test()
81 | print(f"Accuracy: {acc:.4f}")
82 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3.10
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
7 | autoupdate_schedule: weekly
8 | # mypy exceeds free tier 250MiB limit on pre-commit.ci
9 | # https://github.com/pre-commit-ci/issues/issues/171
10 | skip: [mypy]
11 |
12 | repos:
13 | - repo: https://github.com/pre-commit/pre-commit-hooks
14 | rev: v6.0.0
15 | hooks:
16 | - id: no-commit-to-branch
17 | name: No commits to master
18 | - id: end-of-file-fixer
19 | name: End-of-file fixer
20 | - id: trailing-whitespace
21 | name: Remove trailing whitespaces
22 | - id: check-toml
23 | name: Check toml
24 | - id: check-yaml
25 | name: Check yaml
26 |
27 | - repo: https://github.com/adrienverge/yamllint.git
28 | rev: v1.37.1
29 | hooks:
30 | - id: yamllint
31 | name: Lint yaml
32 | args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}']
33 |
34 | - repo: https://github.com/PyCQA/autoflake
35 | rev: v2.3.1
36 | hooks:
37 | - id: autoflake
38 | name: Remove unused imports and variables
39 | args: [
40 | --remove-all-unused-imports,
41 | --remove-unused-variables,
42 | --remove-duplicate-keys,
43 | --ignore-init-module-imports,
44 | --in-place,
45 | ]
46 |
47 | - repo: https://github.com/google/yapf
48 | rev: v0.43.0
49 | hooks:
50 | - id: yapf
51 | name: Format code
52 | additional_dependencies: [toml]
53 |
54 | - repo: https://github.com/pycqa/isort
55 | rev: 7.0.0
56 | hooks:
57 | - id: isort
58 | name: Sort imports
59 |
60 | - repo: https://github.com/PyCQA/flake8
61 | rev: 7.3.0
62 | hooks:
63 | - id: flake8
64 | name: Check PEP8
65 | additional_dependencies: [Flake8-pyproject]
66 |
67 | - repo: https://github.com/astral-sh/ruff-pre-commit
68 | rev: v0.14.8
69 | hooks:
70 | - id: ruff
71 | name: Ruff formatting
72 | args: [--fix, --exit-non-zero-on-fix]
73 |
74 | - repo: https://github.com/pre-commit/mirrors-mypy
75 | rev: v1.19.0
76 | hooks:
77 | - id: mypy
78 | name: Check types
79 | additional_dependencies: [torch==2.9.*]
80 | exclude: "^test/|^examples/|^benchmark/"
81 |
82 | - repo: https://github.com/executablebooks/mdformat
83 | rev: 1.0.0
84 | hooks:
85 | - id: mdformat
86 | name: Format Markdown
87 | additional_dependencies:
88 | - mdformat-gfm
89 | - mdformat-front-matters
90 | - mdformat-footnote
91 |
--------------------------------------------------------------------------------
/torch_frame/datasets/bank_marketing.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import zipfile
3 |
4 | import pandas as pd
5 |
6 | import torch_frame
7 |
8 |
9 | class BankMarketing(torch_frame.data.Dataset):
10 | r"""The `Bank Marketing
11 | `_
12 | dataset. It's related with direct marketing campaigns of
13 | a Portuguese banking institution. The marketing campaigns
14 | were based on phone calls. Often, more than one contant to
15 | the same client was required, in order to access if the
16 | product (bank term deposit) would be (or not) subscribed.
17 | The classification goal is to predict if the client will
18 | subscribe a term deposit.
19 |
20 | **STATS:**
21 |
22 | .. list-table::
23 | :widths: 10 10 10 10 20 10
24 | :header-rows: 1
25 |
26 | * - #rows
27 | - #cols (numerical)
28 | - #cols (categorical)
29 | - #classes
30 | - Task
31 | - Missing value ratio
32 | * - 45,211
33 | - 7
34 | - 9
35 | - 2
36 | - binary_classification
37 | - 0.0%
38 | """
39 |
40 | url = 'https://archive.ics.uci.edu/static/public/222/bank+marketing.zip' # noqa
41 |
42 | def __init__(self, root: str):
43 | path = self.download_url(self.url, root)
44 | folder_path = osp.dirname(path)
45 | with zipfile.ZipFile(path, 'r') as zip_ref:
46 | zip_ref.extractall(folder_path)
47 | data_path = osp.join(folder_path, 'bank.zip')
48 | data_subfolder_path = osp.join(folder_path, 'bank')
49 | with zipfile.ZipFile(data_path, 'r') as zip_ref:
50 | zip_ref.extractall(data_subfolder_path)
51 | df = pd.read_csv(osp.join(data_subfolder_path, 'bank-full.csv'),
52 | sep=';')
53 |
54 | col_to_stype = {
55 | 'age': torch_frame.numerical,
56 | 'job': torch_frame.categorical,
57 | 'marital': torch_frame.categorical,
58 | 'education': torch_frame.categorical,
59 | 'default': torch_frame.categorical,
60 | 'balance': torch_frame.numerical,
61 | 'housing': torch_frame.categorical,
62 | 'loan': torch_frame.categorical,
63 | 'contact': torch_frame.categorical,
64 | 'day': torch_frame.numerical,
65 | 'month': torch_frame.categorical,
66 | 'duration': torch_frame.numerical,
67 | 'campaign': torch_frame.numerical,
68 | 'pdays': torch_frame.numerical,
69 | 'previous': torch_frame.numerical,
70 | 'poutcome': torch_frame.categorical,
71 | 'y': torch_frame.categorical,
72 | }
73 |
74 | super().__init__(df, col_to_stype, target_col='y')
75 |
--------------------------------------------------------------------------------
/torch_frame/config/image_embedder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod
4 | from collections.abc import Callable
5 | from dataclasses import dataclass
6 |
7 | from PIL import Image
8 | from torch import Tensor
9 |
10 |
11 | class ImageEmbedder(ABC):
12 | r"""Parent class for the :obj:`image_embedder` of
13 | :class:`ImageEmbedderConfig`. This class first retrieves images based
14 | on given paths stored in the data frame and then embeds retrieved images
15 | into tensor. Users are responsible for implementing :meth:`forward_embed`
16 | which takes a list of images and returns embeddings tensor. User can also
17 | override :meth:`forward_retrieve` which takes the paths to images and
18 | return a list of :obj:`PIL.Image.Image`.
19 | """
20 | def forward_retrieve(self, path_to_images: list[str]) -> list[Image.Image]:
21 | r"""Retrieval function that reads a list of images from
22 | a list of file paths with the :obj:`RGB` mode.
23 | """
24 | images: list[Image.Image] = []
25 | for path_to_image in path_to_images:
26 | image = Image.open(path_to_image)
27 | images.append(image.copy())
28 | image.close()
29 | images = [image.convert('RGB') for image in images]
30 | return images
31 |
32 | @abstractmethod
33 | def forward_embed(self, images: list[Image.Image]) -> Tensor:
34 | r"""Embedding function that takes a list of images and returns
35 | an embedding tensor.
36 | """
37 | raise NotImplementedError
38 |
39 | def __call__(self, path_to_images: list[str]) -> Tensor:
40 | images = self.forward_retrieve(path_to_images)
41 | return self.forward_embed(images)
42 |
43 |
44 | @dataclass
45 | class ImageEmbedderConfig:
46 | r"""Image embedder model that maps a list of images into PyTorch
47 | Tensor embeddings.
48 |
49 | Args:
50 | image_embedder (callable): A callable image embedder that takes a
51 | list of path to images as input and outputs the PyTorch Tensor
52 | embeddings for that list of images. Usually it contains a retriever
53 | to load image files and then a embedder converting images to
54 | embeddings.
55 | batch_size (int, optional): Batch size to use when encoding the
56 | images. If set to :obj:`None`, the image embeddings will
57 | be obtained in a full-batch manner. (default: :obj:`None`)
58 |
59 | """
60 | image_embedder: Callable[[list[str]], Tensor]
61 | # Batch size to use when encoding the images. It is recommended to set
62 | # it to a reasonable value when one uses a heavy image embedding model
63 | # (e.g., ViT) on GPU. If set to :obj:`None`, the image embeddings
64 | # will be obtained in a full-batch manner.
65 | batch_size: int | None = None
66 |
--------------------------------------------------------------------------------
/torch_frame/datasets/diamond_images.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os.path as osp
4 | import zipfile
5 |
6 | import pandas as pd
7 |
8 | import torch_frame
9 | from torch_frame.config.image_embedder import ImageEmbedderConfig
10 |
11 |
12 | class DiamondImages(torch_frame.data.Dataset):
13 | r"""The `Diamond Images
14 | `_
15 | dataset from Kaggle. The target is to predict :obj:`colour` of each
16 | diamond.
17 |
18 | **STATS:**
19 |
20 | .. list-table::
21 | :widths: 10 10 10 10 10 20 10
22 | :header-rows: 1
23 |
24 | * - #rows
25 | - #cols (numerical)
26 | - #cols (categorical)
27 | - #cols (image)
28 | - #classes
29 | - Task
30 | - Missing value ratio
31 | * - 48,764
32 | - 4
33 | - 7
34 | - 1
35 | - 23
36 | - multiclass_classification
37 | - 0.167%
38 | """
39 |
40 | url = 'https://data.pyg.org/datasets/tables/diamond.zip'
41 |
42 | def __init__(
43 | self,
44 | root: str,
45 | col_to_image_embedder_cfg: ImageEmbedderConfig
46 | | dict[str, ImageEmbedderConfig],
47 | ):
48 | path = self.download_url(self.url, root)
49 |
50 | folder_path = osp.dirname(path)
51 | with zipfile.ZipFile(path, "r") as zip_ref:
52 | zip_ref.extractall(folder_path)
53 |
54 | subfolder_path = osp.join(folder_path, "diamond")
55 | csv_path = osp.join(subfolder_path, "diamond_data.csv")
56 | df = pd.read_csv(csv_path)
57 | df = df.drop(columns=["stock_number"])
58 |
59 | image_paths = []
60 | for path_to_img in df["path_to_img"]:
61 | path_to_img = path_to_img.replace("web_scraped/", "")
62 | image_paths.append(osp.join(subfolder_path, path_to_img))
63 | image_df = pd.DataFrame({"image_path": image_paths})
64 | df = pd.concat([df, image_df], axis=1)
65 | df = df.drop(columns=["path_to_img"])
66 |
67 | col_to_stype = {
68 | "shape": torch_frame.categorical,
69 | "carat": torch_frame.numerical,
70 | "clarity": torch_frame.categorical,
71 | "colour": torch_frame.categorical,
72 | "cut": torch_frame.categorical,
73 | "polish": torch_frame.categorical,
74 | "symmetry": torch_frame.categorical,
75 | "fluorescence": torch_frame.categorical,
76 | "lab": torch_frame.categorical,
77 | "length": torch_frame.numerical,
78 | "width": torch_frame.numerical,
79 | "depth": torch_frame.numerical,
80 | "image_path": torch_frame.image_embedded,
81 | }
82 |
83 | super().__init__(df, col_to_stype, target_col="colour",
84 | col_to_image_embedder_cfg=col_to_image_embedder_cfg)
85 |
--------------------------------------------------------------------------------
/torch_frame/nn/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import copy
4 | from inspect import signature
5 | from typing import Any
6 |
7 | import torch
8 |
9 |
10 | class Module(torch.nn.Module):
11 | r"""A base class for defining modules in which attributes may be defined
12 | in a later stage. As such, users only need to define the dynamic
13 | hyperparameters of a module, and do not need to care about connecting the
14 | module to the underlying data, e.g., specifying the number of input or
15 | output channels.
16 |
17 | This is achieved by postponing submodule creation
18 | (via :meth:`init_modules`) until all attributes in :obj:`LAZY_ATTRS` are
19 | fully-specified.
20 | """
21 | LAZY_ATTRS: set[str] = set()
22 |
23 | def init_modules(self):
24 | pass
25 |
26 | def __init__(self, *args, **kwargs):
27 | super().__init__()
28 |
29 | self._in_init = True
30 | self._missing_attrs = copy.copy(self.LAZY_ATTRS)
31 |
32 | for key, value in zip(
33 | signature(self.__init__).parameters, args, strict=False):
34 | setattr(self, key, value)
35 | for key, value in kwargs.items():
36 | setattr(self, key, value)
37 |
38 | self._in_init = False
39 |
40 | if self.is_fully_specified:
41 | self._init_modules()
42 |
43 | def __setattr__(self, key: str, value: Any):
44 | super().__setattr__(key, value)
45 | if value is not None and key in getattr(self, '_missing_attrs', {}):
46 | self._missing_attrs.remove(key)
47 | if not self._in_init and self.is_fully_specified:
48 | self._init_modules()
49 |
50 | @property
51 | def is_fully_specified(self) -> bool:
52 | return len(self._missing_attrs) == 0
53 |
54 | def validate(self):
55 | if len(self._missing_attrs) > 0:
56 | raise ValueError(
57 | f"The '{self.__class__.__name__}' module is not fully-"
58 | f"specified yet. It is missing the following attribute(s): "
59 | f"{self._missing_attrs}. Please specify them before using "
60 | f"this module in a deep learning pipeline.")
61 |
62 | def _init_modules(self):
63 | self.validate()
64 | self.init_modules()
65 |
66 | def _apply(self, *args, **kwargs) -> Any:
67 | self.validate()
68 | return super()._apply(*args, **kwargs)
69 |
70 | def named_parameters(self, *args, **kwargs) -> Any:
71 | self.validate()
72 | return super().named_parameters(*args, **kwargs)
73 |
74 | def named_children(self) -> Any:
75 | self.validate()
76 | return super().named_children()
77 |
78 | def named_modules(self, *args, **kwargs) -> Any:
79 | self.validate()
80 | return super().named_modules(*args, **kwargs)
81 |
82 | def __call__(self, *args, **kwargs) -> Any:
83 | self.validate()
84 | return super().__call__(*args, **kwargs)
85 |
--------------------------------------------------------------------------------
/torch_frame/datasets/amphibians.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import zipfile
3 |
4 | import pandas as pd
5 |
6 | import torch_frame
7 |
8 |
9 | class Amphibians(torch_frame.data.Dataset):
10 | r"""The `Amphibians
11 | `_
12 | dataset. The task is to predict which of the 7 frogs types appeared
13 | in the habitat.
14 |
15 | **STATS:**
16 |
17 | .. list-table::
18 | :widths: 10 10 10 10 20 10
19 | :header-rows: 1
20 |
21 | * - #rows
22 | - #cols (numerical)
23 | - #cols (categorical)
24 | - #cols (text_embedded)
25 | - Task
26 | - Missing value ratio
27 | * - 189
28 | - 3
29 | - 20
30 | - 0
31 | - multilabel classification
32 | - 0.0%
33 | """
34 | url = 'https://archive.ics.uci.edu/static/public/528/amphibians.zip'
35 |
36 | def __init__(self, root: str):
37 | path = self.download_url(self.url, root)
38 | folder_path = osp.dirname(path)
39 |
40 | with zipfile.ZipFile(path, 'r') as zip_ref:
41 | zip_ref.extractall(folder_path)
42 |
43 | data_path = osp.join(folder_path, 'dataset.csv')
44 | names = [
45 | 'ID', 'MV', 'SR', 'NR', 'TR', 'VR', 'SUR1', 'SUR2', 'SUR3', 'UR',
46 | 'FR', 'OR', 'RR', 'BR', 'MR', 'CR', 't1', 't2', 't3', 't4', 't5',
47 | 't6', 't7'
48 | ]
49 | df = pd.read_csv(data_path, names=names, sep=';')
50 | # Drop the first 2 rows containing metadata
51 | df = df.iloc[2:].reset_index(drop=True)
52 | target_cols = ['t1', 't2', 't3', 't4', 't5', 't6', 't7']
53 | df['t'] = df.apply(
54 | lambda row: [col for col in target_cols if row[col] == '1'],
55 | axis=1)
56 | df = df.drop(target_cols, axis=1)
57 |
58 | # Infer the pandas dataframe automatically
59 | path = osp.join(root, 'amphibians_posprocess.csv')
60 | df.to_csv(path, index=False)
61 | df = pd.read_csv(path)
62 |
63 | col_to_stype = {
64 | 'ID': torch_frame.numerical,
65 | 'MV': torch_frame.categorical,
66 | 'SR': torch_frame.numerical,
67 | 'NR': torch_frame.numerical,
68 | 'TR': torch_frame.categorical,
69 | 'VR': torch_frame.categorical,
70 | 'SUR1': torch_frame.categorical,
71 | 'SUR2': torch_frame.categorical,
72 | 'SUR3': torch_frame.categorical,
73 | 'UR': torch_frame.categorical,
74 | 'FR': torch_frame.categorical,
75 | 'OR': torch_frame.numerical,
76 | 'RR': torch_frame.categorical, # Support Ordinal Encoding
77 | 'BR': torch_frame.categorical, # Support Ordinal Encoding
78 | 'MR': torch_frame.categorical,
79 | 'CR': torch_frame.categorical,
80 | 't': torch_frame.multicategorical,
81 | }
82 | super().__init__(df, col_to_stype, target_col='t', col_to_sep=None)
83 |
--------------------------------------------------------------------------------
/torch_frame/datasets/movielens_1m.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os.path as osp
4 | import zipfile
5 |
6 | import pandas as pd
7 |
8 | import torch_frame
9 | from torch_frame.config.text_embedder import TextEmbedderConfig
10 |
11 |
12 | class Movielens1M(torch_frame.data.Dataset):
13 | r"""The MovieLens 1M rating dataset, assembled by GroupLens Research
14 | from the MovieLens web site, consisting of movies (3,883 nodes) and
15 | users (6,040 nodes) with approximately 1 million ratings between them.
16 |
17 | **STATS:**
18 |
19 | .. list-table::
20 | :widths: 10 10 10 10 20
21 | :header-rows: 1
22 |
23 | * - #Users
24 | - #Items
25 | - #User Field
26 | - #Item Field
27 | - #Samples
28 | * - 6040
29 | - 3952
30 | - 5
31 | - 3
32 | - 1000209
33 | """
34 |
35 | url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip'
36 |
37 | def __init__(
38 | self,
39 | root: str,
40 | col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
41 | | TextEmbedderConfig | None = None,
42 | ):
43 | path = self.download_url(self.url, root)
44 | folder_path = osp.dirname(path)
45 |
46 | with zipfile.ZipFile(path, 'r') as zip_ref:
47 | zip_ref.extractall(folder_path)
48 |
49 | data_path = osp.join(folder_path, 'ml-1m')
50 | users = pd.read_csv(
51 | osp.join(data_path, 'users.dat'),
52 | header=None,
53 | names=['user_id', 'gender', 'age', 'occupation', 'zip'],
54 | sep='::',
55 | engine='python',
56 | )
57 | movies = pd.read_csv(
58 | osp.join(data_path, 'movies.dat'),
59 | header=None,
60 | names=['movie_id', 'title', 'genres'],
61 | sep='::',
62 | engine='python',
63 | encoding='ISO-8859-1',
64 | )
65 | ratings = pd.read_csv(
66 | osp.join(data_path, 'ratings.dat'),
67 | header=None,
68 | names=['user_id', 'movie_id', 'rating', 'timestamp'],
69 | sep='::',
70 | engine='python',
71 | )
72 |
73 | df = pd.merge(pd.merge(ratings, users), movies) \
74 | .sort_values(by='timestamp') \
75 | .reset_index().drop('index', axis=1)
76 |
77 | col_to_stype = {
78 | 'user_id': torch_frame.categorical,
79 | 'gender': torch_frame.categorical,
80 | 'age': torch_frame.categorical,
81 | 'occupation': torch_frame.categorical,
82 | 'zip': torch_frame.categorical,
83 | 'movie_id': torch_frame.categorical,
84 | 'title': torch_frame.text_embedded,
85 | 'genres': torch_frame.multicategorical,
86 | 'rating': torch_frame.numerical,
87 | 'timestamp': torch_frame.timestamp,
88 | }
89 | super().__init__(df, col_to_stype, target_col='rating', col_to_sep='|',
90 | col_to_text_embedder_cfg=col_to_text_embedder_cfg)
91 |
--------------------------------------------------------------------------------
/torch_frame/testing/text_tokenizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from torch_frame.data import MultiNestedTensor
7 | from torch_frame.typing import TextTokenizationOutputs
8 |
9 |
10 | class WhiteSpaceHashTokenizer:
11 | r"""A simple white space tokenizer for testing purposes.
12 | It split sentence to tokens via white space and hashes each
13 | token into an index modulo :obj:`num_hash`.
14 |
15 | Args:
16 | num_hash_bins (int): Number of hash bins to use.
17 | (default: :obj:`64`)
18 | device (torch.device, optional): The device to put tokens.
19 | (default: :obj:`None`)
20 | batched (bool): Whether to tokenize in a batched format.
21 | If :obj:`True`, tokenizer returns Mapping[str, 2dim-Tensor],
22 | else List[Mapping[str, 1dim-Tensor]]. (default: :obj:`False`)
23 | """
24 | def __init__(
25 | self,
26 | num_hash_bins: int = 64,
27 | device: torch.device | None = None,
28 | batched: bool = False,
29 | ):
30 | self.device = device
31 | self.num_hash_bins = num_hash_bins
32 | self.batched = batched
33 |
34 | def __call__(self, sentences: list[str]) -> TextTokenizationOutputs:
35 | input_ids = []
36 | attention_mask = []
37 | for s in sentences:
38 | tokens = s.split(' ')
39 | idx = torch.tensor([hash(t) % self.num_hash_bins for t in tokens])
40 | input_ids.append(idx)
41 | attention_mask.append(torch.ones(idx.shape, dtype=torch.bool))
42 |
43 | if self.batched:
44 | max_length = max(t.size(0) for t in input_ids)
45 | padded_input_ids = [
46 | F.pad(t, (0, max_length - t.size(0)), value=-1)
47 | for t in input_ids
48 | ]
49 | input_ids = torch.stack(padded_input_ids)
50 | padded_attention_mask = [
51 | F.pad(t, (0, max_length - t.size(0)), value=False)
52 | for t in attention_mask
53 | ]
54 | attention_mask = torch.stack(padded_attention_mask)
55 | return {'input_ids': input_ids, 'attention_mask': attention_mask}
56 | else:
57 | return [{
58 | 'input_ids': input_ids[i],
59 | 'attention_mask': attention_mask[i]
60 | } for i in range(len(sentences))]
61 |
62 |
63 | class RandomTextModel(torch.nn.Module):
64 | r"""A text embedding model that takes the tokenized input from
65 | :class:`WhiteSpaceHashTokenizer` and outputs random embeddings. Should be
66 | used only for testing purposes.
67 | """
68 | def __init__(self, text_emb_channels: int):
69 | self.text_emb_channels = text_emb_channels
70 | super().__init__()
71 |
72 | def forward(self, feat: dict[str, MultiNestedTensor]):
73 | input_ids = feat['input_ids'].to_dense(fill_value=0)
74 | _ = feat['attention_mask'].to_dense(fill_value=0)
75 | return torch.rand(size=(input_ids.shape[0], 1, self.text_emb_channels))
76 |
--------------------------------------------------------------------------------
/examples/tuned_gbdt.py:
--------------------------------------------------------------------------------
1 | """Reported (reproduced) results of Tuned XGBoost on TabularBenchmark of
2 | the Trompt paper https://arxiv.org/abs/2305.18446.
3 |
4 | electricity (A4): 88.52 (91.09)
5 | eye_movements (A5): 66.57 (64.21)
6 | MagicTelescope (B2): 86.05 (86.50)
7 | bank-marketing (B4): 80.34 (80.41)
8 | california (B5): 90.12 (89.71)
9 | credit (B7): 77.26 (77.4)
10 | pol (B14): 98.09 (97.5)
11 | jannis (mathcal B4): 79.67 (77.81)
12 |
13 | Reported (reproduced) results of Tuned CatBoost on TabularBenchmark of
14 | the Trompt paper: https://arxiv.org/abs/2305.18446
15 |
16 | electricity (A4): 87.73 (88.09)
17 | eye_movements (A5): 66.84 (64.27)
18 | MagicTelescope (B2): 85.92 (87.18)
19 | bank-marketing (B4): 80.39 (80.50)
20 | california (B5): 90.32 (87.56)
21 | credit (B7): 77.59 (77.29)
22 | pol (B14): 98.49 (98.21)
23 | jannis (mathcal B4): 79.89 (78.96)
24 | """
25 | import argparse
26 | import os.path as osp
27 | import random
28 |
29 | import numpy as np
30 | import torch
31 |
32 | from torch_frame.datasets import TabularBenchmark
33 | from torch_frame.gbdt import CatBoost, LightGBM, XGBoost
34 | from torch_frame.typing import Metric
35 |
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--gbdt_type', type=str, default='xgboost',
38 | choices=['xgboost', 'catboost', 'lightgbm'])
39 | parser.add_argument('--dataset', type=str, default='eye_movements')
40 | parser.add_argument('--saved_model_path', type=str,
41 | default='storage/gbdts.txt')
42 | # Add this flag to match the reported number.
43 | parser.add_argument('--seed', type=int, default=0)
44 | args = parser.parse_args()
45 |
46 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47 |
48 | random.seed(args.seed)
49 | np.random.seed(args.seed)
50 | torch.manual_seed(args.seed)
51 |
52 | # Prepare datasets
53 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
54 | args.dataset)
55 | dataset = TabularBenchmark(root=path, name=args.dataset)
56 | dataset.materialize()
57 | dataset = dataset.shuffle()
58 | # Split ratio following https://arxiv.org/abs/2207.08815
59 | # 70% is used for training. 30% of the remaining is used for validation.
60 | # The final reminder is used for testing.
61 | train_dataset, val_dataset, test_dataset = dataset[:0.7], dataset[
62 | 0.7:0.79], dataset[0.79:]
63 |
64 | num_classes = None
65 | metric = None
66 | task_type = dataset.task_type
67 | if dataset.task_type.is_classification:
68 | metric = Metric.ACCURACY
69 | num_classes = dataset.num_classes
70 | else:
71 | metric = Metric.RMSE
72 | num_classes = None
73 |
74 | gbdt_cls_dict = {
75 | 'xgboost': XGBoost,
76 | 'catboost': CatBoost,
77 | 'lightgbm': LightGBM,
78 | }
79 | gbdt = gbdt_cls_dict[args.gbdt_type](
80 | task_type=task_type,
81 | num_classes=num_classes,
82 | metric=metric,
83 | )
84 |
85 | if osp.exists(args.saved_model_path):
86 | gbdt.load(args.saved_model_path)
87 | else:
88 | gbdt.tune(tf_train=train_dataset.tensor_frame,
89 | tf_val=val_dataset.tensor_frame, num_trials=20)
90 | gbdt.save(args.saved_model_path)
91 |
92 | pred = gbdt.predict(tf_test=test_dataset.tensor_frame)
93 | score = gbdt.compute_metric(test_dataset.tensor_frame.y, pred)
94 | print(f"{gbdt.metric} : {score}")
95 |
--------------------------------------------------------------------------------
/test/nn/test_simple_basecls.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | import torch_frame
5 | from torch_frame import TensorFrame, stype
6 | from torch_frame.nn import Decoder, FeatureEncoder, TableConv
7 |
8 |
9 | def test_simple_basecls():
10 | # Instantiate each base class with a simple class and test e2e pipeline.
11 | class SimpleFeatureEncoder(FeatureEncoder):
12 | def __init__(
13 | self,
14 | out_channels: int,
15 | num_numerical: int,
16 | num_categories: list[int],
17 | ):
18 | super().__init__()
19 |
20 | self.out_channels = out_channels
21 | self.num_numerical = num_numerical
22 | self.num_categories = num_categories
23 |
24 | self.lins = torch.nn.ModuleList([
25 | torch.nn.Linear(1, out_channels) for _ in range(num_numerical)
26 | ])
27 | self.embs = torch.nn.ModuleList([
28 | torch.nn.Embedding(num_category, out_channels)
29 | for num_category in num_categories
30 | ])
31 |
32 | def forward(self, tf: TensorFrame) -> tuple[Tensor, list[str]]:
33 | xs = []
34 | for i, lin in enumerate(self.lins):
35 | xs.append(lin(tf.feat_dict[torch_frame.numerical][:, i:i + 1]))
36 | for i, emb in enumerate(self.embs):
37 | xs.append(emb(tf.feat_dict[torch_frame.categorical][:, i]))
38 |
39 | x = torch.stack(xs, dim=1)
40 | col_names = (tf.col_names_dict[stype.numerical] +
41 | tf.col_names_dict[stype.categorical])
42 |
43 | return x, col_names
44 |
45 | class SimpleTableConv(TableConv):
46 | def __init__(self, in_channels: int, out_channels: int):
47 | super().__init__()
48 | self.lin = torch.nn.Linear(in_channels, out_channels)
49 |
50 | def forward(self, x: Tensor) -> Tensor:
51 | B, C, H = x.shape
52 | x = x.view(-1, H)
53 | return self.lin(x).view(B, C, -1)
54 |
55 | class SimpleDecoder(Decoder):
56 | def forward(self, x: Tensor) -> Tensor:
57 | # Pool along the column axis
58 | return torch.mean(x, dim=1)
59 |
60 | tf = TensorFrame(
61 | feat_dict={
62 | torch_frame.numerical: torch.randn(10, 2),
63 | torch_frame.categorical: torch.randint(0, 5, (10, 2)),
64 | },
65 | col_names_dict={
66 | torch_frame.numerical: ['num1', 'num2'],
67 | torch_frame.categorical: ['cat1', 'cat2'],
68 | },
69 | )
70 |
71 | feat_encoder = SimpleFeatureEncoder(
72 | out_channels=8,
73 | num_numerical=2,
74 | num_categories=[5, 5],
75 | )
76 | table_conv1 = SimpleTableConv(in_channels=8, out_channels=16)
77 | table_conv2 = SimpleTableConv(in_channels=16, out_channels=8)
78 | decoder = SimpleDecoder()
79 |
80 | x, col_names = feat_encoder(tf)
81 | # [batch_size, num_cols, hidden_channels]
82 | assert x.shape == (10, 4, 8)
83 | assert col_names == ['num1', 'num2', 'cat1', 'cat2']
84 | x = table_conv1(x)
85 | assert x.shape == (10, 4, 16)
86 | x = table_conv2(x)
87 | assert x.shape == (10, 4, 8)
88 | x = decoder(x)
89 | assert x.shape == (10, 8)
90 |
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections.abc import Callable
3 |
4 | import pytest
5 | import torch
6 |
7 | import torch_frame
8 | from torch_frame import TensorFrame
9 | from torch_frame.data import MultiEmbeddingTensor, MultiNestedTensor
10 |
11 |
12 | @pytest.fixture()
13 | def get_fake_tensor_frame() -> Callable:
14 | def _get_fake_tensor_frame(num_rows: int) -> TensorFrame:
15 | col_names_dict = {
16 | torch_frame.categorical: ['cat_1', 'cat_2', 'cat_3'],
17 | torch_frame.numerical: ['num_1', 'num_2'],
18 | torch_frame.multicategorical: ['multicat_1', 'multicat_2'],
19 | torch_frame.text_embedded:
20 | ['text_embedded_1', 'text_embedded_2', 'text_embedded_3'],
21 | torch_frame.text_tokenized:
22 | ['text_tokenized_1', 'text_tokenized_2'],
23 | torch_frame.sequence_numerical: ['seq_num_1', 'seq_num_2'],
24 | torch_frame.embedding: ['emb_1', 'emb_2'],
25 | }
26 | feat_dict = {
27 | torch_frame.categorical:
28 | torch.randint(
29 | 0, 3,
30 | size=(num_rows, len(col_names_dict[torch_frame.categorical]))),
31 | torch_frame.numerical:
32 | torch.randn(size=(num_rows,
33 | len(col_names_dict[torch_frame.numerical]))),
34 | torch_frame.multicategorical:
35 | MultiNestedTensor.from_tensor_mat([[
36 | torch.arange(random.randint(0, 10)) for _ in range(
37 | len(col_names_dict[torch_frame.multicategorical]))
38 | ] for _ in range(num_rows)]),
39 | torch_frame.text_embedded:
40 | MultiEmbeddingTensor.from_tensor_list([
41 | torch.randn(num_rows, random.randint(1, 5))
42 | for _ in range(len(col_names_dict[torch_frame.text_embedded]))
43 | ]),
44 | torch_frame.text_tokenized: {
45 | 'input_id':
46 | MultiNestedTensor.from_tensor_mat([[
47 | torch.randint(0, 5, size=(random.randint(0, 10), ))
48 | for _ in range(
49 | len(col_names_dict[torch_frame.text_tokenized]))
50 | ] for _ in range(num_rows)]),
51 | 'mask':
52 | MultiNestedTensor.from_tensor_mat([[
53 | torch.randint(0, 5, size=(random.randint(0, 10), ))
54 | for _ in range(
55 | len(col_names_dict[torch_frame.text_tokenized]))
56 | ] for _ in range(num_rows)]),
57 | },
58 | torch_frame.sequence_numerical:
59 | MultiNestedTensor.from_tensor_mat([[
60 | torch.randn(random.randint(0, 10)) for _ in range(
61 | len(col_names_dict[torch_frame.sequence_numerical]))
62 | ] for _ in range(num_rows)]),
63 | torch_frame.embedding:
64 | MultiEmbeddingTensor.from_tensor_list([
65 | torch.randn(num_rows, random.randint(1, 5))
66 | for _ in range(len(col_names_dict[torch_frame.embedding]))
67 | ])
68 | }
69 |
70 | y = torch.randn(num_rows)
71 |
72 | return TensorFrame(
73 | feat_dict=feat_dict,
74 | col_names_dict=col_names_dict,
75 | y=y,
76 | )
77 |
78 | return _get_fake_tensor_frame
79 |
--------------------------------------------------------------------------------
/test/gbdt/test_gbdt.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import tempfile
3 |
4 | import pytest
5 | import torch
6 |
7 | from torch_frame import Metric, TaskType, stype
8 | from torch_frame.config.text_embedder import TextEmbedderConfig
9 | from torch_frame.data.dataset import Dataset
10 | from torch_frame.datasets.fake import FakeDataset
11 | from torch_frame.gbdt import CatBoost, LightGBM, XGBoost
12 | from torch_frame.testing.text_embedder import HashTextEmbedder
13 |
14 |
15 | @pytest.mark.parametrize('gbdt_cls', [
16 | CatBoost,
17 | XGBoost,
18 | LightGBM,
19 | ])
20 | @pytest.mark.parametrize('stypes', [
21 | [stype.numerical],
22 | [stype.categorical],
23 | [stype.text_embedded],
24 | [stype.numerical, stype.numerical, stype.text_embedded],
25 | ])
26 | @pytest.mark.parametrize('task_type_and_metric', [
27 | (TaskType.REGRESSION, Metric.RMSE),
28 | (TaskType.REGRESSION, Metric.MAE),
29 | (TaskType.BINARY_CLASSIFICATION, Metric.ACCURACY),
30 | (TaskType.BINARY_CLASSIFICATION, Metric.ROCAUC),
31 | (TaskType.MULTICLASS_CLASSIFICATION, Metric.ACCURACY),
32 | ])
33 | def test_gbdt_with_save_load(gbdt_cls, stypes, task_type_and_metric):
34 | task_type, metric = task_type_and_metric
35 | dataset: Dataset = FakeDataset(
36 | num_rows=30,
37 | with_nan=True,
38 | stypes=stypes,
39 | create_split=True,
40 | task_type=task_type,
41 | col_to_text_embedder_cfg=TextEmbedderConfig(
42 | text_embedder=HashTextEmbedder(8)),
43 | )
44 | dataset.materialize()
45 | gbdt = gbdt_cls(
46 | task_type=task_type,
47 | num_classes=dataset.num_classes
48 | if task_type == TaskType.MULTICLASS_CLASSIFICATION else None,
49 | metric=metric,
50 | )
51 |
52 | with tempfile.TemporaryDirectory() as temp_dir:
53 | path = osp.join(temp_dir, 'model.json')
54 | with pytest.raises(RuntimeError, match="is not yet fitted"):
55 | gbdt.save(path)
56 |
57 | if isinstance(gbdt_cls, XGBoost):
58 | gbdt.tune(tf_train=dataset.tensor_frame,
59 | tf_val=dataset.tensor_frame, num_trials=2,
60 | num_boost_round=1000, early_stopping_rounds=2)
61 | assert gbdt.model.best_iteration is not None
62 | else:
63 | gbdt.tune(
64 | tf_train=dataset.tensor_frame,
65 | tf_val=dataset.tensor_frame,
66 | num_trials=2,
67 | num_boost_round=2,
68 | )
69 | gbdt.save(path)
70 |
71 | loaded_gbdt = gbdt_cls(
72 | task_type=task_type,
73 | num_classes=dataset.num_classes
74 | if task_type == TaskType.MULTICLASS_CLASSIFICATION else None,
75 | metric=metric,
76 | )
77 | loaded_gbdt.load(path)
78 |
79 | pred = gbdt.predict(tf_test=dataset.tensor_frame)
80 | score = gbdt.compute_metric(dataset.tensor_frame.y, pred)
81 |
82 | loaded_score = loaded_gbdt.compute_metric(dataset.tensor_frame.y, pred)
83 | dataset.tensor_frame.y = None
84 | loaded_pred = loaded_gbdt.predict(tf_test=dataset.tensor_frame)
85 |
86 | assert torch.allclose(pred, loaded_pred, atol=1e-5)
87 | assert gbdt.metric == metric
88 | assert score == loaded_score
89 | if task_type == TaskType.REGRESSION:
90 | assert (score >= 0)
91 | elif task_type == TaskType.BINARY_CLASSIFICATION:
92 | assert (0 <= score <= 1)
93 | elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
94 | assert (0 <= score <= 1)
95 |
--------------------------------------------------------------------------------
/torch_frame/datasets/mushroom.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import zipfile
3 |
4 | import pandas as pd
5 |
6 | import torch_frame
7 |
8 |
9 | class Mushroom(torch_frame.data.Dataset):
10 | r"""The `Mushroom classification Kaggle competition
11 | `_
12 | dataset. It's a task to predict whether a mushroom is edible
13 | or poisonous.
14 |
15 | **STATS:**
16 |
17 | .. list-table::
18 | :widths: 10 10 10 10 20 10
19 | :header-rows: 1
20 |
21 | * - #rows
22 | - #cols (numerical)
23 | - #cols (categorical)
24 | - #classes
25 | - Task
26 | - Missing value ratio
27 | * - 8,124
28 | - 0
29 | - 22
30 | - 2
31 | - binary_classification
32 | - 0.0%
33 | """
34 |
35 | url = 'http://archive.ics.uci.edu/static/public/73/mushroom.zip'
36 |
37 | def __init__(self, root: str):
38 | path = self.download_url(self.url, root)
39 | folder_path = osp.dirname(path)
40 |
41 | with zipfile.ZipFile(path, 'r') as zip_ref:
42 | zip_ref.extractall(folder_path)
43 |
44 | data_path = osp.join(folder_path, 'agaricus-lepiota.data')
45 |
46 | names = [
47 | 'class',
48 | 'cap-shape',
49 | 'cap-surface',
50 | 'cap-color',
51 | 'bruises',
52 | 'odor',
53 | 'gill-attachment',
54 | 'gill-spacing',
55 | 'gill-size',
56 | 'gill-color',
57 | 'stalk-shape',
58 | 'stalk-root',
59 | 'stalk-surface-above-ring',
60 | 'stalk-surface-below-ring',
61 | 'stalk-color-above-ring',
62 | 'stalk-color-below-ring',
63 | 'veil-type',
64 | 'veil-color',
65 | 'ring-number',
66 | 'ring-type',
67 | 'spore-print-color',
68 | 'population',
69 | 'habitat',
70 | ]
71 | df = pd.read_csv(data_path, names=names)
72 |
73 | col_to_stype = {
74 | 'class': torch_frame.categorical,
75 | 'cap-shape': torch_frame.categorical,
76 | 'cap-surface': torch_frame.categorical,
77 | 'cap-color': torch_frame.categorical,
78 | 'bruises': torch_frame.categorical,
79 | 'odor': torch_frame.categorical,
80 | 'gill-attachment': torch_frame.categorical,
81 | 'gill-spacing': torch_frame.categorical,
82 | 'gill-size': torch_frame.categorical,
83 | 'gill-color': torch_frame.categorical,
84 | 'stalk-shape': torch_frame.categorical,
85 | 'stalk-root': torch_frame.categorical,
86 | 'stalk-surface-above-ring': torch_frame.categorical,
87 | 'stalk-surface-below-ring': torch_frame.categorical,
88 | 'stalk-color-above-ring': torch_frame.categorical,
89 | 'stalk-color-below-ring': torch_frame.categorical,
90 | 'veil-type': torch_frame.categorical,
91 | 'veil-color': torch_frame.categorical,
92 | 'ring-number': torch_frame.categorical,
93 | 'ring-type': torch_frame.categorical,
94 | 'spore-print-color': torch_frame.categorical,
95 | 'population': torch_frame.categorical,
96 | 'habitat': torch_frame.categorical,
97 | }
98 |
99 | super().__init__(df, col_to_stype, target_col='class')
100 |
--------------------------------------------------------------------------------
/torch_frame/nn/conv/ft_transformer_convs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import (
6 | LayerNorm,
7 | Parameter,
8 | TransformerEncoder,
9 | TransformerEncoderLayer,
10 | )
11 |
12 | from torch_frame.nn.conv import TableConv
13 |
14 |
15 | class FTTransformerConvs(TableConv):
16 | r"""The FT-Transformer backbone in the
17 | `"Revisiting Deep Learning Models for Tabular Data"
18 | `_ paper.
19 |
20 | This module concatenates a learnable CLS token embedding :obj:`x_cls` to
21 | the input tensor :obj:`x` and applies a multi-layer Transformer on the
22 | concatenated tensor. After the Transformer layer, the output tensor is
23 | divided into two parts: (1) :obj:`x`, corresponding to the original input
24 | tensor, and (2) :obj:`x_cls`, corresponding to the CLS token tensor.
25 |
26 | Args:
27 | channels (int): Input/output channel dimensionality
28 | feedforward_channels (int, optional): Hidden channels used by
29 | feedforward network of the Transformer model. If :obj:`None`, it
30 | will be set to :obj:`channels` (default: :obj:`None`)
31 | num_layers (int): Number of transformer encoder layers. (default: 3)
32 | nhead (int): Number of heads in multi-head attention (default: 8)
33 | dropout (int): The dropout value (default: 0.1)
34 | activation (str): The activation function (default: :obj:`relu`)
35 | """
36 | def __init__(
37 | self,
38 | channels: int,
39 | feedforward_channels: int | None = None,
40 | # Arguments for Transformer
41 | num_layers: int = 3,
42 | nhead: int = 8,
43 | dropout: float = 0.2,
44 | activation: str = 'relu',
45 | ):
46 | super().__init__()
47 |
48 | encoder_layer = TransformerEncoderLayer(
49 | d_model=channels,
50 | nhead=nhead,
51 | dim_feedforward=feedforward_channels or channels,
52 | dropout=dropout,
53 | activation=activation,
54 | # Input and output tensors are provided as
55 | # [batch_size, seq_len, channels]
56 | batch_first=True,
57 | )
58 | encoder_norm = LayerNorm(channels)
59 | self.transformer = TransformerEncoder(encoder_layer=encoder_layer,
60 | num_layers=num_layers,
61 | norm=encoder_norm)
62 | self.cls_embedding = Parameter(torch.empty(channels))
63 | self.reset_parameters()
64 |
65 | def reset_parameters(self):
66 | torch.nn.init.normal_(self.cls_embedding, std=0.01)
67 | for p in self.transformer.parameters():
68 | if p.dim() > 1:
69 | torch.nn.init.xavier_uniform_(p)
70 |
71 | def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
72 | r"""CLS-token augmented Transformer convolution.
73 |
74 | Args:
75 | x (Tensor): Input tensor of shape [batch_size, num_cols, channels]
76 |
77 | Returns:
78 | (torch.Tensor, torch.Tensor): (Output tensor of shape
79 | [batch_size, num_cols, channels] corresponding to the input
80 | columns, Output tensor of shape [batch_size, channels],
81 | corresponding to the added CLS token column.)
82 | """
83 | B, _, _ = x.shape
84 | # [batch_size, num_cols, channels]
85 | x_cls = self.cls_embedding.repeat(B, 1, 1)
86 | # [batch_size, num_cols + 1, channels]
87 | x_concat = torch.cat([x_cls, x], dim=1)
88 | # [batch_size, num_cols + 1, channels]
89 | x_concat = self.transformer(x_concat)
90 | x_cls, x = x_concat[:, 0, :], x_concat[:, 1:, :]
91 | return x, x_cls
92 |
--------------------------------------------------------------------------------
/torch_frame/nn/encoder/stypewise_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | import torch
6 | from torch import Tensor
7 | from torch.nn import ModuleDict
8 |
9 | import torch_frame
10 | from torch_frame import TensorFrame
11 | from torch_frame.data.stats import StatType
12 | from torch_frame.nn.encoder import FeatureEncoder
13 | from torch_frame.nn.encoder.stype_encoder import StypeEncoder
14 |
15 |
16 | class StypeWiseFeatureEncoder(FeatureEncoder):
17 | r"""Feature encoder that transforms each stype tensor into embeddings and
18 | performs the final concatenation.
19 |
20 | Args:
21 | out_channels (int): Output dimensionality.
22 | col_stats
23 | (dict[str, dict[:class:`torch_frame.data.stats.StatType`, Any]]):
24 | A dictionary that maps column name into stats. Available as
25 | :obj:`dataset.col_stats`.
26 | col_names_dict (dict[:class:`torch_frame.stype`, list[str]]): A
27 | dictionary that maps stype to a list of column names. The column
28 | names are sorted based on the ordering that appear in
29 | :obj:`tensor_frame.feat_dict`.
30 | Available as :obj:`tensor_frame.col_names_dict`.
31 | stype_encoder_dict
32 | (dict[:class:`torch_frame.stype`,
33 | :class:`torch_frame.nn.encoder.StypeEncoder`]):
34 | A dictionary that maps :class:`torch_frame.stype` into
35 | :class:`torch_frame.nn.encoder.StypeEncoder` class. Only
36 | parent :class:`stypes ` are supported
37 | as keys.
38 | """
39 | def __init__(
40 | self,
41 | out_channels: int,
42 | col_stats: dict[str, dict[StatType, Any]],
43 | col_names_dict: dict[torch_frame.stype, list[str]],
44 | stype_encoder_dict: dict[torch_frame.stype, StypeEncoder],
45 | ) -> None:
46 | super().__init__()
47 |
48 | self.col_stats = col_stats
49 | self.col_names_dict = col_names_dict
50 | self.encoder_dict = ModuleDict()
51 | for stype, stype_encoder in stype_encoder_dict.items():
52 | if stype != stype.parent:
53 | if stype.parent in stype_encoder_dict:
54 | msg = (
55 | f"You can delete this {stype} directly since encoder "
56 | f"for parent stype {stype.parent} is already declared."
57 | )
58 | else:
59 | msg = (f"To resolve the issue, you can change the key from"
60 | f" {stype} to {stype.parent}.")
61 | raise ValueError(f"{stype} is an invalid stype to use in the "
62 | f"stype_encoder_dcit. {msg}")
63 | if stype not in stype_encoder.supported_stypes:
64 | raise ValueError(
65 | f"{stype_encoder} does not support encoding {stype}.")
66 |
67 | if stype in col_names_dict:
68 | stats_list = [
69 | self.col_stats[col_name]
70 | for col_name in self.col_names_dict[stype]
71 | ]
72 | # Set lazy attributes
73 | stype_encoder.stype = stype
74 | stype_encoder.out_channels = out_channels
75 | stype_encoder.stats_list = stats_list
76 | self.encoder_dict[stype.value] = stype_encoder
77 |
78 | def forward(self, tf: TensorFrame) -> tuple[Tensor, list[str]]:
79 | all_col_names = []
80 | xs = []
81 | for stype in tf.stypes:
82 | feat = tf.feat_dict[stype]
83 | col_names = self.col_names_dict[stype]
84 | x = self.encoder_dict[stype.value](feat, col_names)
85 | xs.append(x)
86 | all_col_names.extend(col_names)
87 | x = torch.cat(xs, dim=1)
88 | return x, all_col_names
89 |
--------------------------------------------------------------------------------
/torch_frame/transforms/fittable_base_transform.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import copy
4 | from abc import abstractmethod
5 | from typing import Any
6 |
7 | import torch
8 | from torch import Tensor
9 |
10 | from torch_frame import NAStrategy, TensorFrame
11 | from torch_frame.data.stats import StatType
12 | from torch_frame.transforms import BaseTransform
13 |
14 |
15 | class FittableBaseTransform(BaseTransform):
16 | r"""An abstract base class for writing fittable transforms.
17 | Fittable transforms must be fitted on training data before transform.
18 | """
19 | def __init__(self):
20 | super().__init__()
21 | self._is_fitted: bool = False
22 |
23 | def __call__(self, tf: TensorFrame) -> TensorFrame:
24 | # Shallow-copy the data so that we prevent in-place data modification.
25 | return self.forward(copy.copy(tf))
26 |
27 | @property
28 | def is_fitted(self) -> bool:
29 | r"""Whether the transform is already fitted."""
30 | return self._is_fitted
31 |
32 | def _replace_nans(self, x: Tensor, na_strategy: NAStrategy):
33 | r"""Replace NaNs based on NAStrategy.
34 |
35 | Args:
36 | x (Tensor): Input :class:`Tensor` whose NaN
37 | values in categorical columns will be replaced.
38 | na_strategy (NAStrategy): The :class:`NAStrategy` used to
39 | replace NaN values.
40 |
41 | Returns:
42 | Tensor: Output :class:`Tensor` with NaN values
43 | replaced.
44 | """
45 | x = x.clone()
46 | for col in range(x.size(1)):
47 | column_data = x[:, col]
48 | if na_strategy.is_numerical_strategy:
49 | nan_mask = torch.isnan(column_data)
50 | else:
51 | nan_mask = column_data < 0
52 | if nan_mask.all():
53 | raise ValueError("Column contains only nan values.")
54 | if not nan_mask.any():
55 | continue
56 | valid_data = column_data[~nan_mask]
57 | if na_strategy == NAStrategy.MEAN:
58 | fill_value = valid_data.mean()
59 | elif na_strategy in [NAStrategy.ZEROS, NAStrategy.MOST_FREQUENT]:
60 | fill_value = torch.tensor(0.)
61 | else:
62 | raise ValueError(f'{na_strategy} is not supported.')
63 | column_data[nan_mask] = fill_value
64 | return x
65 |
66 | def fit(
67 | self,
68 | tf: TensorFrame,
69 | col_stats: dict[str, dict[StatType, Any]],
70 | ):
71 | r"""Fit the transform with train data.
72 |
73 | Args:
74 | tf (TensorFrame): Input :class:`TensorFrame` object representing
75 | the training data.
76 | col_stats (Dict[str, Dict[StatType, Any]], optional): The column
77 | stats of the input :class:`TensorFrame`.
78 | """
79 | self._fit(tf, col_stats)
80 | self._is_fitted = True
81 |
82 | def forward(self, tf: TensorFrame) -> TensorFrame:
83 | if not self.is_fitted:
84 | raise ValueError(f"'{self.__class__.__name__}' is not yet fitted ."
85 | f"Please run `fit()` first before attempting to "
86 | f"transform the TensorFrame.")
87 |
88 | transformed_tf = self._forward(tf)
89 | transformed_tf.validate()
90 | return transformed_tf
91 |
92 | @abstractmethod
93 | def _fit(
94 | self,
95 | tf: TensorFrame,
96 | col_stats: dict[str, dict[StatType, Any]],
97 | ):
98 | raise NotImplementedError
99 |
100 | @abstractmethod
101 | def _forward(self, tf: TensorFrame) -> TensorFrame:
102 | raise NotImplementedError
103 |
104 | def state_dict(self) -> dict[str, Any]:
105 | return self.__dict__
106 |
107 | def load_state_dict(self, state_dict: dict[str, Any]):
108 | self.__dict__.update(state_dict)
109 | return self
110 |
--------------------------------------------------------------------------------
/torch_frame/typing.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Mapping
4 | from enum import Enum
5 | from typing import TypeAlias
6 |
7 | import pandas as pd
8 | import torch
9 | from torch import Tensor
10 |
11 | from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor
12 | from torch_frame.data.multi_nested_tensor import MultiNestedTensor
13 |
14 | WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
15 | WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
16 |
17 |
18 | class Metric(Enum):
19 | r"""The metric.
20 |
21 | Attributes:
22 | ACCURACY: accuracy
23 | ROCAUC: rocauc
24 | RMSE: rmse
25 | MAE: mae
26 | """
27 | ACCURACY = 'accuracy'
28 | ROCAUC = 'rocauc'
29 | RMSE = 'rmse'
30 | MAE = 'mae'
31 | R2 = 'r2'
32 |
33 | def supports_task_type(self, task_type: TaskType) -> bool:
34 | return self in task_type.supported_metrics
35 |
36 |
37 | class TaskType(Enum):
38 | r"""The type of the task.
39 |
40 | Attributes:
41 | REGRESSION: Regression task.
42 | MULTICLASS_CLASSIFICATION: Multi-class classification task.
43 | BINARY_CLASSIFICATION: Binary classification task.
44 | """
45 | REGRESSION = 'regression'
46 | MULTICLASS_CLASSIFICATION = 'multiclass_classification'
47 | BINARY_CLASSIFICATION = 'binary_classification'
48 | MULTILABEL_CLASSIFICATION = 'multilabel_classification'
49 |
50 | @property
51 | def is_classification(self) -> bool:
52 | return self in (TaskType.BINARY_CLASSIFICATION,
53 | TaskType.MULTICLASS_CLASSIFICATION)
54 |
55 | @property
56 | def is_regression(self) -> bool:
57 | return self == TaskType.REGRESSION
58 |
59 | @property
60 | def supported_metrics(self) -> list[Metric]:
61 | if self == TaskType.REGRESSION:
62 | return [Metric.RMSE, Metric.MAE, Metric.R2]
63 | elif self == TaskType.BINARY_CLASSIFICATION:
64 | return [Metric.ACCURACY, Metric.ROCAUC]
65 | elif self == TaskType.MULTICLASS_CLASSIFICATION:
66 | return [Metric.ACCURACY]
67 | else:
68 | return []
69 |
70 |
71 | class NAStrategy(Enum):
72 | r"""Strategy for dealing with NaN values in columns.
73 |
74 | Attributes:
75 | MEAN: Replaces NaN values with the mean of a
76 | :obj:`torch_frame.numerical` column.
77 | ZEROS: Replaces NaN values with zeros in a
78 | :obj:`torch_frame.numerical` column.
79 | MOST_FREQUENT: Replaces NaN values with the most frequent category of a
80 | :obj:`torch_frame.categorical` column.
81 | """
82 | MEAN = 'mean'
83 | MOST_FREQUENT = 'most_frequent'
84 | ZEROS = 'zeros'
85 | OLDEST_TIMESTAMP = 'oldest_timestamp'
86 | NEWEST_TIMESTAMP = 'newest_timestamp'
87 | MEDIAN_TIMESTAMP = 'median_timestamp'
88 |
89 | @property
90 | def is_categorical_strategy(self) -> bool:
91 | return self == NAStrategy.MOST_FREQUENT
92 |
93 | @property
94 | def is_multicategorical_strategy(self) -> bool:
95 | return self == NAStrategy.ZEROS
96 |
97 | @property
98 | def is_numerical_strategy(self) -> bool:
99 | return self in [NAStrategy.MEAN, NAStrategy.ZEROS]
100 |
101 | @property
102 | def is_timestamp_strategy(self) -> bool:
103 | return self in [
104 | NAStrategy.NEWEST_TIMESTAMP,
105 | NAStrategy.OLDEST_TIMESTAMP,
106 | NAStrategy.MEDIAN_TIMESTAMP,
107 | ]
108 |
109 |
110 | Series: TypeAlias = pd.Series
111 | DataFrame: TypeAlias = pd.DataFrame
112 |
113 | IndexSelectType: TypeAlias = int | list[int] | range | slice | Tensor
114 | ColumnSelectType: TypeAlias = str | list[str]
115 | TextTokenizationMapping: TypeAlias = Mapping[str, Tensor]
116 | TextTokenizationOutputs: TypeAlias = \
117 | list[TextTokenizationMapping] | TextTokenizationMapping
118 | TensorData: TypeAlias = (Tensor | MultiNestedTensor | MultiEmbeddingTensor
119 | | dict[str, MultiNestedTensor])
120 |
--------------------------------------------------------------------------------
/torch_frame/_stype.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class stype(Enum):
5 | r"""The semantic type of a column.
6 |
7 | A semantic type denotes the semantic meaning of a column, and denotes how
8 | columns are encoded into an embedding space within tabular deep learning
9 | models:
10 |
11 | .. code-block:: python
12 |
13 | import torch_frame
14 |
15 | stype = torch_frame.numerical # Numerical columns
16 | stype = torch_frame.categorical # Categorical columns
17 | ...
18 |
19 | Attributes:
20 | numerical: Numerical columns.
21 | categorical: Categorical columns.
22 | text_embedded: Pre-computed embeddings of text columns.
23 | text_tokenized: Tokenized text columns for finetuning.
24 | multicategorical: Multicategorical columns.
25 | sequence_numerical: Sequence of numerical values.
26 | embedding: Embedding columns.
27 | timestamp: Timestamp columns.
28 | image_embedded: Pre-computed embeddings of image columns.
29 | """
30 | numerical = 'numerical'
31 | categorical = 'categorical'
32 | text_embedded = 'text_embedded'
33 | text_tokenized = 'text_tokenized'
34 | multicategorical = 'multicategorical'
35 | sequence_numerical = 'sequence_numerical'
36 | timestamp = 'timestamp'
37 | image_embedded = 'image_embedded'
38 | embedding = 'embedding'
39 |
40 | @property
41 | def is_text_stype(self) -> bool:
42 | return self in [stype.text_embedded, stype.text_tokenized]
43 |
44 | @property
45 | def is_image_stype(self) -> bool:
46 | return self in [stype.image_embedded]
47 |
48 | @property
49 | def use_multi_nested_tensor(self) -> bool:
50 | r"""This property indicates if the data of an stype is stored in
51 | :class:`torch_frame.data.MultiNestedTensor`.
52 | """
53 | return self in [stype.multicategorical, self.sequence_numerical]
54 |
55 | @property
56 | def use_multi_embedding_tensor(self) -> bool:
57 | r"""This property indicates if the data of an stype is stored in
58 | :class:`torch_frame.data.MultiNestedTensor`.
59 | """
60 | return self in [
61 | stype.text_embedded, stype.image_embedded, stype.embedding
62 | ]
63 |
64 | @property
65 | def use_dict_multi_nested_tensor(self) -> bool:
66 | r"""This property indicates if the data of an stype is stored in
67 | a dictionary of :class:`torch_frame.data.MultiNestedTensor`.
68 | """
69 | return self in [stype.text_tokenized]
70 |
71 | @property
72 | def use_multi_tensor(self) -> bool:
73 | r"""This property indicates if the data of an
74 | :class:`~torch_frame.stype` is stored in
75 | :class:`torch_frame.data._MultiTensor`.
76 | """
77 | return self.use_multi_nested_tensor or self.use_multi_embedding_tensor
78 |
79 | @property
80 | def parent(self):
81 | r"""This property indicates if an :class:`~torch_frame.stype` is
82 | user-facing column :obj:`stype` or internal :obj:`stype` for grouping
83 | columns in :obj:`TensorFrame`. User-facing :class:`~torch_frame.stype`
84 | will be mapped to its parent during materialization. For
85 | :class:`~torch_frame.stype` that are both internal and
86 | user-facing, the parent maps to itself.
87 | """
88 | if self == stype.text_embedded:
89 | return stype.embedding
90 | elif self == stype.image_embedded:
91 | return stype.embedding
92 | else:
93 | return self
94 |
95 | def __str__(self) -> str:
96 | return f'{self.name}'
97 |
98 |
99 | numerical = stype('numerical')
100 | categorical = stype('categorical')
101 | text_embedded = stype('text_embedded')
102 | text_tokenized = stype('text_tokenized')
103 | multicategorical = stype('multicategorical')
104 | sequence_numerical = stype('sequence_numerical')
105 | timestamp = stype('timestamp')
106 | image_embedded = stype('image_embedded')
107 | embedding = stype('embedding')
108 |
--------------------------------------------------------------------------------
/torch_frame/datasets/kdd_census_income.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import os.path as osp
5 | import tarfile
6 | import zipfile
7 |
8 | import pandas as pd
9 |
10 | import torch_frame
11 | from torch_frame import stype
12 |
13 |
14 | class KDDCensusIncome(torch_frame.data.Dataset):
15 | r"""The `KDD Census Income
16 | `_
17 | dataset. It's a task of forest cover type classification
18 | based on attributes such as elevation, slop and soil type etc.
19 |
20 | **STATS:**
21 |
22 | .. list-table::
23 | :widths: 10 10 10 10 20 10
24 | :header-rows: 1
25 |
26 | * - #rows
27 | - #cols (numerical)
28 | - #cols (categorical)
29 | - #classes
30 | - Task
31 | - Missing value ratio
32 | * - 199,523
33 | - 7
34 | - 34
35 | - 2
36 | - binary_classification
37 | - 0.0%
38 | """
39 |
40 | url = 'https://archive.ics.uci.edu/static/public/117/census+income+kdd.zip'
41 |
42 | def __init__(self, root: str):
43 | data_dir = osp.join(root, 'census')
44 | filename = osp.join(data_dir, 'census-income.data')
45 | if not osp.exists(filename):
46 | path = self.download_url(self.url, root)
47 | tar_gz_path = osp.join(root, 'census.tar.gz')
48 | with zipfile.ZipFile(path, 'r') as zip_ref:
49 | zip_ref.extractall(root)
50 | with tarfile.open(tar_gz_path, 'r:gz') as tar_ref:
51 | tar_ref.extractall(data_dir)
52 | os.remove(tar_gz_path)
53 | os.remove(path)
54 |
55 | names = [
56 | 'age',
57 | 'class of worker',
58 | 'industry code',
59 | 'occupation code',
60 | 'education',
61 | 'wage per hour',
62 | 'enrolled in edu inst last wk',
63 | 'marital status',
64 | 'major industry code',
65 | 'major occupation code',
66 | 'race',
67 | 'hispanic Origin',
68 | 'sex',
69 | 'member of a labor union',
70 | 'reason for unemployment',
71 | 'full or part time employment stat',
72 | 'capital gains',
73 | 'capital losses',
74 | 'divdends from stocks',
75 | 'tax filer status',
76 | 'region of previous residence',
77 | 'state of previous residence',
78 | 'detailed household and family stat',
79 | 'detailed household summary in household',
80 | 'migration code-change in msa',
81 | 'migration code-change in reg',
82 | 'migration code-move within reg',
83 | 'live in this house 1 year ago',
84 | 'migration prev res in sunbelt',
85 | 'family members under 18',
86 | 'num persons worked for employer',
87 | 'country of birth father',
88 | 'country of birth mother',
89 | 'country of birth self',
90 | 'citizenship',
91 | 'total person income',
92 | 'own business or self employed',
93 | "fill inc questionnaire for veteran's admin",
94 | 'veterans benefits',
95 | 'weeks worked in year',
96 | 'year',
97 | 'income above 50000',
98 | ]
99 |
100 | continuous_cols = {
101 | 'age',
102 | 'wage per hour',
103 | 'capital gains',
104 | 'capital losses',
105 | 'divdends from stocks',
106 | 'num persons worked for employer',
107 | 'weeks worked in year',
108 | }
109 |
110 | col_to_stype: dict[str, stype] = {}
111 | for name in names:
112 | if name in continuous_cols:
113 | col_to_stype[name] = torch_frame.numerical
114 | else:
115 | col_to_stype[name] = torch_frame.categorical
116 |
117 | df = pd.read_csv(filename, names=names)
118 |
119 | super().__init__(df, col_to_stype, target_col='income above 50000')
120 |
--------------------------------------------------------------------------------
/torch_frame/transforms/mutual_information_sort.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from torch_frame import NAStrategy, TaskType, TensorFrame, stype
9 | from torch_frame.data.stats import StatType
10 | from torch_frame.transforms import FittableBaseTransform
11 |
12 |
13 | class MutualInformationSort(FittableBaseTransform):
14 | r"""A transform that sorts the numerical features of input
15 | :class:`TensorFrame` object based on mutual information.
16 |
17 | Args:
18 | task_type (TaskType): The task type.
19 | na_strategy (NAStrategy): Strategy used for imputing NaN values
20 | in numerical features.
21 | """
22 | def __init__(self, task_type: TaskType,
23 | na_strategy: NAStrategy = NAStrategy.MEAN):
24 | super().__init__()
25 |
26 | from sklearn.feature_selection import (
27 | mutual_info_classif,
28 | mutual_info_regression,
29 | )
30 |
31 | if task_type in [
32 | TaskType.MULTICLASS_CLASSIFICATION,
33 | TaskType.BINARY_CLASSIFICATION
34 | ]:
35 | self.mi_func = mutual_info_classif
36 | elif task_type == TaskType.REGRESSION:
37 | self.mi_func = mutual_info_regression
38 | else:
39 | raise ValueError(
40 | f"'{self.__class__.__name__}' can be only used on binary "
41 | "classification, multiclass classification or regression "
42 | f"task, but got {task_type}.")
43 | if not na_strategy.is_numerical_strategy:
44 | raise RuntimeError(
45 | f"Cannot use {na_strategy} for numerical features.")
46 | self.na_strategy = na_strategy
47 |
48 | def _fit(self, tf_train: TensorFrame, col_stats: dict[str, dict[StatType,
49 | Any]]):
50 | if tf_train.y is None:
51 | raise RuntimeError(
52 | "'{self.__class__.__name__}' cannot be used when target column"
53 | " is None.")
54 | if (stype.categorical in tf_train.col_names_dict
55 | and len(tf_train.col_names_dict[stype.categorical]) != 0):
56 | raise ValueError(f"'{self.__class__.__name__}' can be only used"
57 | " on TensorFrame with numerical only features.")
58 | feat_train = tf_train.feat_dict[stype.numerical]
59 | y_train = tf_train.y
60 | if torch.isnan(feat_train).any():
61 | feat_train = self._replace_nans(feat_train, self.na_strategy)
62 | if torch.isnan(tf_train.y).any():
63 | not_nan_indices = ~torch.isnan(y_train)
64 | if not not_nan_indices.any():
65 | raise ValueError(f"'{self.__class__.__name__}' cannot be"
66 | "performed when all target values are"
67 | " nan.")
68 | y_train = y_train[not_nan_indices]
69 | feat_train = feat_train[not_nan_indices]
70 | mi_scores = self.mi_func(feat_train.cpu(), y_train.cpu())
71 | self.mi_ranks = np.argsort(-mi_scores)
72 | self.mi_scores = mi_scores[self.mi_ranks]
73 | col_names = tf_train.col_names_dict[stype.numerical]
74 | ranks = {col_names[self.mi_ranks[i]]: i for i in range(len(col_names))}
75 | self.reordered_col_names = tf_train.col_names_dict[
76 | stype.numerical].copy()
77 |
78 | for col, rank in ranks.items():
79 | self.reordered_col_names[rank] = col
80 | self._transformed_stats = col_stats
81 |
82 | def _forward(self, tf: TensorFrame) -> TensorFrame:
83 | if tf.col_names_dict.keys() != {stype.numerical}:
84 | raise ValueError("The transform can be only used on TensorFrame"
85 | " with numerical only features.")
86 |
87 | tf.feat_dict[stype.numerical] = tf.feat_dict[
88 | stype.numerical][:, self.mi_ranks]
89 |
90 | tf.col_names_dict[stype.numerical] = self.reordered_col_names
91 |
92 | # set lazy attribute for meta features
93 | tf.mi_scores = torch.tensor(self.mi_scores, dtype=torch.float32,
94 | device=tf.device)
95 |
96 | return tf
97 |
--------------------------------------------------------------------------------
/torch_frame/nn/models/ft_transformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | from torch import Tensor
6 | from torch.nn import LayerNorm, Linear, Module, ReLU, Sequential
7 |
8 | import torch_frame
9 | from torch_frame import TensorFrame, stype
10 | from torch_frame.data.stats import StatType
11 | from torch_frame.nn.conv import FTTransformerConvs
12 | from torch_frame.nn.encoder.stype_encoder import (
13 | EmbeddingEncoder,
14 | LinearEncoder,
15 | StypeEncoder,
16 | )
17 | from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder
18 |
19 |
20 | class FTTransformer(Module):
21 | r"""The FT-Transformer model introduced in the
22 | `"Revisiting Deep Learning Models for Tabular Data"
23 | `_ paper.
24 |
25 | .. note::
26 |
27 | For an example of using FTTransformer, see `examples/revisiting.py
28 | `_.
30 |
31 | Args:
32 | channels (int): Hidden channel dimensionality
33 | out_channels (int): Output channels dimensionality
34 | num_layers (int): Number of layers. (default: :obj:`3`)
35 | col_stats(dict[str,dict[:class:`torch_frame.data.stats.StatType`,Any]]):
36 | A dictionary that maps column name into stats.
37 | Available as :obj:`dataset.col_stats`.
38 | col_names_dict (dict[:obj:`torch_frame.stype`, list[str]]): A
39 | dictionary that maps stype to a list of column names. The column
40 | names are sorted based on the ordering that appear in
41 | :obj:`tensor_frame.feat_dict`. Available as
42 | :obj:`tensor_frame.col_names_dict`.
43 | stype_encoder_dict
44 | (dict[:class:`torch_frame.stype`,
45 | :class:`torch_frame.nn.encoder.StypeEncoder`], optional):
46 | A dictionary mapping stypes into their stype encoders.
47 | (default: :obj:`None`, will call
48 | :class:`torch_frame.nn.encoder.EmbeddingEncoder()` for categorical
49 | feature and :class:`torch_frame.nn.encoder.LinearEncoder()`
50 | for numerical feature)
51 | """
52 | def __init__(
53 | self,
54 | channels: int,
55 | out_channels: int,
56 | num_layers: int,
57 | col_stats: dict[str, dict[StatType, Any]],
58 | col_names_dict: dict[torch_frame.stype, list[str]],
59 | stype_encoder_dict: dict[torch_frame.stype, StypeEncoder]
60 | | None = None,
61 | ) -> None:
62 | super().__init__()
63 | if num_layers <= 0:
64 | raise ValueError(
65 | f"num_layers must be a positive integer (got {num_layers})")
66 |
67 | if stype_encoder_dict is None:
68 | stype_encoder_dict = {
69 | stype.categorical: EmbeddingEncoder(),
70 | stype.numerical: LinearEncoder(),
71 | }
72 |
73 | self.encoder = StypeWiseFeatureEncoder(
74 | out_channels=channels,
75 | col_stats=col_stats,
76 | col_names_dict=col_names_dict,
77 | stype_encoder_dict=stype_encoder_dict,
78 | )
79 | self.backbone = FTTransformerConvs(channels=channels,
80 | num_layers=num_layers)
81 | self.decoder = Sequential(
82 | LayerNorm(channels),
83 | ReLU(),
84 | Linear(channels, out_channels),
85 | )
86 | self.reset_parameters()
87 |
88 | def reset_parameters(self) -> None:
89 | self.encoder.reset_parameters()
90 | self.backbone.reset_parameters()
91 | for m in self.decoder:
92 | if not isinstance(m, ReLU):
93 | m.reset_parameters()
94 |
95 | def forward(self, tf: TensorFrame) -> Tensor:
96 | r"""Transforming :class:`TensorFrame` object into output prediction.
97 |
98 | Args:
99 | tf (TensorFrame):
100 | Input :class:`TensorFrame` object.
101 |
102 | Returns:
103 | torch.Tensor: Output of shape [batch_size, out_channels].
104 | """
105 | x, _ = self.encoder(tf)
106 | x, x_cls = self.backbone(x)
107 | out = self.decoder(x_cls)
108 | return out
109 |
--------------------------------------------------------------------------------
/test/utils/test_io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import shutil
4 | import tempfile
5 |
6 | import pytest
7 |
8 | import torch_frame
9 | from torch_frame import TensorFrame, load, save
10 | from torch_frame.config.text_embedder import TextEmbedderConfig
11 | from torch_frame.config.text_tokenizer import TextTokenizerConfig
12 | from torch_frame.datasets import FakeDataset
13 | from torch_frame.testing.text_embedder import HashTextEmbedder
14 | from torch_frame.testing.text_tokenizer import WhiteSpaceHashTokenizer
15 |
16 | TEST_DIR = tempfile.TemporaryDirectory()
17 | TEST_DATASET_NAME = 'test_dataset_tf.pt'
18 | TEST_SAVE_LOAD_NAME = 'tf.pt'
19 |
20 |
21 | def teardown_module():
22 | if osp.exists(TEST_DIR.name):
23 | shutil.rmtree(TEST_DIR.name, ignore_errors=True)
24 |
25 |
26 | def get_fake_dataset(
27 | num_rows: int,
28 | col_to_text_embedder_cfg: TextEmbedderConfig,
29 | col_to_text_tokenizer_cfg: TextTokenizerConfig,
30 | ) -> FakeDataset:
31 | stypes = [
32 | torch_frame.numerical,
33 | torch_frame.categorical,
34 | torch_frame.multicategorical,
35 | torch_frame.text_embedded,
36 | torch_frame.text_tokenized,
37 | torch_frame.sequence_numerical,
38 | torch_frame.embedding,
39 | ]
40 | dataset = FakeDataset(
41 | num_rows=num_rows,
42 | stypes=stypes,
43 | col_to_text_embedder_cfg=col_to_text_embedder_cfg,
44 | col_to_text_tokenizer_cfg=col_to_text_tokenizer_cfg,
45 | )
46 | return dataset
47 |
48 |
49 | def test_dataset_cache():
50 | num_rows = 10
51 | out_channels = 8
52 |
53 | col_to_text_embedder_cfg = TextEmbedderConfig(
54 | text_embedder=HashTextEmbedder(out_channels))
55 | col_to_text_tokenizer_cfg = TextTokenizerConfig(
56 | text_tokenizer=WhiteSpaceHashTokenizer())
57 | dataset = get_fake_dataset(
58 | num_rows,
59 | col_to_text_embedder_cfg,
60 | col_to_text_tokenizer_cfg,
61 | )
62 |
63 | path = osp.join(TEST_DIR.name, TEST_DATASET_NAME)
64 | dataset.materialize(path=path)
65 |
66 | new_dataset = get_fake_dataset(
67 | num_rows,
68 | col_to_text_embedder_cfg,
69 | col_to_text_tokenizer_cfg,
70 | )
71 | new_dataset.df = dataset.df
72 |
73 | # Test materialize via caching
74 | new_dataset.materialize(path=path)
75 | assert new_dataset.is_materialized
76 | assert dataset.col_stats == new_dataset.col_stats
77 | assert dataset.tensor_frame == new_dataset.tensor_frame
78 |
79 | # Test `tensor_frame` converter
80 | tf = new_dataset._to_tensor_frame_converter(dataset.df)
81 | assert dataset.tensor_frame == tf
82 |
83 | # Remove saved tensor frame object
84 | os.remove(path)
85 |
86 | new_dataset = get_fake_dataset(
87 | num_rows,
88 | col_to_text_embedder_cfg,
89 | col_to_text_tokenizer_cfg,
90 | )
91 | new_dataset.df = new_dataset.df
92 |
93 | # Test materialize again with specified path
94 | new_dataset.materialize()
95 | new_dataset.materialize(path=path)
96 |
97 | assert new_dataset.is_materialized
98 |
99 |
100 | def test_save_load_tensor_frame():
101 | num_rows = 10
102 | out_channels = 8
103 | col_to_text_embedder_cfg = TextEmbedderConfig(
104 | text_embedder=HashTextEmbedder(out_channels))
105 | col_to_text_tokenizer_cfg = TextTokenizerConfig(
106 | text_tokenizer=WhiteSpaceHashTokenizer(),
107 | batch_size=None,
108 | )
109 | dataset = get_fake_dataset(num_rows, col_to_text_embedder_cfg,
110 | col_to_text_tokenizer_cfg)
111 | dataset.materialize()
112 |
113 | path = osp.join(TEST_DIR.name, TEST_SAVE_LOAD_NAME)
114 | save(dataset.tensor_frame, dataset.col_stats, path)
115 |
116 | tf, col_stats = load(path)
117 | assert dataset.col_stats == col_stats
118 | assert dataset.tensor_frame == tf
119 |
120 |
121 | class UntrustedClass:
122 | pass
123 |
124 |
125 | @pytest.mark.skipif(
126 | not torch_frame.typing.WITH_PT24,
127 | reason='Requres PyTorch 2.4',
128 | )
129 | def test_load_weights_only_gracefully(tmpdir):
130 | save(
131 | tensor_frame=TensorFrame({}, {}),
132 | col_stats={'a': UntrustedClass()},
133 | path=tmpdir.join('tf.pt'),
134 | )
135 | with pytest.warns(UserWarning, match='Weights only load failed'):
136 | load(tmpdir.join('tf.pt'))
137 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to PyTorch Frame
2 |
3 | If you are interested in contributing to PyTorch Frame, your contributions will likely fall into one of the following two categories:
4 |
5 | 1. You want to implement a new feature:
6 | - In general, we accept any features as long as they fit the scope of this package. If you are unsure about this or need help on the design/implementation of your feature, post about it in an issue.
7 | 1. You want to fix a bug:
8 | - Feel free to send a Pull Request (PR) any time you encounter a bug. Please provide a clear and concise description of what the bug was. If you are unsure about if this is a bug at all or how to fix, post about it in an issue.
9 |
10 | Once you finish implementing a feature or bug-fix, please send a PR to https://github.com/pyg-team/pytorch-frame.
11 |
12 | Your PR will be merged after one or more rounds of reviews by the [pyg-team](https://github.com/pyg-team).
13 |
14 | ## Developing PyTorch Frame
15 |
16 | To develop PyTorch Frame on your machine, here are some tips:
17 |
18 | 1. Ensure that you are running on one of the supported PyTorch versions (*e.g.*, `2.1.0`):
19 |
20 | ```python
21 | import torch
22 | print(torch.__version__)
23 | ```
24 |
25 | 1. Uninstall all existing PyTorch Frame installations.
26 | It is advised to run this command repeatedly to confirm that installations across all locations are properly removed.
27 |
28 | ```bash
29 | pip uninstall pytorch-frame
30 | ```
31 |
32 | 1. Fork and clone the PyTorch Frame repository:
33 |
34 | ```bash
35 | git clone https://github.com//pytorch-frame.git
36 | cd pytorch-frame
37 |
38 | ```
39 |
40 | 1. If you already cloned PyTorch Frame from source, update it:
41 |
42 | ```bash
43 | git pull
44 | ```
45 |
46 | 1. Install PyTorch Frame in editable mode:
47 |
48 | ```bash
49 | pip install -e ".[dev,full]"
50 | ```
51 |
52 | This mode will symlink the Python files from the current local source tree into the Python install.
53 | Hence, if you modify a Python file, you do not need to re-install PyTorch Frame again.
54 |
55 | 1. Ensure that you have a working PyTorch Frame installation by running the entire test suite with
56 |
57 | ```bash
58 | pytest
59 | ```
60 |
61 | 1. Install pre-commit hooks:
62 |
63 | ```bash
64 | pre-commit install
65 | ```
66 |
67 | ## Unit Testing
68 |
69 | The PyTorch Frame testing suite is located under `test/`.
70 | Run the test suite with
71 |
72 | ```bash
73 | # all test cases
74 | pytest
75 |
76 | # individual test cases
77 | pytest test/utils/test_split.py
78 | ```
79 |
80 | ## Continuous Integration
81 |
82 | PyTorch Frame uses [GitHub Actions](https://github.com/pyg-team/pytorch-frame/actions) in combination with [CodeCov](https://codecov.io/github/pyg-team/pytorch-frame?branch=master) for continuous integration.
83 |
84 | Everytime you send a Pull Request, your commit will be built and checked against the PyTorch Frame guidelines:
85 |
86 | 1. Ensure that your code is formatted correctly by testing against the styleguide of [`flake8`](https://github.com/PyCQA/flake8).
87 | We use the [`Flake8-pyproject`](https://pypi.org/project/Flake8-pyproject/) plugin for configuration:
88 |
89 | ```bash
90 | flake8
91 | ```
92 |
93 | If you do not want to format your code manually, we recommend to use [`yapf`](https://github.com/google/yapf).
94 |
95 | 1. Ensure that the entire test suite passes and that code coverage roughly stays the same.
96 | Please feel encouraged to provide a test with your submitted code.
97 | To test, either run
98 |
99 | ```bash
100 | pytest --cov
101 | ```
102 |
103 | (which runs a set of additional but time-consuming tests) dependening on your needs.
104 |
105 | 1. Add your feature/bugfix to the [`CHANGELOG.md`](https://github.com/pyg-team/pyotrch-frame/blob/master/CHANGELOG.md?plain=1).
106 | If multiple PRs move towards integrating a single feature, it is advised to group them together into one bullet point.
107 |
108 | ## Building Documentation
109 |
110 | To build the documentation:
111 |
112 | 1. [Build and install](#developing-pytorch-frame) PyTorch Frame from source.
113 | 1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via
114 | ```bash
115 | pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git
116 | ```
117 | 1. Generate the documentation via:
118 | ```bash
119 | cd docs
120 | make html
121 | ```
122 |
--------------------------------------------------------------------------------
/torch_frame/nn/models/mlp.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | import torch
6 | from torch import Tensor
7 | from torch.nn import (
8 | BatchNorm1d,
9 | Dropout,
10 | LayerNorm,
11 | Linear,
12 | Module,
13 | ReLU,
14 | Sequential,
15 | )
16 |
17 | import torch_frame
18 | from torch_frame import TensorFrame, stype
19 | from torch_frame.data.stats import StatType
20 | from torch_frame.nn.encoder.stype_encoder import (
21 | EmbeddingEncoder,
22 | LinearEncoder,
23 | StypeEncoder,
24 | )
25 | from torch_frame.nn.encoder.stypewise_encoder import StypeWiseFeatureEncoder
26 |
27 |
28 | class MLP(Module):
29 | r"""The light-weight MLP model that mean-pools column embeddings and
30 | applies MLP over it.
31 |
32 | Args:
33 | channels (int): The number of channels in the backbone layers.
34 | out_channels (int): The number of output channels in the decoder.
35 | num_layers (int): The number of layers in the backbone.
36 | col_stats(dict[str,Dict[:class:`torch_frame.data.stats.StatType`,Any]]):
37 | A dictionary that maps column name into stats.
38 | Available as :obj:`dataset.col_stats`.
39 | col_names_dict (dict[:class:`torch_frame.stype`, List[str]]): A
40 | dictionary that maps stype to a list of column names. The column
41 | names are sorted based on the ordering that appear in
42 | :obj:`tensor_frame.feat_dict`. Available as
43 | :obj:`tensor_frame.col_names_dict`.
44 | stype_encoder_dict
45 | (dict[:class:`torch_frame.stype`,
46 | :class:`torch_frame.nn.encoder.StypeEncoder`], optional):
47 | A dictionary mapping stypes into their stype encoders.
48 | (default: :obj:`None`, will call :obj:`EmbeddingEncoder()`
49 | for categorical feature and :obj:`LinearEncoder()` for
50 | numerical feature)
51 | normalization (str, optional): The type of normalization to use.
52 | :obj:`batch_norm`, :obj:`layer_norm`, or :obj:`None`.
53 | (default: :obj:`layer_norm`)
54 | dropout_prob (float): The dropout probability (default: `0.2`).
55 | """
56 | def __init__(
57 | self,
58 | channels: int,
59 | out_channels: int,
60 | num_layers: int,
61 | col_stats: dict[str, dict[StatType, Any]],
62 | col_names_dict: dict[torch_frame.stype, list[str]],
63 | stype_encoder_dict: dict[torch_frame.stype, StypeEncoder]
64 | | None = None,
65 | normalization: str | None = "layer_norm",
66 | dropout_prob: float = 0.2,
67 | ) -> None:
68 | super().__init__()
69 |
70 | if stype_encoder_dict is None:
71 | stype_encoder_dict = {
72 | stype.categorical: EmbeddingEncoder(),
73 | stype.numerical: LinearEncoder(),
74 | }
75 |
76 | self.encoder = StypeWiseFeatureEncoder(
77 | out_channels=channels,
78 | col_stats=col_stats,
79 | col_names_dict=col_names_dict,
80 | stype_encoder_dict=stype_encoder_dict,
81 | )
82 |
83 | self.mlp = Sequential()
84 |
85 | for _ in range(num_layers - 1):
86 | self.mlp.append(Linear(channels, channels))
87 | if normalization == "layer_norm":
88 | self.mlp.append(LayerNorm(channels))
89 | elif normalization == "batch_norm":
90 | self.mlp.append(BatchNorm1d(channels))
91 | self.mlp.append(ReLU())
92 | self.mlp.append(Dropout(p=dropout_prob))
93 | self.mlp.append(Linear(channels, out_channels))
94 |
95 | self.reset_parameters()
96 |
97 | def reset_parameters(self) -> None:
98 | self.encoder.reset_parameters()
99 | for param in self.mlp:
100 | if hasattr(param, 'reset_parameters'):
101 | param.reset_parameters()
102 |
103 | def forward(self, tf: TensorFrame) -> Tensor:
104 | r"""Transforming :class:`TensorFrame` object into output prediction.
105 |
106 | Args:
107 | tf (TensorFrame): Input :class:`TensorFrame` object.
108 |
109 | Returns:
110 | torch.Tensor: Output of shape [batch_size, out_channels].
111 | """
112 | x, _ = self.encoder(tf)
113 |
114 | x = torch.mean(x, dim=1)
115 |
116 | out = self.mlp(x)
117 | return out
118 |
--------------------------------------------------------------------------------
/benchmark/encoder/README.md:
--------------------------------------------------------------------------------
1 | # Encoder Benchmark
2 |
3 | ## Usage
4 |
5 | Exemplary command:
6 |
7 | ```
8 | python encoder_benchmark.py --stype-kv categorical embedding --stype-kv numerical linear
9 | ```
10 |
11 | It will create a dataset that will contain categorical and numerical columns and will use for them
12 | embedding and linear encoders, respectively.
13 |
14 | Arguments:
15 |
16 | - **--stype-kv**: Specify the stype(s) and corresponding encoder(s) to run.
17 | - **--num-rows**: The number of rows in the dataset (default is `8192`).
18 | - **--out-channels**: The number of output channels (default is `128`).
19 | - **--with-nan**: If specified, the dataset will include NaN values.
20 | - **--runs**: The number of runs for the benchmark (default is `1000`).
21 | - **--warmup-size**: The size of the warmup stage (default is `200`).
22 | - **--torch-profile**: If specified, torch profiling will be enabled.
23 | - **--line-profile**: If specified, line profiling will be enabled.
24 | - **--line-profile-level**: The level of line profiling (default is `'encode_forward'`).
25 | - **--device**: The device to run the benchmark on (default is `'cpu'`).
26 |
27 | No matter if any profiler is used, benchmark always outputs a latency (single run execution time), e.g.:
28 |
29 | ```
30 | Latency: 0.034277s
31 | ```
32 |
33 | Torch profiler produces a table of operations sorted by execution time, e.g.:
34 |
35 | ```
36 | ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
37 | Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
38 | ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
39 | aten::cat 47.49% 2.027s 48.05% 2.051s 1.025ms 2000
40 | aten::nan_to_num 19.74% 842.584ms 39.35% 1.680s 419.945us 4000
41 | aten::add 10.04% 428.549ms 10.04% 428.549ms 214.274us 2000
42 | aten::index_select 6.28% 268.051ms 7.78% 331.959ms 165.980us 2000
43 | aten::mul 4.96% 211.853ms 4.96% 211.853ms 211.853us 1000
44 | aten::sub 1.36% 58.064ms 1.36% 58.064ms 58.064us 1000
45 | aten::any 1.30% 55.612ms 1.41% 60.278ms 30.139us 2000
46 | aten::div 1.22% 52.159ms 1.22% 52.159ms 52.159us 1000
47 | ...
48 | ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------
49 | Self CPU time total: 4.268s
50 | ```
51 |
52 | Line profiler shows how many percent was spent on each of method lines, e.g.:
53 |
54 | ```
55 | Total time: 1.03661 s
56 | File: {PF_BASE_PATH}/pytorch-frame/torch_frame/nn/encoder/stype_encoder.py
57 | Function: encode_forward at line 295
58 |
59 | Line # Hits Time Per Hit % Time Line Contents
60 | ==============================================================
61 | 295 def encode_forward(
62 | 296 self,
63 | 297 feat: Tensor,
64 | 298 col_names: list[str] | None = None,
65 | 299 ) -> Tensor:
66 | 300 # TODO: Make this more efficient.
67 | 301 # Increment the index by one so that NaN index (-1) becomes 0
68 | 302 # (padding_idx)
69 | 303 # feat: [batch_size, num_cols]
70 | 304 1200 29867.4 24.9 2.9 feat = feat + 1
71 | 305 1200 944.3 0.8 0.1 xs = []
72 | 306 3600 19345.7 5.4 1.9 for i, emb in enumerate(self.embs):
73 | 307 2400 466967.7 194.6 45.0 xs.append(emb(feat[:, i]))
74 | 308 # [batch_size, num_cols, hidden_channels]
75 | 309 1200 519044.9 432.5 50.1 x = torch.stack(xs, dim=1)
76 | 310 1200 440.8 0.4 0.0 return x
77 | ```
78 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires=["flit_core >=3.12,<4"]
3 | build-backend="flit_core.buildapi"
4 |
5 | [project]
6 | name="pytorch-frame"
7 | authors=[
8 | {name="PyG Team", email="team@pyg.org"},
9 | ]
10 | dynamic=["version"]
11 | description="Tabular Deep Learning Library for PyTorch"
12 | readme="README.md"
13 | requires-python=">=3.10"
14 | keywords=[
15 | "deep-learning",
16 | "pytorch",
17 | "tabular-learning",
18 | "data-frame",
19 | ]
20 | license = "MIT"
21 | license-files = ["LICENSE"]
22 | classifiers=[
23 | "Development Status :: 4 - Beta",
24 | "Programming Language :: Python",
25 | "Programming Language :: Python :: 3.10",
26 | "Programming Language :: Python :: 3.11",
27 | "Programming Language :: Python :: 3.12",
28 | "Programming Language :: Python :: 3.13",
29 | "Programming Language :: Python :: 3 :: Only",
30 | ]
31 | dependencies=[
32 | "numpy",
33 | "pandas",
34 | "torch",
35 | "tqdm",
36 | "pyarrow",
37 | "Pillow",
38 | ]
39 |
40 | [project.optional-dependencies]
41 | test=[
42 | "pytest",
43 | "pytest-cov",
44 | "mypy",
45 | ]
46 | dev=[
47 | "pytorch-frame[test]",
48 | "pre-commit",
49 | ]
50 | full=[
51 | "scikit-learn",
52 | "xgboost>=1.7.0, <2.0.0",
53 | "optuna>=3.0.0",
54 | "optuna-integration",
55 | "mpmath==1.3.0",
56 | "catboost",
57 | "lightgbm",
58 | "datasets",
59 | "torchmetrics",
60 | ]
61 |
62 | [project.urls]
63 | homepage="https://pyg.org"
64 | documentation="https://pytorch-frame.readthedocs.io"
65 | repository="https://github.com/pyg-team/pytorch-frame.git"
66 | changelog="https://github.com/pyg-team/pytorch-frame/blob/master/CHANGELOG.md"
67 |
68 | [tool.flit.module]
69 | name="torch_frame"
70 |
71 | [tool.ruff] # https://docs.astral.sh/ruff/rules
72 | target-version = "py310"
73 | src = ["torch_frame", "test", "examples", "benchmark"]
74 | line-length = 80
75 | indent-width = 4
76 |
77 | [tool.ruff.lint]
78 | select = [
79 | "B", # flake8-bugbear
80 | "D", # pydocstyle
81 | "UP", # pyupgrade
82 | ]
83 | ignore = [
84 | "D100", # TODO: Don't ignore "Missing docstring in public module"
85 | "D101", # TODO: Don't ignore "Missing docstring in public class"
86 | "D102", # TODO: Don't ignore "Missing docstring in public method"
87 | "D103", # TODO: Don't ignore "Missing docstring in public function"
88 | "D105", # Ignore "Missing docstring in magic method"
89 | "D107", # Ignore "Missing docstring in __init__"
90 | "D205", # Ignore "blank line required between summary line and description"
91 | ]
92 |
93 | [tool.ruff.lint.pydocstyle]
94 | convention = "google"
95 |
96 | [tool.yapf]
97 | based_on_style = "pep8"
98 | split_before_named_assigns = false
99 | blank_line_before_nested_class_or_def = false
100 |
101 | [tool.isort]
102 | multi_line_output = 3
103 | include_trailing_comma = true
104 | skip = [".gitignore", "__init__.py"]
105 |
106 | [tool.flake8]
107 | ignore = ["F811", "W503", "W504"]
108 |
109 | [tool.mypy]
110 | files = ["torch_frame"]
111 | install_types = true
112 | non_interactive = true
113 | ignore_missing_imports = true
114 | show_error_codes = true
115 | warn_redundant_casts = true
116 | warn_unused_configs = true
117 | warn_unused_ignores = true
118 |
119 | # TODO: the goal is for this ignore list to be empty
120 | [[tool.mypy.overrides]]
121 | ignore_errors = true
122 | # Run this command to generate this list of files
123 | # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
124 | module = [
125 | "torch_frame.data.dataset",
126 | "torch_frame.data.loader",
127 | "torch_frame.data.mapper",
128 | "torch_frame.data.stats",
129 | "torch_frame.data.tensor_frame",
130 | "torch_frame.gbdt.gbdt",
131 | "torch_frame.gbdt.tuned_catboost",
132 | "torch_frame.gbdt.tuned_lightgbm",
133 | "torch_frame.gbdt.tuned_xgboost",
134 | "torch_frame.nn.encoder.stype_encoder",
135 | "torch_frame.testing.text_tokenizer",
136 | "torch_frame.transforms.base_transform",
137 | "torch_frame.transforms.cat_to_num_transform",
138 | "torch_frame.transforms.fittable_base_transform",
139 | "torch_frame.transforms.mutual_information_sort",
140 | ]
141 |
142 | [tool.pytest.ini_options]
143 | addopts = [
144 | "--capture=no",
145 | "--color=yes",
146 | "-vv",
147 | ]
148 |
149 | [tool.coverage.report]
150 | exclude_lines = [
151 | "pragma: no cover",
152 | "pass",
153 | "raise NotImplementedError",
154 | ]
155 |
--------------------------------------------------------------------------------
/examples/tabnet.py:
--------------------------------------------------------------------------------
1 | """Reported (reproduced) results of of TabNet model in the original paper
2 | https://arxiv.org/abs/1908.07442.
3 |
4 | Forest Cover Type: 96.99 (96.53)
5 | KDD Census Income: 95.5 (95.41)
6 | """
7 |
8 | import argparse
9 | import os.path as osp
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.optim.lr_scheduler import ExponentialLR
14 | from tqdm import tqdm
15 |
16 | from torch_frame.data import DataLoader
17 | from torch_frame.datasets import ForestCoverType, KDDCensusIncome
18 | from torch_frame.nn import TabNet
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--dataset', type=str, default="ForestCoverType",
22 | choices=["ForestCoverType", "KDDCensusIncome"])
23 | parser.add_argument('--channels', type=int, default=128)
24 | parser.add_argument('--gamma', type=int, default=1.2)
25 | parser.add_argument('--num_layers', type=int, default=6)
26 | parser.add_argument('--batch_size', type=int, default=4096)
27 | parser.add_argument('--lr', type=float, default=0.005)
28 | parser.add_argument('--epochs', type=int, default=50)
29 | parser.add_argument('--seed', type=int, default=0)
30 | parser.add_argument('--compile', action='store_true')
31 | args = parser.parse_args()
32 |
33 | torch.manual_seed(args.seed)
34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35 |
36 | # Prepare datasets
37 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
38 | args.dataset)
39 | if args.dataset == "ForestCoverType":
40 | dataset = ForestCoverType(root=path)
41 | elif args.dataset == "KDDCensusIncome":
42 | dataset = KDDCensusIncome(root=path)
43 | else:
44 | raise ValueError(f"Unsupported dataset called {args.dataset}")
45 |
46 | dataset.materialize()
47 | assert dataset.task_type.is_classification
48 | dataset = dataset.shuffle()
49 | # Split ratio is set to 80% / 10% / 10% (no clear mentioning of split in the
50 | # original TabNet paper)
51 | train_dataset, val_dataset, test_dataset = dataset[:0.8], dataset[
52 | 0.8:0.9], dataset[0.9:]
53 |
54 | # Set up data loaders
55 | train_tensor_frame = train_dataset.tensor_frame
56 | val_tensor_frame = val_dataset.tensor_frame
57 | test_tensor_frame = test_dataset.tensor_frame
58 | train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size,
59 | shuffle=True)
60 | val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
61 | test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size)
62 |
63 | # Set up model and optimizer
64 | model = TabNet(
65 | out_channels=dataset.num_classes,
66 | num_layers=args.num_layers,
67 | split_attn_channels=args.channels,
68 | split_feat_channels=args.channels,
69 | gamma=args.gamma,
70 | col_stats=dataset.col_stats,
71 | col_names_dict=train_tensor_frame.col_names_dict,
72 | ).to(device)
73 | model = torch.compile(model, dynamic=True) if args.compile else model
74 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
75 | lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
76 |
77 |
78 | def train(epoch: int) -> float:
79 | model.train()
80 | loss_accum = total_count = 0
81 |
82 | for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'):
83 | tf = tf.to(device)
84 | pred = model(tf)
85 | loss = F.cross_entropy(pred, tf.y)
86 | optimizer.zero_grad()
87 | loss.backward()
88 | loss_accum += float(loss) * len(tf.y)
89 | total_count += len(tf.y)
90 | optimizer.step()
91 | return loss_accum / total_count
92 |
93 |
94 | @torch.no_grad()
95 | def test(loader: DataLoader) -> float:
96 | model.eval()
97 | accum = total_count = 0
98 |
99 | for tf in loader:
100 | tf = tf.to(device)
101 | pred = model(tf)
102 | pred_class = pred.argmax(dim=-1)
103 | accum += float((tf.y == pred_class).sum())
104 | total_count += len(tf.y)
105 |
106 | return accum / total_count
107 |
108 |
109 | best_val_acc = 0
110 | best_test_acc = 0
111 | for epoch in range(1, args.epochs + 1):
112 | train_loss = train(epoch)
113 | train_acc = test(train_loader)
114 | val_acc = test(val_loader)
115 | test_acc = test(test_loader)
116 | if best_val_acc < val_acc:
117 | best_val_acc = val_acc
118 | best_test_acc = test_acc
119 | print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
120 | f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
121 | lr_scheduler.step()
122 |
123 | print(f'Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}')
124 |
--------------------------------------------------------------------------------