├── .flake8 ├── .github └── workflows │ └── main.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── examples ├── film │ └── film.py ├── han │ └── han.py ├── heat │ └── heat.py ├── hgt │ ├── .gitignore │ └── hgt.py ├── rgat │ └── rgat.py └── rgcn │ ├── .gitignore │ └── rgcn.py ├── fasten ├── __init__.py ├── nn │ ├── __init__.py │ ├── conv │ │ ├── __init__.py │ │ ├── heat_conv.py │ │ ├── hgt_conv.py │ │ ├── rgat_conv.py │ │ └── rgcn_conv.py │ └── linear │ │ ├── __init__.py │ │ └── linear.py ├── operators │ ├── __init__.py │ ├── torch_ops │ │ ├── __init__.py │ │ └── segment_matmul.py │ └── triton_ops │ │ ├── __init__.py │ │ ├── kernels │ │ ├── __init__.py │ │ └── matmul.py │ │ └── segment_matmul.py ├── ops.py ├── runtime │ ├── __init__.py │ └── stream_pool.py ├── scheduler.py ├── stats.py ├── tensor_slice.py └── utils.py ├── setup.py └── test ├── datasets_csv ├── ACM.csv ├── AIFB.csv ├── AM.csv ├── BGS.csv ├── DBLP.csv ├── Freebase.csv ├── IMDB.csv └── MUTAG.csv ├── test_nn.py ├── test_ops.py ├── test_stats.py ├── test_tensor_slice.py ├── test_triton.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501,E701,E731 3 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: ci-cpu 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | permissions: write-all 12 | jobs: 13 | build: 14 | runs-on: ubuntu-20.04 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Setup python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: 3.11.2 22 | 23 | - name: Caching dependencies 24 | uses: actions/cache@v3 25 | with: 26 | path: ${{env.pythonLocation}} 27 | key: ${{env.pythonLocation}} - ${{ hashFiles('setup.py') }} 28 | 29 | - name: Install dependencies 30 | run: | 31 | pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 32 | pip install git+https://github.com/pyg-team/pytorch_geometric.git 33 | pip install pyg-lib -f https://data.pyg.org/whl/torch-2.2.0+cpu.html 34 | pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly 35 | CI_ONLY=1 pip install . 36 | 37 | - name: Run test cases 38 | run: | 39 | cd ./test; 40 | pytest --tb=line -k "cpu and(slices0 or slices1)" test_ops.py 41 | pytest --tb=line -k "cpu" test_nn.py test_triton.py test_tensor_slice.py test_stats.py 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | build 3 | *.egg-info 4 | __pycache__ 5 | .vscode 6 | examples/data 7 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_local_folder=fasten 3 | line_length=88 4 | py_version=36 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-symlinks 6 | - id: destroyed-symlinks 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: check-ast 12 | - id: check-added-large-files 13 | - id: check-merge-conflict 14 | - id: check-executables-have-shebangs 15 | - id: check-shebang-scripts-are-executable 16 | - id: detect-private-key 17 | - id: debug-statements 18 | - repo: https://github.com/PyCQA/isort 19 | rev: 5.12.0 20 | hooks: 21 | - id: isort 22 | stages: [commit, push, manual] 23 | - repo: https://github.com/pre-commit/mirrors-autopep8 24 | rev: v1.6.0 25 | hooks: 26 | - id: autopep8 27 | args: ["-i"] 28 | stages: [commit, push, manual] 29 | - repo: https://github.com/pycqa/flake8 30 | rev: 6.0.0 31 | hooks: 32 | - id: flake8 33 | # TODO: uncomment this to enable more flake8 plugins 34 | # additional_dependencies: 35 | # - flake8-bugbear 36 | # - flake8-comprehensions 37 | # - flake8-docstrings 38 | # - flake8-pyi 39 | # - flake8-simplify 40 | stages: [commit, push, manual] 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Deep-Learning-Profiling-Tools 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fasten: A Library of *Fast Segment* Operators 2 | 3 | ## Introduction 4 | 5 | Fasten is a library aimed at speeding up Heterogeneous Graph Neural Network (HGNN) workloads. 6 | The current version of Fasten focuses on improving segmented matrix multiplication, a critical operator in HGNNs. 7 | Fasten implements a simple interface, making it easy to integrate with existing graph library PyG with minimal changes. 8 | Fasten achieved an average speedup of *13.65x* and *4.72x* in operator-wise benchmarks compared to CUTLASS and cuBLAS, respectively. 9 | 10 | ### Fasten vs CUTLASS 11 | 12 | ![figure9_fasten_vs_cutlass(pyg)](https://github.com/Deep-Learning-Profiling-Tools/fasten/assets/2306281/d88fab7c-a331-4978-9157-08e448afcce5) 13 | 14 | ### Fasten vs cuBLAS 15 | 16 | ![figure9_fasten_vs_cublas(torch)](https://github.com/Deep-Learning-Profiling-Tools/fasten/assets/2306281/4f8fbe5f-f8d4-45b2-9f92-ac3f7cb97c28) 17 | 18 | ## Installation 19 | 20 | ### Build Instructions 21 | 22 | Install pytorch nightly and triton nightly. We use relatively new triton features so old triton releases may crash. 23 | 24 | ```bash 25 | pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 26 | pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly 27 | ``` 28 | 29 | You may need to build triton from source before *proton* is distributed with triton's pip wheel. 30 | 31 | ```bash 32 | git clone https://github.com/Deep-Learning-Profiling-Tools/fasten.git && cd fasten 33 | pip install . 34 | ``` 35 | 36 | ## Examples 37 | 38 | Fasten's segment matrix multiplication operator has been integrated with various HGNN architecture such as *RGCN*, *HGT*, *RGAT* in PyG. 39 | Examples on how to run the examples can be found below: 40 | 41 | ### GNN Examples 42 | 43 | - RGCN 44 | 45 | ```bash 46 | cd examples/rgcn 47 | # Without fasten 48 | # Available datasets are: AIFB, MUTAG, BGS, AM 49 | python rgcn.py --device cuda --dataset AIFB 50 | # With fasten 51 | python rgat.py --device cuda --mode fasten --dataset AIFB 52 | ``` 53 | 54 | - HGT 55 | 56 | ```bash 57 | cd examples/rgcn 58 | # Without fasten 59 | # Available datasets are: DBLP, Freebase, AIFB, MUTAG, BGS, AM 60 | python rgcn.py --device cuda --example DBLP 61 | # With fasten 62 | python rgat.py --device cuda --mode fasten --example DBLP 63 | ``` 64 | 65 | - RGAT 66 | 67 | ```bash 68 | cd examples/rgat 69 | # Without fasten 70 | # Available datasets are: AIFB, MUTAG, BGS, AM 71 | python rgat.py --device cuda --dataset MUTAG 72 | # With fasten 73 | python rgat.py --device cuda --mode fasten --dataset MUTAG 74 | ``` 75 | 76 | ### Benchmarking 77 | 78 | ```bash 79 | cd test 80 | pytest -vs test_op.py::test_perf 81 | ``` 82 | 83 | ## Compatibility 84 | 85 | ### Supported Platforms 86 | 87 | - Linux 88 | 89 | ### Supported Hardware 90 | 91 | - NVIDIA GPUs (Compute Capability 7.0+) 92 | 93 | ### Software requirements 94 | 95 | - Pytorch >=2.2.0 96 | - Triton >=3.0.0 97 | - PyG >=2.6.0 98 | 99 | ## Publication 100 | 101 | - Keren Zhou, Karthik Ganapathi Subramanian, Po-Hsun Lin, Matthias Fey, Binqian Yin, and Jiajia Li. 2024. 102 | FASTEN: Fast GPU-accelerated Segmented Matrix Multiplication for Heterogeneous Graph Neural Networks. 103 | In Proceedings of the 38th ACM International Conference on Supercomputing (ICS’24), June 4–7, 2024, Kyoto, Japan. 104 | -------------------------------------------------------------------------------- /examples/film/film.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from sklearn.metrics import f1_score 7 | from torch.nn import BatchNorm1d 8 | from torch.profiler import record_function 9 | from torch_geometric.datasets import PPI 10 | from torch_geometric.loader import DataLoader 11 | from torch_geometric.nn import FiLMConv 12 | 13 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI') 14 | train_dataset = PPI(path, split='train') 15 | val_dataset = PPI(path, split='val') 16 | test_dataset = PPI(path, split='test') 17 | train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) 18 | val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False) 19 | test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) 20 | 21 | 22 | class Net(torch.nn.Module): 23 | def __init__(self, in_channels, hidden_channels, out_channels, num_layers, 24 | dropout=0.0): 25 | super().__init__() 26 | self.dropout = dropout 27 | 28 | self.convs = torch.nn.ModuleList() 29 | self.convs.append(FiLMConv(in_channels, hidden_channels)) 30 | for _ in range(num_layers - 2): 31 | self.convs.append(FiLMConv(hidden_channels, hidden_channels)) 32 | self.convs.append(FiLMConv(hidden_channels, out_channels, act=None)) 33 | 34 | self.norms = torch.nn.ModuleList() 35 | for _ in range(num_layers - 1): 36 | self.norms.append(BatchNorm1d(hidden_channels)) 37 | 38 | def forward(self, x, edge_index): 39 | for conv, norm in zip(self.convs[:-1], self.norms): 40 | x = norm(conv(x, edge_index)) 41 | x = F.dropout(x, p=self.dropout, training=self.training) 42 | x = self.convs[-1](x, edge_index) 43 | return x 44 | 45 | 46 | if torch.cuda.is_available(): 47 | device = torch.device('cuda') 48 | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 49 | device = torch.device('mps') 50 | else: 51 | device = torch.device('cpu') 52 | 53 | model = Net(in_channels=train_dataset.num_features, hidden_channels=320, 54 | out_channels=train_dataset.num_classes, num_layers=4, 55 | dropout=0.1).to(device) 56 | criterion = torch.nn.BCEWithLogitsLoss() 57 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 58 | 59 | 60 | def train(): 61 | model.train() 62 | 63 | total_loss = 0 64 | for data in train_loader: 65 | data = data.to(device) 66 | optimizer.zero_grad() 67 | with record_function("film_inference"): 68 | loss = criterion(model(data.x, data.edge_index), data.y) 69 | total_loss += loss.item() * data.num_graphs 70 | loss.backward() 71 | optimizer.step() 72 | return total_loss / len(train_loader.dataset) 73 | 74 | 75 | @torch.no_grad() 76 | def test(loader): 77 | model.eval() 78 | 79 | ys, preds = [], [] 80 | for data in loader: 81 | ys.append(data.y) 82 | out = model(data.x.to(device), data.edge_index.to(device)) 83 | preds.append((out > 0).float().cpu()) 84 | 85 | y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy() 86 | return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0 87 | 88 | 89 | times = [] 90 | for epoch in range(1, 5): 91 | start = time.time() 92 | loss = train() 93 | val_f1 = test(val_loader) 94 | test_f1 = test(test_loader) 95 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, ' 96 | f'Test: {test_f1:.4f}') 97 | times.append(time.time() - start) 98 | print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") 99 | -------------------------------------------------------------------------------- /examples/han/han.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Dict, List, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import torch_geometric.transforms as T 7 | from torch import nn 8 | from torch.profiler import record_function 9 | from torch_geometric.datasets import IMDB 10 | from torch_geometric.nn import HANConv 11 | 12 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'IMDB') 13 | 14 | metapaths = [[('movie', 'actor'), ('actor', 'movie')], 15 | [('movie', 'director'), ('director', 'movie')]] 16 | transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True, 17 | drop_unconnected_node_types=True) 18 | dataset = IMDB(path, transform=transform) 19 | data = dataset[0] 20 | print(data) 21 | 22 | 23 | class HAN(nn.Module): 24 | def __init__(self, in_channels: Union[int, Dict[str, int]], 25 | out_channels: int, hidden_channels=128, heads=8): 26 | super().__init__() 27 | self.han_conv = HANConv(in_channels, hidden_channels, heads=heads, 28 | dropout=0.6, metadata=data.metadata()) 29 | self.lin = nn.Linear(hidden_channels, out_channels) 30 | 31 | def forward(self, x_dict, edge_index_dict): 32 | out = self.han_conv(x_dict, edge_index_dict) 33 | out = self.lin(out['movie']) 34 | return out 35 | 36 | 37 | model = HAN(in_channels=-1, out_channels=3) 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | data, model = data.to(device), model.to(device) 40 | 41 | with torch.no_grad(): # Initialize lazy modules. 42 | out = model(data.x_dict, data.edge_index_dict) 43 | 44 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001) 45 | 46 | 47 | def train() -> float: 48 | model.train() 49 | optimizer.zero_grad() 50 | with record_function("han_inference"): 51 | out = model(data.x_dict, data.edge_index_dict) 52 | mask = data['movie'].train_mask 53 | loss = F.cross_entropy(out[mask], data['movie'].y[mask]) 54 | loss.backward() 55 | optimizer.step() 56 | return float(loss) 57 | 58 | 59 | @torch.no_grad() 60 | def test() -> List[float]: 61 | model.eval() 62 | pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1) 63 | 64 | accs = [] 65 | for split in ['train_mask', 'val_mask', 'test_mask']: 66 | mask = data['movie'][split] 67 | acc = (pred[mask] == data['movie'].y[mask]).sum() / mask.sum() 68 | accs.append(float(acc)) 69 | return accs 70 | 71 | 72 | best_val_acc = 0 73 | start_patience = patience = 100 74 | for epoch in range(1, 5): 75 | 76 | loss = train() 77 | train_acc, val_acc, test_acc = test() 78 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' 79 | f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') 80 | 81 | if best_val_acc <= val_acc: 82 | patience = start_patience 83 | best_val_acc = val_acc 84 | else: 85 | patience -= 1 86 | 87 | if patience <= 0: 88 | print('Stopping training as validation accuracy did not improve ' 89 | f'for {start_patience} epochs') 90 | break 91 | -------------------------------------------------------------------------------- /examples/heat/heat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import random 4 | from typing import List, Tuple 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch_geometric 9 | import torch_geometric.transforms as T 10 | from torch import Tensor 11 | from torch.profiler import ProfilerActivity, profile 12 | from torch_geometric.datasets import DBLP, HGBDataset 13 | from torch_geometric.nn import HEATConv, Linear 14 | from torch_geometric.utils import index_sort 15 | from torch_geometric.utils.sparse import index2ptr 16 | from triton.testing import do_bench 17 | 18 | from fasten import Engine, TensorSlice, compact_tensor_types 19 | from fasten.nn import FastenHEATConv 20 | 21 | torch.backends.cuda.matmul.allow_tf32 = True 22 | torch_geometric.backend.use_segment_matmul = True 23 | 24 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'DBLP') 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--device', type=str, default='cpu', 27 | choices=['cpu', 'cuda']) 28 | parser.add_argument('--mode', type=str, default='pyg', 29 | choices=['pyg', 'fasten']) 30 | parser.add_argument('--example', type=str, default='dblp', 31 | choices=['dblp', 'freebase']) 32 | parser.add_argument('--profile', type=str, default='none', 33 | choices=['none', 'profile', 'benchmark']) 34 | parser.add_argument('--hidden_size', type=int, default=32) 35 | args = parser.parse_args() 36 | device = torch.device(args.device) 37 | 38 | 39 | # We initialize conference node features with a single one-vector as feature: 40 | if args.example == 'dblp': 41 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP') 42 | # We initialize conference node features with a single one-vector as feature: 43 | dataset = DBLP(path, transform=T.Constant(node_types='conference')) 44 | out_channels = 4 # 4 class labels 45 | else: 46 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/HGBD') 47 | transform = T.Compose([T.Constant(value=random.random(), 48 | node_types=['book', 'film', 'music', 'sports', 'people', 'location', 'organization', 'business'])]) 49 | dataset = HGBDataset(path, "Freebase", transform=transform) 50 | out_channels = 7 # 7 class labels 51 | 52 | data = dataset[0] 53 | data = data.to_homogeneous() 54 | # Create ramdom values for edge_attr 55 | data["edge_attr"] = torch.randn((data.edge_index.shape[1], 2)) 56 | 57 | 58 | def ptr_to_tensor_slice(ptr: List, data: Tensor = None, is_sorted: bool = False) -> Tuple[TensorSlice, List]: 59 | assert ptr is not None 60 | slices = [slice(ptr[i], ptr[i + 1]) for i in range(len(ptr) - 1)] 61 | types = torch.zeros((ptr[-1],), dtype=torch.int) 62 | for i, s in enumerate(slices): 63 | types[s] = i 64 | tensor_slice = compact_tensor_types(data=data, types=types, is_sorted=is_sorted, device=device) 65 | return tensor_slice, slices 66 | 67 | 68 | def tensor_slice_gen(data) -> TensorSlice: 69 | sorted_node_type, _ = index_sort(data.node_type, len(torch.unique(data.node_type))) 70 | ptr = index2ptr(sorted_node_type, len(torch.unique(data.node_type))) 71 | tensor_slice_hl, _ = ptr_to_tensor_slice(ptr, is_sorted=True) 72 | return tensor_slice_hl 73 | 74 | 75 | class HEAT(torch.nn.Module): 76 | def __init__(self, hidden_channels, out_channels, num_heads, num_layers): 77 | super().__init__() 78 | 79 | self.convs = torch.nn.ModuleList() 80 | for _ in range(num_layers): 81 | conv = HEATConv(hidden_channels, hidden_channels, len(torch.unique(data.node_type)), len(torch.unique(data.edge_type)), 5, 2, 6, 82 | num_heads, concat=False) 83 | self.convs.append(conv) 84 | 85 | self.lin_in = Linear(-1, hidden_channels) 86 | self.lin_out = Linear(hidden_channels, out_channels) 87 | 88 | def forward(self, x, edge_index, node_type, edge_type, edge_attr): 89 | x = self.lin_in(x).relu_() 90 | 91 | for conv in self.convs: 92 | x = conv(x, edge_index, node_type, edge_type, edge_attr) 93 | 94 | return self.lin_out(x) 95 | 96 | 97 | class FastenHEAT(torch.nn.Module): 98 | def __init__(self, hidden_channels, out_channels, num_heads, num_layers): 99 | super().__init__() 100 | 101 | self.convs = torch.nn.ModuleList() 102 | for _ in range(num_layers): 103 | conv = FastenHEATConv(hidden_channels, hidden_channels, len(torch.unique(data.node_type)), len(torch.unique(data.edge_type)), 5, 2, 6, 104 | num_heads, concat=False, engine=Engine.TRITON) 105 | self.convs.append(conv) 106 | 107 | self.lin_in = Linear(-1, hidden_channels) 108 | self.lin_out = Linear(hidden_channels, out_channels) 109 | 110 | def forward(self, x, edge_index, node_type, edge_type, edge_attr, tensor_slice_hl): 111 | x = self.lin_in(x).relu_() 112 | 113 | for conv in self.convs: 114 | x = conv(x, edge_index, node_type, edge_type, edge_attr, tensor_slice_hl=tensor_slice_hl) 115 | 116 | return self.lin_out(x) 117 | 118 | 119 | model = None 120 | if args.mode == 'fasten': 121 | model = FastenHEAT(hidden_channels=args.hidden_size, out_channels=out_channels, num_heads=2, num_layers=1) 122 | else: 123 | model = HEAT(hidden_channels=args.hidden_size, out_channels=out_channels, num_heads=2, num_layers=1) 124 | 125 | data, model = data.to(device), model.to(device) 126 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001) 127 | tensor_slice_hl = tensor_slice_gen(data) 128 | 129 | 130 | def train(): 131 | model.train() 132 | optimizer.zero_grad() 133 | if args.mode == 'fasten': 134 | out = model(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr, tensor_slice_hl) 135 | else: 136 | out = model(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr) 137 | mask = data.y != -1 138 | loss = F.cross_entropy(out[mask], data.y[mask]) 139 | loss.backward() 140 | optimizer.step() 141 | return float(loss) 142 | 143 | 144 | @torch.no_grad() 145 | def test(): 146 | model.eval() 147 | if args.mode == 'fasten': 148 | pred = model(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr, tensor_slice_hl).argmax(dim=-1) 149 | else: 150 | pred = model(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr).argmax(dim=-1) 151 | accs = [] 152 | for split in ['train_mask', 'val_mask', 'test_mask']: 153 | mask = data[split] 154 | 155 | acc = 0 156 | acc_cnt = 0 157 | for i in range(len(data.y[mask])): 158 | if data.y[mask][i] <= 0: continue 159 | acc_cnt += 1 160 | if pred[mask][i] == data.y[mask][i]: 161 | acc += 1 162 | 163 | if acc_cnt > 0: 164 | acc /= acc_cnt 165 | accs.append(float(acc)) 166 | 167 | return accs 168 | 169 | 170 | if args.profile == "none": 171 | for epoch in range(1, 5): 172 | loss = train() 173 | train_acc, val_acc, test_acc = test() 174 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' 175 | f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') 176 | 177 | elif args.profile == "profile": 178 | # warmup 179 | train() 180 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=False, record_shapes=False) as prof: 181 | for epoch in range(1, 5): 182 | train() 183 | 184 | print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=15)) 185 | 186 | else: # args.profile == "benchmark" 187 | def pyg_fn(): 188 | model(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr) 189 | 190 | def fasten_fn(): 191 | model(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr, tensor_slice_hl) 192 | 193 | def train_fn(): 194 | train() 195 | fn = pyg_fn if args.mode == "pyg" else fasten_fn 196 | inference_ms = do_bench(fn) 197 | train_ms = do_bench(train_fn) 198 | print(f"{args.mode} inference: {inference_ms} ms") 199 | print(f"{args.mode} train: {train_ms} ms") 200 | -------------------------------------------------------------------------------- /examples/hgt/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | timemory-*-output 3 | metadata.json 4 | -------------------------------------------------------------------------------- /examples/hgt/hgt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import random 4 | from typing import List, Tuple 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch_geometric 9 | import torch_geometric.transforms as T 10 | from torch import Tensor 11 | from torch.profiler import ProfilerActivity, profile 12 | from torch_geometric.datasets import DBLP, Entities, HGBDataset 13 | from torch_geometric.nn import HGTConv, Linear 14 | from torch_geometric.utils.sparse import index2ptr 15 | from triton.testing import do_bench 16 | 17 | from fasten import Engine, TensorSlice, compact_tensor_types 18 | from fasten.nn import FastenHGTConv 19 | 20 | torch.backends.cuda.matmul.allow_tf32 = True 21 | torch_geometric.backend.use_segment_matmul = True 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--device', type=str, default='cpu', 25 | choices=['cpu', 'cuda']) 26 | parser.add_argument('--mode', type=str, default='pyg', 27 | choices=['pyg', 'fasten']) 28 | parser.add_argument('--example', type=str, default='DBLP', 29 | choices=['DBLP', 'Freebase', 'AIFB', 'AM', 'BGS', 'MUTAG']) 30 | parser.add_argument('--hidden_size', type=int, default=32) 31 | parser.add_argument('--profile', type=str, default='none', 32 | choices=['none', 'profile', 'benchmark']) 33 | args = parser.parse_args() 34 | 35 | if args.example == 'dblp': 36 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP') 37 | # We initialize conference node features with a single one-vector as feature: 38 | dataset = DBLP(path, transform=T.Constant(node_types='conference')) 39 | out_channels = 4 # 4 class labels 40 | elif args.example in ['AIFB', 'AM', 'BGS', 'MUTAG']: 41 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../../Entities') 42 | dataset = Entities(path, args.example, hetero=True) 43 | out_channels = dataset.num_classes 44 | else: 45 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/HGBD') 46 | transform = T.Compose([T.Constant(value=random.random(), 47 | node_types=['book', 'film', 'music', 'sports', 'people', 'location', 'organization', 'business'])]) 48 | dataset = HGBDataset(path, "Freebase", transform=transform) 49 | out_channels = 7 # 7 class labels 50 | 51 | data = dataset[0] 52 | device = torch.device(args.device) 53 | if args.example in ['aifb', 'am', 'bgs', 'mutag']: 54 | data['v'].train_mask = torch.zeros(data['v'].num_nodes, dtype=torch.bool) 55 | data['v'].test_mask = torch.zeros(data['v'].num_nodes, dtype=torch.bool) 56 | data['v'].y = torch.Tensor([dataset.num_classes + 1] * data['v'].num_nodes).type(torch.LongTensor) 57 | data['v'].x = torch.rand((data['v'].num_nodes, args.hidden_size)) 58 | for idx, i in enumerate(data.train_idx): 59 | data['v'].train_mask[i] = True 60 | data['v'].y[i] = data.train_y[idx] 61 | for idx, i in enumerate(data.test_idx): 62 | data['v'].test_mask[i] = True 63 | data['v'].y[i] = data.test_y[idx] 64 | 65 | 66 | def ptr_to_tensor_slice(ptr: List, data: Tensor = None, is_sorted: bool = False) -> Tuple[TensorSlice, List]: 67 | assert ptr is not None 68 | slices = [slice(ptr[i], ptr[i + 1]) for i in range(len(ptr) - 1)] 69 | types = torch.zeros((ptr[-1],), dtype=torch.int) 70 | for i, s in enumerate(slices): 71 | types[s] = i 72 | tensor_slice = compact_tensor_types(data=data, types=types, is_sorted=is_sorted, device=device) 73 | return tensor_slice, slices 74 | 75 | 76 | def tensor_slice_gen(data, num_heads) -> Tuple[TensorSlice, Tensor, TensorSlice, List]: 77 | # Generating tensor_slice for HeteroDictLinear 78 | ptr = [0] 79 | for key, _ in data.x_dict.items(): 80 | ptr.append(ptr[-1] + data.x_dict[key].shape[0]) 81 | tensor_slice_hdl, slices = ptr_to_tensor_slice(ptr, is_sorted=True) 82 | slices_hdl = slices 83 | 84 | # Generating tensor_slice for HeteroLinear 85 | edge_types = data.metadata()[1] 86 | num_edge_types = len(edge_types) 87 | H = num_heads # No of heads 88 | type_list = [] 89 | edge_map = {edge_type: i for i, edge_type in enumerate(data.metadata()[1])} 90 | 91 | for key, _ in data.edge_index_dict.items(): 92 | N = data.x_dict[key[0]].shape[0] 93 | edge_type_offset = edge_map[key] 94 | type_vec = torch.arange(H, dtype=torch.long).view(-1, 1).repeat(1, N) * num_edge_types + edge_type_offset 95 | type_list.append(type_vec) 96 | 97 | type_vec = torch.cat(type_list, dim=1).flatten() 98 | num_types = H * len(edge_types) 99 | ptr = index2ptr(type_vec, num_types) 100 | tensor_slice_hl, _ = ptr_to_tensor_slice(ptr, is_sorted=True) 101 | 102 | return tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl 103 | 104 | 105 | class HGT(torch.nn.Module): 106 | def __init__(self, hidden_channels, out_channels, num_heads, num_layers): 107 | super().__init__() 108 | 109 | self.lin_dict = torch.nn.ModuleDict() 110 | for node_type in data.node_types: 111 | self.lin_dict[node_type] = Linear(-1, hidden_channels) 112 | 113 | self.convs = torch.nn.ModuleList() 114 | for _ in range(num_layers): 115 | conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), 116 | num_heads) 117 | self.convs.append(conv) 118 | 119 | self.lin = Linear(hidden_channels, out_channels) 120 | if args.example == "dblp": 121 | self.node_out = "author" 122 | elif args.example in ["aifb", "am", "bgs", "mutag"]: 123 | self.node_out = "v" 124 | else: 125 | self.node_out = "book" 126 | 127 | def forward(self, x_dict, edge_index_dict): 128 | x_dict = { 129 | node_type: self.lin_dict[node_type](x).relu_() 130 | for node_type, x in x_dict.items() 131 | } 132 | 133 | for conv in self.convs: 134 | x_dict = conv(x_dict, edge_index_dict) 135 | 136 | return self.lin(x_dict[self.node_out]) 137 | 138 | 139 | class FastenHGT(torch.nn.Module): 140 | def __init__(self, hidden_channels, out_channels, num_heads, num_layers): 141 | super().__init__() 142 | 143 | self.lin_dict = torch.nn.ModuleDict() 144 | for node_type in data.node_types: 145 | self.lin_dict[node_type] = Linear(-1, hidden_channels) 146 | 147 | self.convs = torch.nn.ModuleList() 148 | for _ in range(num_layers): 149 | conv = FastenHGTConv(hidden_channels, hidden_channels, data.metadata(), 150 | num_heads, engine=Engine.TRITON) 151 | self.convs.append(conv) 152 | 153 | self.lin = Linear(hidden_channels, out_channels) 154 | if args.example == "dblp": 155 | self.node_out = "author" 156 | elif args.example in ["aifb", "am", "bgs", "mutag"]: 157 | self.node_out = "v" 158 | else: 159 | self.node_out = "book" 160 | 161 | def forward(self, x_dict, edge_index_dict, tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl): 162 | x_dict = { 163 | node_type: self.lin_dict[node_type](x).relu_() 164 | for node_type, x in x_dict.items() 165 | } 166 | 167 | for conv in self.convs: 168 | x_dict = conv(x_dict=x_dict, edge_index_dict=edge_index_dict, tensor_slice_hl=tensor_slice_hl, 169 | type_vec=type_vec, tensor_slice_hdl=tensor_slice_hdl, slices_hdl=slices_hdl) 170 | 171 | return self.lin(x_dict[self.node_out]) 172 | 173 | 174 | if args.mode == 'fasten': 175 | model = FastenHGT(hidden_channels=args.hidden_size, out_channels=out_channels, num_heads=2, num_layers=1) 176 | data, model = data.to(device), model.to(device) 177 | tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl = tensor_slice_gen(data, 2) # last argument num_heads 178 | with torch.no_grad(): # Initialize lazy modules. 179 | out = model(data.x_dict, data.edge_index_dict, tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl) 180 | 181 | else: 182 | model = HGT(hidden_channels=args.hidden_size, out_channels=out_channels, num_heads=2, num_layers=1) 183 | data, model = data.to(device), model.to(device) 184 | with torch.no_grad(): # Initialize lazy modules. 185 | out = model(data.x_dict, data.edge_index_dict) 186 | 187 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001) 188 | 189 | 190 | def train(node_out: str): 191 | model.train() 192 | optimizer.zero_grad() 193 | if args.mode == 'fasten': 194 | out = model(data.x_dict, data.edge_index_dict, tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl) 195 | else: 196 | out = model(data.x_dict, data.edge_index_dict) 197 | 198 | mask = data[node_out].train_mask 199 | loss = F.cross_entropy(out[mask], data[node_out].y[mask]) 200 | loss.backward() 201 | optimizer.step() 202 | return float(loss) 203 | 204 | 205 | @torch.no_grad() 206 | def test(node_out: str): 207 | model.eval() 208 | if args.mode == 'fasten': 209 | pred = model(data.x_dict, data.edge_index_dict, tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl).argmax(dim=-1) 210 | else: 211 | pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1) 212 | 213 | accs = [] 214 | if args.example == "dblp": 215 | for split in ['train_mask', 'val_mask', 'test_mask']: 216 | mask = data[node_out][split] 217 | acc = (pred[mask] == data[node_out].y[mask]).sum() / mask.sum() 218 | accs.append(float(acc)) 219 | else: 220 | for split in ['train_mask', 'test_mask']: 221 | mask = data[node_out][split] 222 | acc = (pred[mask] == data[node_out].y[mask]).sum() / mask.sum() 223 | accs.append(float(acc)) 224 | 225 | return accs 226 | 227 | 228 | if args.example == 'dblp': 229 | node_out = 'author' 230 | elif args.example in ['aifb', 'am', 'bgs', 'mutag']: 231 | node_out = 'v' 232 | else: 233 | node_out = 'book' 234 | 235 | if args.profile == "none": 236 | for epoch in range(1, 5): 237 | loss = train(node_out) 238 | if args.example == "dblp": 239 | train_acc, val_acc, test_acc = test(node_out) 240 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' 241 | f'Val: {val_acc:.4f}, Test: {test_acc:.4f}') 242 | else: 243 | train_acc, test_acc = test(node_out) 244 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' 245 | f'Test: {test_acc:.4f}') 246 | 247 | elif args.profile == "profile": 248 | # warmup 249 | train(node_out) 250 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=False, record_shapes=False) as prof: 251 | for epoch in range(1, 5): 252 | train(node_out) 253 | 254 | print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=15)) 255 | 256 | else: # args.profile == "benchmark" 257 | def pyg_fn(): 258 | model(data.x_dict, data.edge_index_dict) 259 | 260 | def fasten_fn(): 261 | model(data.x_dict, data.edge_index_dict, tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl) 262 | 263 | def train_fn(): 264 | train(node_out) 265 | fn = pyg_fn if args.mode == "pyg" else fasten_fn 266 | inference_ms = do_bench(fn) 267 | train_ms = do_bench(train_fn) 268 | print(f"{args.mode} inference: {inference_ms} ms") 269 | print(f"{args.mode} train: {train_ms} ms") 270 | -------------------------------------------------------------------------------- /examples/rgat/rgat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import time 4 | from typing import Tuple 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch_geometric 9 | from torch.profiler import ProfilerActivity, profile 10 | from torch_geometric.datasets import Entities 11 | from torch_geometric.nn import RGATConv 12 | from torch_geometric.utils import index_sort, k_hop_subgraph 13 | from triton.testing import do_bench 14 | 15 | from fasten import Engine, TensorSlice, compact_tensor_types 16 | from fasten.nn import FastenRGATConv 17 | 18 | torch.backends.cuda.matmul.allow_tf32 = True 19 | torch_geometric.backend.use_segment_matmul = True 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--device', type=str, default='cpu', 23 | choices=['cpu', 'cuda']) 24 | parser.add_argument('--mode', type=str, default='pyg', 25 | choices=['pyg', 'fasten']) 26 | parser.add_argument('--dataset', type=str, default='AIFB', 27 | choices=['AIFB', 'MUTAG', 'BGS', 'AM']) 28 | parser.add_argument('--profile', type=str, default='none', 29 | choices=['none', 'profile', 'benchmark']) 30 | parser.add_argument('--hidden_size', type=int, default=32) 31 | args = parser.parse_args() 32 | 33 | device = torch.device(args.device) 34 | 35 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities') 36 | dataset = Entities(path, args.dataset) 37 | data = dataset[0] 38 | node_idx = torch.cat([data.train_idx, data.test_idx], dim=0) 39 | node_idx, edge_index, mapping, edge_mask = k_hop_subgraph( 40 | node_idx, 2, data.edge_index, relabel_nodes=True) 41 | data.x = torch.randn(data.num_nodes, args.hidden_size) 42 | 43 | 44 | def tensor_slice_gen(edge_type, edge_index, num_relations) -> Tuple[TensorSlice, torch.Tensor, torch.Tensor]: 45 | if (edge_type[1:] < edge_type[:-1]).any(): 46 | edge_type, perm = index_sort( 47 | edge_type, max_value=num_relations) 48 | edge_index = edge_index[:, perm] 49 | tensor_slice = compact_tensor_types(data=None, types=edge_type, is_sorted=True, device=device) 50 | return tensor_slice, edge_index, edge_type 51 | 52 | 53 | class FastenRGAT(torch.nn.Module): 54 | def __init__(self, in_channels, hidden_channels, out_channels, 55 | num_relations): 56 | super().__init__() 57 | self.conv1 = FastenRGATConv(in_channels, hidden_channels, num_relations, engine=Engine.TRITON) 58 | self.conv2 = FastenRGATConv(hidden_channels, hidden_channels, num_relations, engine=Engine.TRITON) 59 | self.lin = torch.nn.Linear(hidden_channels, out_channels) 60 | 61 | def forward(self, x, edge_index, edge_type, tensor_slice): 62 | x = self.conv1(x, edge_index, edge_type, tensor_slice=tensor_slice).relu() 63 | x = self.conv2(x, edge_index, edge_type, tensor_slice=tensor_slice).relu() 64 | x = self.lin(x) 65 | return F.log_softmax(x, dim=-1) 66 | 67 | 68 | class RGAT(torch.nn.Module): 69 | def __init__(self, in_channels, hidden_channels, out_channels, 70 | num_relations): 71 | super().__init__() 72 | self.conv1 = RGATConv(in_channels, hidden_channels, num_relations) 73 | self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations) 74 | self.lin = torch.nn.Linear(hidden_channels, out_channels) 75 | 76 | def forward(self, x, edge_index, edge_type): 77 | x = self.conv1(x, edge_index, edge_type).relu() 78 | x = self.conv2(x, edge_index, edge_type).relu() 79 | x = self.lin(x) 80 | return F.log_softmax(x, dim=-1) 81 | 82 | 83 | data = data.to(device) 84 | if args.mode == "fasten": 85 | model = FastenRGAT(args.hidden_size, args.hidden_size, dataset.num_classes, dataset.num_relations).to(device) 86 | ptr = [i for i in range(len(data.edge_type) + 1)] 87 | tensor_slice, edge_index, edge_type = tensor_slice_gen(data.edge_type, data.edge_index, dataset.num_relations) 88 | assert tensor_slice is not None 89 | else: 90 | model = RGAT(args.hidden_size, args.hidden_size, dataset.num_classes, dataset.num_relations).to(device) 91 | 92 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) 93 | 94 | 95 | def train(): 96 | model.train() 97 | optimizer.zero_grad() 98 | if args.mode == "fasten": 99 | out = model(data.x, edge_index, edge_type, tensor_slice=tensor_slice) 100 | else: 101 | out = model(data.x, data.edge_index, data.edge_type) 102 | loss = F.nll_loss(out[data.train_idx], data.train_y) 103 | loss.backward() 104 | optimizer.step() 105 | return float(loss) 106 | 107 | 108 | @torch.no_grad() 109 | def test(): 110 | model.eval() 111 | if args.mode == "fasten": 112 | pred = model(data.x, edge_index, edge_type, tensor_slice=tensor_slice).argmax(dim=-1) 113 | else: 114 | pred = model(data.x, data.edge_index, data.edge_type).argmax(dim=-1) 115 | train_acc = float((pred[data.train_idx] == data.train_y).float().mean()) 116 | test_acc = float((pred[data.test_idx] == data.test_y).float().mean()) 117 | return train_acc, test_acc 118 | 119 | 120 | if args.profile == "none": 121 | times = [] 122 | for epoch in range(1, 5): 123 | start = time.time() 124 | loss = train() 125 | train_acc, test_acc = test() 126 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} ' 127 | f'Test: {test_acc:.4f}') 128 | times.append(time.time() - start) 129 | print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") 130 | 131 | elif args.profile == "profile": 132 | # warmup 133 | train() 134 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=False, record_shapes=False) as prof: 135 | for epoch in range(1, 5): 136 | train() 137 | 138 | print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=15)) 139 | 140 | else: # args.profile == "benchmark" 141 | def train_fn(): 142 | train() 143 | train_ms = do_bench(train_fn) 144 | print(f"{args.mode} train: {train_ms} ms") 145 | -------------------------------------------------------------------------------- /examples/rgcn/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | timemory-rgcn-output 3 | metadata.json 4 | -------------------------------------------------------------------------------- /examples/rgcn/rgcn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | from typing import Tuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch_geometric 8 | from torch.profiler import ProfilerActivity, profile 9 | from torch_geometric.datasets import Entities 10 | from torch_geometric.nn import RGCNConv 11 | from torch_geometric.utils import index_sort, k_hop_subgraph 12 | from triton.testing import do_bench 13 | 14 | from fasten import TensorSlice, compact_tensor_types 15 | from fasten.nn import FastenRGCNConv 16 | 17 | torch.backends.cuda.matmul.allow_tf32 = True 18 | torch_geometric.backend.use_segment_matmul = True 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dataset', type=str, default='AIFB', 22 | choices=['AIFB', 'MUTAG', 'BGS', 'AM']) 23 | parser.add_argument('--mode', type=str, default='pyg', 24 | choices=['pyg', 'fasten']) 25 | parser.add_argument('--device', type=str, default='cpu', 26 | choices=['cpu', 'cuda']) 27 | parser.add_argument('--profile', type=str, default='none', 28 | choices=['none', 'profile', 'benchmark']) 29 | parser.add_argument('--hidden_size', type=int, default=32) 30 | args = parser.parse_args() 31 | device = torch.device(args.device) 32 | 33 | 34 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities') 35 | dataset = Entities(path, args.dataset) 36 | data = dataset[0] 37 | 38 | node_idx = torch.cat([data.train_idx, data.test_idx], dim=0) 39 | node_idx, edge_index, mapping, edge_mask = k_hop_subgraph( 40 | node_idx, 2, data.edge_index, relabel_nodes=True) 41 | 42 | data.num_nodes = node_idx.size(0) 43 | data.edge_index = edge_index 44 | data.edge_type = data.edge_type[edge_mask] 45 | data.train_idx = mapping[:data.train_idx.size(0)] 46 | data.test_idx = mapping[data.train_idx.size(0):] 47 | 48 | input = torch.randn(data.num_nodes, args.hidden_size).to(device) 49 | 50 | 51 | def tensor_slice_gen(edge_type, edge_index, num_relations) -> Tuple[TensorSlice, torch.Tensor, torch.Tensor]: 52 | if (edge_type[1:] < edge_type[:-1]).any(): 53 | edge_type, perm = index_sort( 54 | edge_type, max_value=num_relations) 55 | edge_index = edge_index[:, perm] 56 | tensor_slice = compact_tensor_types(data=None, types=edge_type, is_sorted=True, device=device) 57 | return tensor_slice, edge_index, edge_type 58 | 59 | 60 | class Net(torch.nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | self.conv1 = RGCNConv(args.hidden_size, args.hidden_size, dataset.num_relations, aggr="add", is_sorted=True) 64 | self.conv2 = RGCNConv(args.hidden_size, dataset.num_classes, dataset.num_relations, aggr="add", is_sorted=True) 65 | 66 | def forward(self, input, edge_index, edge_type): 67 | x = F.relu(self.conv1(input, edge_index, edge_type)) 68 | x = self.conv2(x, edge_index, edge_type) 69 | return F.log_softmax(x, dim=1) 70 | 71 | 72 | class FastenNet(torch.nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | self.conv1 = FastenRGCNConv(args.hidden_size, args.hidden_size, dataset.num_relations, aggr="add", is_sorted=True) 76 | self.conv2 = FastenRGCNConv(args.hidden_size, dataset.num_classes, dataset.num_relations, aggr="add", is_sorted=True) 77 | 78 | def forward(self, input, edge_index, edge_type, tensor_slice): 79 | x = F.relu(self.conv1(input, edge_index, edge_type, tensor_slice)) 80 | x = self.conv2(x, edge_index, edge_type, tensor_slice) 81 | return F.log_softmax(x, dim=1) 82 | 83 | 84 | if args.mode == "fasten": 85 | model, data = FastenNet().to(device), data.to(device) 86 | else: 87 | model, data = Net().to(device), data.to(device) 88 | tensor_slice, edge_index, edge_type = tensor_slice_gen(data.edge_type, data.edge_index, dataset.num_relations) 89 | 90 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) 91 | 92 | 93 | def train(): 94 | optimizer.zero_grad() 95 | model.train() 96 | if args.mode == "fasten": 97 | out = model(input, edge_index, None, tensor_slice) 98 | else: 99 | out = model(input, edge_index, edge_type) 100 | loss = F.nll_loss(out[data.train_idx], data.train_y) 101 | loss.backward() 102 | optimizer.step() 103 | return float(loss) 104 | 105 | 106 | @torch.no_grad() 107 | def test(): 108 | model.eval() 109 | if args.mode == "fasten": 110 | pred = model(input, edge_index, None, tensor_slice).argmax(dim=-1) 111 | else: 112 | pred = model(input, edge_index, edge_type).argmax(dim=-1) 113 | train_acc = float((pred[data.train_idx] == data.train_y).float().mean()) 114 | test_acc = float((pred[data.test_idx] == data.test_y).float().mean()) 115 | return train_acc, test_acc 116 | 117 | 118 | if args.profile == "none": 119 | for epoch in range(1, 5): 120 | loss = train() 121 | train_acc, test_acc = test() 122 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} ' 123 | f'Test: {test_acc:.4f}') 124 | 125 | elif args.profile == "profile": 126 | # warmup 127 | train() 128 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=False, record_shapes=False) as prof: 129 | for epoch in range(1, 5): 130 | train() 131 | 132 | print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=15)) 133 | 134 | else: # args.profile == "benchmark" 135 | def pyg_fn(): 136 | model(input, edge_index, edge_type).argmax(dim=-1) 137 | 138 | def fasten_fn(): 139 | model(input, edge_index, edge_type, tensor_slice).argmax(dim=-1) 140 | fn = pyg_fn if args.mode == "pyg" else fasten_fn 141 | inference_ms = do_bench(fn) 142 | train_ms = do_bench(train) 143 | print(f"{args.mode} inference: {inference_ms} ms") 144 | print(f"{args.mode} train: {train_ms} ms") 145 | -------------------------------------------------------------------------------- /fasten/__init__.py: -------------------------------------------------------------------------------- 1 | from .ops import * # noqa: F403,F401 2 | from .tensor_slice import * # noqa: F403,F401 3 | from .utils import * # noqa: F403,F401 4 | -------------------------------------------------------------------------------- /fasten/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401 2 | from .conv import FastenHEATConv, FastenHGTConv, FastenRGATConv, FastenRGCNConv 3 | from .linear import FastenHeteroDictLinear, FastenHeteroLinear 4 | -------------------------------------------------------------------------------- /fasten/nn/conv/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401 2 | from .heat_conv import FastenHEATConv 3 | from .hgt_conv import FastenHGTConv 4 | from .rgat_conv import FastenRGATConv 5 | from .rgcn_conv import FastenRGCNConv 6 | -------------------------------------------------------------------------------- /fasten/nn/conv/heat_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from torch.nn import Embedding 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.dense.linear import Linear 9 | from torch_geometric.typing import Adj, OptTensor 10 | from torch_geometric.utils import softmax 11 | 12 | from fasten import Engine, TensorSlice 13 | from fasten.nn.linear import FastenHeteroLinear 14 | 15 | 16 | class FastenHEATConv(MessagePassing): 17 | r"""The heterogeneous edge-enhanced graph attentional operator from the 18 | `"Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent 19 | Trajectory Prediction" `_ paper. 20 | 21 | :class:`HEATConv` enhances :class:`~torch_geometric.nn.conv.GATConv` by: 22 | 23 | 1. type-specific transformations of nodes of different types 24 | 2. edge type and edge feature incorporation, in which edges are assumed to 25 | have different types but contain the same kind of attributes 26 | 27 | Args: 28 | in_channels (int): Size of each input sample, or :obj:`-1` to derive 29 | the size from the first input(s) to the forward method. 30 | out_channels (int): Size of each output sample. 31 | num_node_types (int): The number of node types. 32 | num_edge_types (int): The number of edge types. 33 | edge_type_emb_dim (int): The embedding size of edge types. 34 | edge_dim (int): Edge feature dimensionality. 35 | edge_attr_emb_dim (int): The embedding size of edge features. 36 | heads (int, optional): Number of multi-head-attentions. 37 | (default: :obj:`1`) 38 | concat (bool, optional): If set to :obj:`False`, the multi-head 39 | attentions are averaged instead of concatenated. 40 | (default: :obj:`True`) 41 | negative_slope (float, optional): LeakyReLU angle of the negative 42 | slope. (default: :obj:`0.2`) 43 | dropout (float, optional): Dropout probability of the normalized 44 | attention coefficients which exposes each node to a stochastically 45 | sampled neighborhood during training. (default: :obj:`0`) 46 | root_weight (bool, optional): If set to :obj:`False`, the layer will 47 | not add transformed root node features to the output. 48 | (default: :obj:`True`) 49 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 50 | an additive bias. (default: :obj:`True`) 51 | **kwargs (optional): Additional arguments of 52 | :class:`torch_geometric.nn.conv.MessagePassing`. 53 | 54 | Shapes: 55 | - **input:** 56 | node features :math:`(|\mathcal{V}|, F_{in})`, 57 | edge indices :math:`(2, |\mathcal{E}|)`, 58 | node types :math:`(|\mathcal{V}|)`, 59 | edge types :math:`(|\mathcal{E}|)`, 60 | edge features :math:`(|\mathcal{E}|, D)` *(optional)* 61 | - **output:** node features :math:`(|\mathcal{V}|, F_{out})` 62 | """ 63 | 64 | def __init__(self, in_channels: int, out_channels: int, 65 | num_node_types: int, num_edge_types: int, 66 | edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, 67 | heads: int = 1, concat: bool = True, 68 | negative_slope: float = 0.2, dropout: float = 0.0, 69 | root_weight: bool = True, bias: bool = True, engine: Engine = Engine.AUTO, **kwargs): 70 | 71 | kwargs.setdefault('aggr', 'add') 72 | super().__init__(node_dim=0, **kwargs) 73 | 74 | self.in_channels = in_channels 75 | self.out_channels = out_channels 76 | self.heads = heads 77 | self.concat = concat 78 | self.negative_slope = negative_slope 79 | self.dropout = dropout 80 | self.root_weight = root_weight 81 | self.engine = engine 82 | 83 | # self.hetero_lin = HeteroLinear(in_channels, out_channels, 84 | # num_node_types, bias=bias) 85 | self.hetero_lin = FastenHeteroLinear(in_channels, out_channels, num_node_types, bias=bias, engine=self.engine) # check for sotredness 86 | 87 | self.edge_type_emb = Embedding(num_edge_types, edge_type_emb_dim) 88 | self.edge_attr_emb = Linear(edge_dim, edge_attr_emb_dim, bias=False) 89 | 90 | self.att = Linear( 91 | 2 * out_channels + edge_type_emb_dim + edge_attr_emb_dim, 92 | self.heads, bias=False) 93 | 94 | self.lin = Linear(out_channels + edge_attr_emb_dim, out_channels, 95 | bias=bias) 96 | 97 | self.reset_parameters() 98 | 99 | def reset_parameters(self): 100 | super().reset_parameters() 101 | self.hetero_lin.reset_parameters() 102 | self.edge_type_emb.reset_parameters() 103 | self.edge_attr_emb.reset_parameters() 104 | self.att.reset_parameters() 105 | self.lin.reset_parameters() 106 | 107 | def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor, 108 | edge_type: Tensor, edge_attr: OptTensor = None, tensor_slice_hl: TensorSlice = None) -> Tensor: 109 | 110 | x = self.hetero_lin(x, node_type, tensor_slice_hl) 111 | 112 | edge_type_emb = F.leaky_relu(self.edge_type_emb(edge_type), 113 | self.negative_slope) 114 | 115 | # propagate_type: (x: Tensor, edge_type_emb: Tensor, edge_attr: OptTensor) # noqa 116 | out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb, 117 | edge_attr=edge_attr, size=None) 118 | 119 | if self.concat: 120 | if self.root_weight: 121 | out = out + x.view(-1, 1, self.out_channels) 122 | out = out.view(-1, self.heads * self.out_channels) 123 | else: 124 | out = out.mean(dim=1) 125 | if self.root_weight: 126 | out = out + x 127 | 128 | return out 129 | 130 | def message(self, x_i: Tensor, x_j: Tensor, edge_type_emb: Tensor, 131 | edge_attr: Tensor, index: Tensor, ptr: OptTensor, 132 | size_i: Optional[int]) -> Tensor: 133 | 134 | edge_attr = F.leaky_relu(self.edge_attr_emb(edge_attr), 135 | self.negative_slope) 136 | 137 | alpha = torch.cat([x_i, x_j, edge_type_emb, edge_attr], dim=-1) 138 | alpha = F.leaky_relu(self.att(alpha), self.negative_slope) 139 | alpha = softmax(alpha, index, ptr, size_i) 140 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 141 | 142 | out = self.lin(torch.cat([x_j, edge_attr], dim=-1)).unsqueeze(-2) 143 | return out * alpha.unsqueeze(-1) 144 | 145 | def __repr__(self) -> str: 146 | return (f'{self.__class__.__name__}({self.in_channels}, ' 147 | f'{self.out_channels}, heads={self.heads})') 148 | -------------------------------------------------------------------------------- /fasten/nn/conv/hgt_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import Parameter 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.inits import ones 9 | from torch_geometric.nn.parameter_dict import ParameterDict 10 | from torch_geometric.typing import Adj, EdgeType, Metadata, NodeType 11 | from torch_geometric.utils import softmax 12 | from torch_geometric.utils.hetero import construct_bipartite_edge_index 13 | 14 | from fasten import Engine, TensorSlice 15 | from fasten.nn.linear import FastenHeteroDictLinear, FastenHeteroLinear 16 | 17 | 18 | class FastenHGTConv(MessagePassing): 19 | r"""The Heterogeneous Graph Transformer (HGT) operator from the 20 | `"Heterogeneous Graph Transformer" `_ 21 | paper. 22 | 23 | .. note:: 24 | 25 | For an example of using HGT, see `examples/hetero/hgt_dblp.py 26 | `_. 28 | 29 | Args: 30 | in_channels (int or Dict[str, int]): Size of each input sample of every 31 | node type, or :obj:`-1` to derive the size from the first input(s) 32 | to the forward method. 33 | out_channels (int): Size of each output sample. 34 | metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata 35 | of the heterogeneous graph, *i.e.* its node and edge types given 36 | by a list of strings and a list of string triplets, respectively. 37 | See :meth:`torch_geometric.data.HeteroData.metadata` for more 38 | information. 39 | heads (int, optional): Number of multi-head-attentions. 40 | (default: :obj:`1`) 41 | **kwargs (optional): Additional arguments of 42 | :class:`torch_geometric.nn.conv.MessagePassing`. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | in_channels: Union[int, Dict[str, int]], 48 | out_channels: int, 49 | metadata: Metadata, 50 | heads: int = 1, 51 | engine: Engine = Engine.AUTO, 52 | **kwargs, 53 | ): 54 | super().__init__(aggr='add', node_dim=0, **kwargs) 55 | 56 | if out_channels % heads != 0: 57 | raise ValueError(f"'out_channels' (got {out_channels}) must be " 58 | f"divisible by the number of heads (got {heads})") 59 | 60 | if not isinstance(in_channels, dict): 61 | in_channels = {node_type: in_channels for node_type in metadata[0]} 62 | 63 | self.in_channels = in_channels 64 | self.out_channels = out_channels 65 | self.heads = heads 66 | self.engine = engine 67 | self.node_types = metadata[0] 68 | self.edge_types = metadata[1] 69 | self.edge_types_map = { 70 | edge_type: i 71 | for i, edge_type in enumerate(metadata[1]) 72 | } 73 | 74 | self.dst_node_types = set([key[-1] for key in self.edge_types]) 75 | 76 | self.kqv_lin = FastenHeteroDictLinear(self.in_channels, 77 | self.out_channels * 3, engine=self.engine) 78 | 79 | self.out_lin = FastenHeteroDictLinear(self.out_channels, self.out_channels, 80 | types=self.node_types, engine=self.engine) 81 | 82 | dim = out_channels // heads 83 | num_types = heads * len(self.edge_types) 84 | 85 | self.k_rel = FastenHeteroLinear(dim, dim, num_types, bias=False, 86 | is_sorted=True, engine=self.engine) 87 | self.v_rel = FastenHeteroLinear(dim, dim, num_types, bias=False, 88 | is_sorted=True, engine=self.engine) 89 | 90 | self.skip = ParameterDict({ 91 | node_type: Parameter(torch.empty(1)) 92 | for node_type in self.node_types 93 | }) 94 | 95 | self.p_rel = ParameterDict() 96 | for edge_type in self.edge_types: 97 | edge_type = '__'.join(edge_type) 98 | self.p_rel[edge_type] = Parameter(torch.empty(1, heads)) 99 | 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | super().reset_parameters() 104 | self.kqv_lin.reset_parameters() 105 | self.out_lin.reset_parameters() 106 | self.k_rel.reset_parameters() 107 | self.v_rel.reset_parameters() 108 | ones(self.skip) 109 | ones(self.p_rel) 110 | 111 | def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]: 112 | """Concatenates a dictionary of features.""" 113 | cumsum = 0 114 | outs: List[Tensor] = [] 115 | offset: Dict[str, int] = {} 116 | for key, x in x_dict.items(): 117 | outs.append(x) 118 | offset[key] = cumsum 119 | cumsum += x.size(0) 120 | return torch.cat(outs, dim=0), offset 121 | 122 | def _construct_src_node_feat( 123 | self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor], 124 | edge_index_dict: Dict[EdgeType, Adj], 125 | type_vec: Tensor, 126 | tensor_slice_hl: TensorSlice = None 127 | ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]: 128 | """Constructs the source node representations.""" 129 | cumsum = 0 130 | H, D = self.heads, self.out_channels // self.heads 131 | 132 | # Flatten into a single tensor with shape [num_edge_types * heads, D]: 133 | ks: List[Tensor] = [] 134 | vs: List[Tensor] = [] 135 | offset: Dict[EdgeType] = {} 136 | for edge_type in edge_index_dict.keys(): 137 | src = edge_type[0] 138 | N = k_dict[src].size(0) 139 | offset[edge_type] = cumsum 140 | cumsum += N 141 | ks.append(k_dict[src]) 142 | vs.append(v_dict[src]) 143 | 144 | ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) 145 | vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) 146 | 147 | k = self.k_rel(ks, type_vec, tensor_slice_hl).view(H, -1, D).transpose(0, 1) 148 | v = self.v_rel(vs, type_vec, tensor_slice_hl).view(H, -1, D).transpose(0, 1) 149 | 150 | return k, v, offset 151 | 152 | def forward( 153 | self, 154 | x_dict: Dict[NodeType, Tensor], 155 | edge_index_dict: Dict[EdgeType, Adj], 156 | tensor_slice_hl: TensorSlice = None, 157 | type_vec: Tensor = None, 158 | tensor_slice_hdl: TensorSlice = None, 159 | slices_hdl=None 160 | ) -> Dict[NodeType, Optional[Tensor]]: 161 | r"""Runs the forward pass of the module. 162 | 163 | Args: 164 | x_dict (Dict[str, torch.Tensor]): A dictionary holding input node 165 | features for each individual node type. 166 | edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A 167 | dictionary holding graph connectivity information for each 168 | individual edge type, either as a :class:`torch.Tensor` of 169 | shape :obj:`[2, num_edges]` or a 170 | :class:`torch_sparse.SparseTensor`. 171 | 172 | :rtype: :obj:`Dict[str, Optional[torch.Tensor]]` - The output node 173 | embeddings for each node type. 174 | In case a node type does not receive any message, its output will 175 | be set to :obj:`None`. 176 | """ 177 | F = self.out_channels 178 | H = self.heads 179 | D = F // H 180 | 181 | k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} 182 | 183 | # Compute K, Q, V over node types: 184 | kqv_dict = self.kqv_lin(x_dict, tensor_slice_hdl, slices_hdl) 185 | for key, val in kqv_dict.items(): 186 | k, q, v = torch.tensor_split(val, 3, dim=1) 187 | k_dict[key] = k.view(-1, H, D) 188 | q_dict[key] = q.view(-1, H, D) 189 | v_dict[key] = v.view(-1, H, D) 190 | 191 | q, dst_offset = self._cat(q_dict) 192 | k, v, src_offset = self._construct_src_node_feat( 193 | k_dict, v_dict, edge_index_dict, type_vec, tensor_slice_hl) 194 | 195 | edge_index, edge_attr = construct_bipartite_edge_index( 196 | edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel) 197 | 198 | out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr, 199 | size=None) 200 | 201 | # Reconstruct output node embeddings dict: 202 | for node_type, start_offset in dst_offset.items(): 203 | end_offset = start_offset + q_dict[node_type].size(0) 204 | if node_type in self.dst_node_types: 205 | out_dict[node_type] = out[start_offset:end_offset] 206 | 207 | # Transform output node embeddings: 208 | a_dict = self.out_lin({ 209 | k: 210 | torch.nn.functional.gelu(v) if v is not None else v 211 | for k, v in out_dict.items() 212 | }, tensor_slice_hdl, slices_hdl) 213 | 214 | # Iterate over node types: 215 | for node_type, out in out_dict.items(): 216 | out = a_dict[node_type] 217 | 218 | if out.size(-1) == x_dict[node_type].size(-1): 219 | alpha = self.skip[node_type].sigmoid() 220 | out = alpha * out + (1 - alpha) * x_dict[node_type] 221 | out_dict[node_type] = out 222 | 223 | return out_dict 224 | 225 | def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, 226 | index: Tensor, ptr: Optional[Tensor], 227 | size_i: Optional[int]) -> Tensor: 228 | alpha = (q_i * k_j).sum(dim=-1) * edge_attr 229 | alpha = alpha / math.sqrt(q_i.size(-1)) 230 | alpha = softmax(alpha, index, ptr, size_i) 231 | out = v_j * alpha.view(-1, self.heads, 1) 232 | return out.view(-1, self.out_channels) 233 | 234 | def __repr__(self) -> str: 235 | return (f'{self.__class__.__name__}(-1, {self.out_channels}, ' 236 | f'heads={self.heads})') 237 | -------------------------------------------------------------------------------- /fasten/nn/conv/rgat_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from torch.nn import Parameter, ReLU 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.dense.linear import Linear 9 | from torch_geometric.nn.inits import glorot, ones, zeros 10 | from torch_geometric.typing import Adj, OptTensor, Size, SparseTensor 11 | from torch_geometric.utils import is_torch_sparse_tensor, scatter, softmax 12 | from torch_geometric.utils.sparse import set_sparse_value 13 | 14 | from fasten import Engine, TensorSlice, ops 15 | 16 | 17 | class FastenRGATConv(MessagePassing): 18 | r"""The relational graph attentional operator from the `"Relational Graph 19 | Attention Networks" `_ paper. 20 | 21 | Here, attention logits :math:`\mathbf{a}^{(r)}_{i,j}` are computed for each 22 | relation type :math:`r` with the help of both query and key kernels, *i.e.* 23 | 24 | .. math:: 25 | \mathbf{q}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot 26 | \mathbf{Q}^{(r)} 27 | \quad \textrm{and} \quad 28 | \mathbf{k}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot 29 | \mathbf{K}^{(r)}. 30 | 31 | Two schemes have been proposed to compute attention logits 32 | :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r`: 33 | 34 | **Additive attention** 35 | 36 | .. math:: 37 | \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + 38 | \mathbf{k}^{(r)}_j) 39 | 40 | or **multiplicative attention** 41 | 42 | .. math:: 43 | \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j. 44 | 45 | If the graph has multi-dimensional edge features 46 | :math:`\mathbf{e}^{(r)}_{i,j}`, the attention logits 47 | :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r` are 48 | computed as 49 | 50 | .. math:: 51 | \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + 52 | \mathbf{k}^{(r)}_j + \mathbf{W}_2^{(r)}\mathbf{e}^{(r)}_{i,j}) 53 | 54 | or 55 | 56 | .. math:: 57 | \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j 58 | \cdot \mathbf{W}_2^{(r)} \mathbf{e}^{(r)}_{i,j}, 59 | 60 | respectively. 61 | The attention coefficients :math:`\alpha^{(r)}_{i,j}` for each relation 62 | type :math:`r` are then obtained via two different attention mechanisms: 63 | The **within-relation** attention mechanism 64 | 65 | .. math:: 66 | \alpha^{(r)}_{i,j} = 67 | \frac{\exp(\mathbf{a}^{(r)}_{i,j})} 68 | {\sum_{k \in \mathcal{N}_r(i)} \exp(\mathbf{a}^{(r)}_{i,k})} 69 | 70 | or the **across-relation** attention mechanism 71 | 72 | .. math:: 73 | \alpha^{(r)}_{i,j} = 74 | \frac{\exp(\mathbf{a}^{(r)}_{i,j})} 75 | {\sum_{r^{\prime} \in \mathcal{R}} 76 | \sum_{k \in \mathcal{N}_{r^{\prime}}(i)} 77 | \exp(\mathbf{a}^{(r^{\prime})}_{i,k})} 78 | 79 | where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. 80 | Edge type needs to be a one-dimensional :obj:`torch.long` tensor which 81 | stores a relation identifier :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` 82 | for each edge. 83 | 84 | To enhance the discriminative power of attention-based GNNs, this layer 85 | further implements four different cardinality preservation options as 86 | proposed in the `"Improving Attention Mechanism in Graph Neural Networks 87 | via Cardinality Preservation" `_ paper: 88 | 89 | .. math:: 90 | \text{additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= 91 | \sum_{j \in \mathcal{N}_r(i)} 92 | \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + \mathcal{W} \odot 93 | \sum_{j \in \mathcal{N}_r(i)} \mathbf{x}^{(r)}_j 94 | 95 | \text{scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= 96 | \psi(|\mathcal{N}_r(i)|) \odot 97 | \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j 98 | 99 | \text{f-additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= 100 | \sum_{j \in \mathcal{N}_r(i)} 101 | (\alpha^{(r)}_{i,j} + 1) \cdot \mathbf{x}^{(r)}_j 102 | 103 | \text{f-scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= 104 | |\mathcal{N}_r(i)| \odot \sum_{j \in \mathcal{N}_r(i)} 105 | \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j 106 | 107 | * If :obj:`attention_mode="additive-self-attention"` and 108 | :obj:`concat=True`, the layer outputs :obj:`heads * out_channels` 109 | features for each node. 110 | 111 | * If :obj:`attention_mode="multiplicative-self-attention"` and 112 | :obj:`concat=True`, the layer outputs :obj:`heads * dim * out_channels` 113 | features for each node. 114 | 115 | * If :obj:`attention_mode="additive-self-attention"` and 116 | :obj:`concat=False`, the layer outputs :obj:`out_channels` features for 117 | each node. 118 | 119 | * If :obj:`attention_mode="multiplicative-self-attention"` and 120 | :obj:`concat=False`, the layer outputs :obj:`dim * out_channels` features 121 | for each node. 122 | 123 | Please make sure to set the :obj:`in_channels` argument of the next 124 | layer accordingly if more than one instance of this layer is used. 125 | 126 | .. note:: 127 | 128 | For an example of using :class:`RGATConv`, see 129 | `examples/rgat.py `_. 131 | 132 | Args: 133 | in_channels (int): Size of each input sample. 134 | out_channels (int): Size of each output sample. 135 | num_relations (int): Number of relations. 136 | num_bases (int, optional): If set, this layer will use the 137 | basis-decomposition regularization scheme where :obj:`num_bases` 138 | denotes the number of bases to use. (default: :obj:`None`) 139 | num_blocks (int, optional): If set, this layer will use the 140 | block-diagonal-decomposition regularization scheme where 141 | :obj:`num_blocks` denotes the number of blocks to use. 142 | (default: :obj:`None`) 143 | mod (str, optional): The cardinality preservation option to use. 144 | (:obj:`"additive"`, :obj:`"scaled"`, :obj:`"f-additive"`, 145 | :obj:`"f-scaled"`, :obj:`None`). (default: :obj:`None`) 146 | attention_mechanism (str, optional): The attention mechanism to use 147 | (:obj:`"within-relation"`, :obj:`"across-relation"`). 148 | (default: :obj:`"across-relation"`) 149 | attention_mode (str, optional): The mode to calculate attention logits. 150 | (:obj:`"additive-self-attention"`, 151 | :obj:`"multiplicative-self-attention"`). 152 | (default: :obj:`"additive-self-attention"`) 153 | heads (int, optional): Number of multi-head-attentions. 154 | (default: :obj:`1`) 155 | dim (int): Number of dimensions for query and key kernels. 156 | (default: :obj:`1`) 157 | concat (bool, optional): If set to :obj:`False`, the multi-head 158 | attentions are averaged instead of concatenated. 159 | (default: :obj:`True`) 160 | negative_slope (float, optional): LeakyReLU angle of the negative 161 | slope. (default: :obj:`0.2`) 162 | dropout (float, optional): Dropout probability of the normalized 163 | attention coefficients which exposes each node to a stochastically 164 | sampled neighborhood during training. (default: :obj:`0`) 165 | edge_dim (int, optional): Edge feature dimensionality (in case there 166 | are any). (default: :obj:`None`) 167 | bias (bool, optional): If set to :obj:`False`, the layer will not 168 | learn an additive bias. (default: :obj:`True`) 169 | **kwargs (optional): Additional arguments of 170 | :class:`torch_geometric.nn.conv.MessagePassing`. 171 | """ 172 | 173 | _alpha: OptTensor 174 | 175 | def __init__( 176 | self, 177 | in_channels: int, 178 | out_channels: int, 179 | num_relations: int, 180 | num_bases: Optional[int] = None, 181 | num_blocks: Optional[int] = None, 182 | mod: Optional[str] = None, 183 | attention_mechanism: str = "across-relation", 184 | attention_mode: str = "additive-self-attention", 185 | heads: int = 1, 186 | dim: int = 1, 187 | concat: bool = True, 188 | negative_slope: float = 0.2, 189 | dropout: float = 0.0, 190 | edge_dim: Optional[int] = None, 191 | bias: bool = True, 192 | engine: Engine = Engine.AUTO, 193 | **kwargs, 194 | ): 195 | kwargs.setdefault('aggr', 'add') 196 | super().__init__(node_dim=0, **kwargs) 197 | 198 | self.heads = heads 199 | self.negative_slope = negative_slope 200 | self.dropout = dropout 201 | self.mod = mod 202 | self.activation = ReLU() 203 | self.concat = concat 204 | self.attention_mode = attention_mode 205 | self.attention_mechanism = attention_mechanism 206 | self.dim = dim 207 | self.edge_dim = edge_dim 208 | 209 | self.in_channels = in_channels 210 | self.out_channels = out_channels 211 | self.num_relations = num_relations 212 | self.num_bases = num_bases 213 | self.num_blocks = num_blocks 214 | self.engine = engine 215 | 216 | mod_types = ['additive', 'scaled', 'f-additive', 'f-scaled'] 217 | 218 | if (self.attention_mechanism != "within-relation" and self.attention_mechanism != "across-relation"): 219 | raise ValueError('attention mechanism must either be ' 220 | '"within-relation" or "across-relation"') 221 | 222 | if (self.attention_mode != "additive-self-attention" and self.attention_mode != "multiplicative-self-attention"): 223 | raise ValueError('attention mode must either be ' 224 | '"additive-self-attention" or ' 225 | '"multiplicative-self-attention"') 226 | 227 | if self.attention_mode == "additive-self-attention" and self.dim > 1: 228 | raise ValueError('"additive-self-attention" mode cannot be ' 229 | 'applied when value of d is greater than 1. ' 230 | 'Use "multiplicative-self-attention" instead.') 231 | 232 | if self.dropout > 0.0 and self.mod in mod_types: 233 | raise ValueError('mod must be None with dropout value greater ' 234 | 'than 0 in order to sample attention ' 235 | 'coefficients stochastically') 236 | 237 | if num_bases is not None and num_blocks is not None: 238 | raise ValueError('Can not apply both basis-decomposition and ' 239 | 'block-diagonal-decomposition at the same time.') 240 | 241 | # The learnable parameters to compute both attention logits and 242 | # attention coefficients: 243 | self.q = Parameter( 244 | torch.empty(self.heads * self.out_channels, self.heads * self.dim)) 245 | self.k = Parameter( 246 | torch.empty(self.heads * self.out_channels, self.heads * self.dim)) 247 | 248 | if bias and concat: 249 | self.bias = Parameter( 250 | torch.empty(self.heads * self.dim * self.out_channels)) 251 | elif bias and not concat: 252 | self.bias = Parameter(torch.empty(self.dim * self.out_channels)) 253 | else: 254 | self.register_parameter('bias', None) 255 | 256 | if edge_dim is not None: 257 | self.lin_edge = Linear(self.edge_dim, 258 | self.heads * self.out_channels, bias=False, 259 | weight_initializer='glorot') 260 | self.e = Parameter( 261 | torch.empty(self.heads * self.out_channels, 262 | self.heads * self.dim)) 263 | else: 264 | self.lin_edge = None 265 | self.register_parameter('e', None) 266 | 267 | if num_bases is not None: 268 | self.att = Parameter( 269 | torch.empty(self.num_relations, self.num_bases)) 270 | self.basis = Parameter( 271 | torch.empty(self.num_bases, self.in_channels, 272 | self.heads * self.out_channels)) 273 | elif num_blocks is not None: 274 | assert ( 275 | self.in_channels % self.num_blocks == 0 and (self.heads * self.out_channels) % self.num_blocks == 0), ( 276 | "both 'in_channels' and 'heads * out_channels' must be " 277 | "multiple of 'num_blocks' used") 278 | self.weight = Parameter( 279 | torch.empty(self.num_relations, self.num_blocks, 280 | self.in_channels // self.num_blocks, 281 | (self.heads * self.out_channels) // self.num_blocks)) 282 | else: 283 | self.weight = Parameter( 284 | torch.empty(self.num_relations, self.in_channels, 285 | self.heads * self.out_channels)) 286 | 287 | self.w = Parameter(torch.ones(self.out_channels)) 288 | self.l1 = Parameter(torch.empty(1, self.out_channels)) 289 | self.b1 = Parameter(torch.empty(1, self.out_channels)) 290 | self.l2 = Parameter(torch.empty(self.out_channels, self.out_channels)) 291 | self.b2 = Parameter(torch.empty(1, self.out_channels)) 292 | 293 | self._alpha = None 294 | 295 | self.reset_parameters() 296 | 297 | def reset_parameters(self): 298 | super().reset_parameters() 299 | if self.num_bases is not None: 300 | glorot(self.basis) 301 | glorot(self.att) 302 | else: 303 | glorot(self.weight) 304 | glorot(self.q) 305 | glorot(self.k) 306 | zeros(self.bias) 307 | ones(self.l1) 308 | zeros(self.b1) 309 | torch.full(self.l2.size(), 1 / self.out_channels) 310 | zeros(self.b2) 311 | if self.lin_edge is not None: 312 | glorot(self.lin_edge) 313 | glorot(self.e) 314 | 315 | def forward( 316 | self, 317 | x: Tensor, 318 | edge_index: Adj, 319 | edge_type: OptTensor = None, 320 | edge_attr: OptTensor = None, 321 | size: Size = None, 322 | return_attention_weights=None, 323 | tensor_slice: TensorSlice = None, 324 | ): 325 | r"""Runs the forward pass of the module. 326 | 327 | Args: 328 | x (torch.Tensor): The input node features. 329 | Can be either a :obj:`[num_nodes, in_channels]` node feature 330 | matrix, or an optional one-dimensional node index tensor (in 331 | which case input features are treated as trainable node 332 | embeddings). 333 | edge_index (torch.Tensor or SparseTensor): The edge indices. 334 | edge_type (torch.Tensor, optional): The one-dimensional relation 335 | type/index for each edge in :obj:`edge_index`. 336 | Should be only :obj:`None` in case :obj:`edge_index` is of type 337 | :class:`torch_sparse.SparseTensor` or 338 | :class:`torch.sparse.Tensor`. (default: :obj:`None`) 339 | edge_attr (torch.Tensor, optional): The edge features. 340 | (default: :obj:`None`) 341 | size ((int, int), optional): The shape of the adjacency matrix. 342 | (default: :obj:`None`) 343 | return_attention_weights (bool, optional): If set to :obj:`True`, 344 | will additionally return the tuple 345 | :obj:`(edge_index, attention_weights)`, holding the computed 346 | attention weights for each edge. (default: :obj:`None`) 347 | """ 348 | # propagate_type: (x: Tensor, edge_type: OptTensor, edge_attr: OptTensor, tensor_slice=tensor_slice) # noqa 349 | out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x, 350 | size=size, edge_attr=edge_attr, tensor_slice=tensor_slice) 351 | 352 | alpha = self._alpha 353 | assert alpha is not None 354 | self._alpha = None 355 | 356 | if isinstance(return_attention_weights, bool): 357 | if isinstance(edge_index, Tensor): 358 | if is_torch_sparse_tensor(edge_index): 359 | # TODO TorchScript requires to return a tuple 360 | adj = set_sparse_value(edge_index, alpha) 361 | return out, (adj, alpha) 362 | else: 363 | return out, (edge_index, alpha) 364 | elif isinstance(edge_index, SparseTensor): 365 | return out, edge_index.set_value(alpha, layout='coo') 366 | else: 367 | return out 368 | 369 | def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, 370 | edge_attr: OptTensor, index: Tensor, ptr: OptTensor, 371 | size_i: Optional[int], tensor_slice: TensorSlice) -> Tensor: 372 | 373 | if self.num_bases is not None: # Basis-decomposition ================= 374 | w = torch.matmul(self.att, self.basis.view(self.num_bases, -1)) 375 | w = w.view(self.num_relations, self.in_channels, 376 | self.heads * self.out_channels) 377 | if self.num_blocks is not None: # Block-diagonal-decomposition ======= 378 | if (x_i.dtype == torch.long and x_j.dtype == torch.long and self.num_blocks is not None): 379 | raise ValueError('Block-diagonal decomposition not supported ' 380 | 'for non-continuous input features.') 381 | w = self.weight 382 | x_i = x_i.view(-1, 1, w.size(1), w.size(2)) 383 | x_j = x_j.view(-1, 1, w.size(1), w.size(2)) 384 | w = torch.index_select(w, 0, edge_type) 385 | outi = torch.einsum('abcd,acde->ace', x_i, w) 386 | outi = outi.contiguous().view(-1, self.heads * self.out_channels) 387 | outj = torch.einsum('abcd,acde->ace', x_j, w) 388 | outj = outj.contiguous().view(-1, self.heads * self.out_channels) 389 | else: # No regularization/Basis-decomposition ======================== 390 | if self.num_bases is None: 391 | w = self.weight 392 | outi = ops.fasten_segment_matmul(x_i, w, tensor_slice, self.engine) 393 | outj = ops.fasten_segment_matmul(x_j, w, tensor_slice, self.engine) 394 | 395 | qi = torch.matmul(outi, self.q) 396 | kj = torch.matmul(outj, self.k) 397 | 398 | alpha_edge, alpha = 0, torch.tensor([0]) 399 | if edge_attr is not None: 400 | if edge_attr.dim() == 1: 401 | edge_attr = edge_attr.view(-1, 1) 402 | assert self.lin_edge is not None, ( 403 | "Please set 'edge_dim = edge_attr.size(-1)' while calling the " 404 | "RGATConv layer") 405 | edge_attributes = self.lin_edge(edge_attr).view( 406 | -1, self.heads * self.out_channels) 407 | if edge_attributes.size(0) != edge_attr.size(0): 408 | edge_attributes = torch.index_select(edge_attributes, 0, 409 | edge_type) 410 | alpha_edge = torch.matmul(edge_attributes, self.e) 411 | 412 | if self.attention_mode == "additive-self-attention": 413 | if edge_attr is not None: 414 | alpha = torch.add(qi, kj) + alpha_edge 415 | else: 416 | alpha = torch.add(qi, kj) 417 | alpha = F.leaky_relu(alpha, self.negative_slope) 418 | elif self.attention_mode == "multiplicative-self-attention": 419 | if edge_attr is not None: 420 | alpha = (qi * kj) * alpha_edge 421 | else: 422 | alpha = qi * kj 423 | 424 | if self.attention_mechanism == "within-relation": 425 | across_out = torch.zeros_like(alpha) 426 | for r in range(self.num_relations): 427 | mask = edge_type == r 428 | across_out[mask] = softmax(alpha[mask], index[mask]) 429 | alpha = across_out 430 | elif self.attention_mechanism == "across-relation": 431 | alpha = softmax(alpha, index, ptr, size_i) 432 | 433 | self._alpha = alpha 434 | 435 | if self.mod == "additive": 436 | if self.attention_mode == "additive-self-attention": 437 | ones = torch.ones_like(alpha) 438 | h = (outj.view(-1, self.heads, self.out_channels) * ones.view(-1, self.heads, 1)) 439 | h = torch.mul(self.w, h) 440 | 441 | return (outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1) + h) 442 | elif self.attention_mode == "multiplicative-self-attention": 443 | ones = torch.ones_like(alpha) 444 | h = (outj.view(-1, self.heads, 1, self.out_channels) * ones.view(-1, self.heads, self.dim, 1)) 445 | h = torch.mul(self.w, h) 446 | 447 | return (outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1) + h) 448 | 449 | elif self.mod == "scaled": 450 | if self.attention_mode == "additive-self-attention": 451 | ones = alpha.new_ones(index.size()) 452 | degree = scatter(ones, index, dim_size=size_i, 453 | reduce='sum')[index].unsqueeze(-1) 454 | degree = torch.matmul(degree, self.l1) + self.b1 455 | degree = self.activation(degree) 456 | degree = torch.matmul(degree, self.l2) + self.b2 457 | 458 | return torch.mul( 459 | outj.view(-1, self.heads, self.out_channels) * alpha.view(-1, self.heads, 1), 460 | degree.view(-1, 1, self.out_channels)) 461 | elif self.attention_mode == "multiplicative-self-attention": 462 | ones = alpha.new_ones(index.size()) 463 | degree = scatter(ones, index, dim_size=size_i, 464 | reduce='sum')[index].unsqueeze(-1) 465 | degree = torch.matmul(degree, self.l1) + self.b1 466 | degree = self.activation(degree) 467 | degree = torch.matmul(degree, self.l2) + self.b2 468 | 469 | return torch.mul( 470 | outj.view(-1, self.heads, 1, self.out_channels) * alpha.view(-1, self.heads, self.dim, 1), 471 | degree.view(-1, 1, 1, self.out_channels)) 472 | 473 | elif self.mod == "f-additive": 474 | alpha = torch.where(alpha > 0, alpha + 1, alpha) 475 | 476 | elif self.mod == "f-scaled": 477 | ones = alpha.new_ones(index.size()) 478 | degree = scatter(ones, index, dim_size=size_i, 479 | reduce='sum')[index].unsqueeze(-1) 480 | alpha = alpha * degree 481 | 482 | elif self.training and self.dropout > 0: 483 | alpha = F.dropout(alpha, p=self.dropout, training=True) 484 | 485 | else: 486 | alpha = alpha # original 487 | 488 | if self.attention_mode == "additive-self-attention": 489 | return alpha.view(-1, self.heads, 1) * outj.view( 490 | -1, self.heads, self.out_channels) 491 | else: 492 | return (alpha.view(-1, self.heads, self.dim, 1) * outj.view(-1, self.heads, 1, self.out_channels)) 493 | 494 | def update(self, aggr_out: Tensor) -> Tensor: 495 | if self.attention_mode == "additive-self-attention": 496 | if self.concat is True: 497 | aggr_out = aggr_out.view(-1, self.heads * self.out_channels) 498 | else: 499 | aggr_out = aggr_out.mean(dim=1) 500 | 501 | if self.bias is not None: 502 | aggr_out = aggr_out + self.bias 503 | 504 | return aggr_out 505 | else: 506 | if self.concat is True: 507 | aggr_out = aggr_out.view( 508 | -1, self.heads * self.dim * self.out_channels) 509 | else: 510 | aggr_out = aggr_out.mean(dim=1) 511 | aggr_out = aggr_out.view(-1, self.dim * self.out_channels) 512 | 513 | if self.bias is not None: 514 | aggr_out = aggr_out + self.bias 515 | 516 | return aggr_out 517 | 518 | def __repr__(self) -> str: 519 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 520 | self.in_channels, 521 | self.out_channels, self.heads) 522 | -------------------------------------------------------------------------------- /fasten/nn/conv/rgcn_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Parameter 6 | from torch.nn import Parameter as Param 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.inits import glorot, zeros 9 | from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse 10 | from torch_geometric.utils import spmm 11 | 12 | from fasten import Engine, TensorSlice, ops 13 | 14 | 15 | def masked_edge_index(edge_index, edge_mask): 16 | if isinstance(edge_index, Tensor): 17 | return edge_index[:, edge_mask] 18 | return torch_sparse.masked_select_nnz(edge_index, edge_mask, layout='coo') 19 | 20 | 21 | class FastenRGCNConv(MessagePassing): 22 | r"""The relational graph convolutional operator from the `"Modeling 23 | Relational Data with Graph Convolutional Networks" 24 | `_ paper 25 | 26 | .. math:: 27 | \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot 28 | \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} 29 | \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j, 30 | 31 | where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. 32 | Edge type needs to be a one-dimensional :obj:`torch.long` tensor which 33 | stores a relation identifier 34 | :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge. 35 | 36 | .. note:: 37 | This implementation is as memory-efficient as possible by iterating 38 | over each individual relation type. 39 | Therefore, it may result in low GPU utilization in case the graph has a 40 | large number of relations. 41 | As an alternative approach, :class:`FastRGCNConv` does not iterate over 42 | each individual type, but may consume a large amount of memory to 43 | compensate. 44 | We advise to check out both implementations to see which one fits your 45 | needs. 46 | 47 | .. note:: 48 | :class:`RGCNConv` can use `dynamic shapes 49 | `_, which means that the shape of the interim 51 | tensors can be determined at runtime. 52 | If your device doesn't support dynamic shapes, use 53 | :class:`FastRGCNConv` instead. 54 | 55 | Args: 56 | in_channels (int or tuple): Size of each input sample. A tuple 57 | corresponds to the sizes of source and target dimensionalities. 58 | In case no input features are given, this argument should 59 | correspond to the number of nodes in your graph. 60 | out_channels (int): Size of each output sample. 61 | num_relations (int): Number of relations. 62 | num_bases (int, optional): If set, this layer will use the 63 | basis-decomposition regularization scheme where :obj:`num_bases` 64 | denotes the number of bases to use. (default: :obj:`None`) 65 | num_blocks (int, optional): If set, this layer will use the 66 | block-diagonal-decomposition regularization scheme where 67 | :obj:`num_blocks` denotes the number of blocks to use. 68 | (default: :obj:`None`) 69 | aggr (str, optional): The aggregation scheme to use 70 | (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). 71 | (default: :obj:`"mean"`) 72 | root_weight (bool, optional): If set to :obj:`False`, the layer will 73 | not add transformed root node features to the output. 74 | (default: :obj:`True`) 75 | is_sorted (bool, optional): If set to :obj:`True`, assumes that 76 | :obj:`edge_index` is sorted by :obj:`edge_type`. This avoids 77 | internal re-sorting of the data and can improve runtime and memory 78 | efficiency. (default: :obj:`False`) 79 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 80 | an additive bias. (default: :obj:`True`) 81 | **kwargs (optional): Additional arguments of 82 | :class:`torch_geometric.nn.conv.MessagePassing`. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | in_channels: Union[int, Tuple[int, int]], 88 | out_channels: int, 89 | num_relations: int, 90 | num_bases: Optional[int] = None, 91 | num_blocks: Optional[int] = None, 92 | aggr: str = 'mean', 93 | root_weight: bool = True, 94 | is_sorted: bool = False, 95 | bias: bool = True, 96 | engine: Engine = Engine.AUTO, 97 | **kwargs, 98 | ): 99 | kwargs.setdefault('aggr', aggr) 100 | super().__init__(node_dim=0, **kwargs) 101 | 102 | if num_bases is not None and num_blocks is not None: 103 | raise ValueError('Can not apply both basis-decomposition and ' 104 | 'block-diagonal-decomposition at the same time.') 105 | 106 | self.in_channels = in_channels 107 | self.out_channels = out_channels 108 | self.num_relations = num_relations 109 | self.num_bases = num_bases 110 | self.num_blocks = num_blocks 111 | self.is_sorted = is_sorted 112 | self.engine = engine 113 | if isinstance(in_channels, int): 114 | in_channels = (in_channels, in_channels) 115 | self.in_channels_l = in_channels[0] 116 | 117 | if num_bases is not None: 118 | self.weight = Parameter( 119 | torch.empty(num_bases, in_channels[0], out_channels)) 120 | self.comp = Parameter(torch.empty(num_relations, num_bases)) 121 | 122 | elif num_blocks is not None: 123 | assert (in_channels[0] % num_blocks == 0 and out_channels % num_blocks == 0) 124 | self.weight = Parameter( 125 | torch.empty(num_relations, num_blocks, 126 | in_channels[0] // num_blocks, 127 | out_channels // num_blocks)) 128 | self.register_parameter('comp', None) 129 | 130 | else: 131 | self.weight = Parameter( 132 | torch.empty(num_relations, in_channels[0], out_channels)) 133 | self.register_parameter('comp', None) 134 | 135 | if root_weight: 136 | self.root = Param(torch.empty(in_channels[1], out_channels)) 137 | else: 138 | self.register_parameter('root', None) 139 | 140 | if bias: 141 | self.bias = Param(torch.empty(out_channels)) 142 | else: 143 | self.register_parameter('bias', None) 144 | 145 | self.reset_parameters() 146 | 147 | def reset_parameters(self): 148 | super().reset_parameters() 149 | glorot(self.weight) 150 | glorot(self.comp) 151 | glorot(self.root) 152 | zeros(self.bias) 153 | 154 | def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], 155 | edge_index: Adj, edge_type: OptTensor = None, edge_tensor_slice: TensorSlice = None): 156 | r"""Runs the forward pass of the module. 157 | 158 | Args: 159 | x (torch.Tensor or tuple, optional): The input node features. 160 | Can be either a :obj:`[num_nodes, in_channels]` node feature 161 | matrix, or an optional one-dimensional node index tensor (in 162 | which case input features are treated as trainable node 163 | embeddings). 164 | Furthermore, :obj:`x` can be of type :obj:`tuple` denoting 165 | source and destination node features. 166 | edge_index (torch.Tensor or SparseTensor): The edge indices. 167 | edge_type (torch.Tensor, optional): The one-dimensional relation 168 | type/index for each edge in :obj:`edge_index`. 169 | Should be only :obj:`None` in case :obj:`edge_index` is of type 170 | :class:`torch_sparse.SparseTensor`. (default: :obj:`None`) 171 | """ 172 | # Convert input features to a pair of node features or node indices. 173 | x_l: OptTensor = None 174 | if isinstance(x, tuple): 175 | x_l = x[0] 176 | else: 177 | x_l = x 178 | if x_l is None: 179 | x_l = torch.arange(self.in_channels_l, device=self.weight.device) 180 | 181 | x_r: Tensor = x_l 182 | if isinstance(x, tuple): 183 | x_r = x[1] 184 | 185 | size = (x_l.size(0), x_r.size(0)) 186 | if isinstance(edge_index, SparseTensor): 187 | edge_type = edge_index.storage.value() 188 | 189 | # propagate_type: (x: Tensor, edge_tensor_slice: OptTensor) 190 | out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) 191 | 192 | weight = self.weight 193 | if self.num_bases is not None: # Basis-decomposition ================= 194 | weight = (self.comp @ weight.view(self.num_bases, -1)).view( 195 | self.num_relations, self.in_channels_l, self.out_channels) 196 | 197 | if self.num_blocks is not None: # Block-diagonal-decomposition ===== 198 | 199 | if not torch.is_floating_point( 200 | x_r) and self.num_blocks is not None: 201 | raise ValueError('Block-diagonal decomposition not supported ' 202 | 'for non-continuous input features.') 203 | 204 | for i in range(self.num_relations): 205 | tmp = masked_edge_index(edge_index, edge_type == i) 206 | h = self.propagate(tmp, x=x_l, size=size) 207 | h = h.view(-1, weight.size(1), weight.size(2)) 208 | h = torch.einsum('abc,bcd->abd', h, weight[i]) 209 | out = out + h.contiguous().view(-1, self.out_channels) 210 | 211 | else: # No regularization/Basis-decomposition ======================== 212 | if (self.num_bases is None and x_l.is_floating_point() and isinstance(edge_index, Tensor)) and edge_tensor_slice: 213 | assert self.is_sorted, "edge_tensor_slice is only supported when is_sorted=True" 214 | assert self.aggr == "add", "edge_tensor_slice is only supported when aggr=add if you want to get equivalent results as the base implementation" 215 | out = self.propagate(edge_index, x=x_l, size=size, edge_tensor_slice=edge_tensor_slice) 216 | else: 217 | for i in range(self.num_relations): 218 | tmp = masked_edge_index(edge_index, edge_type == i) 219 | 220 | if not torch.is_floating_point(x_r): 221 | out = out + self.propagate( 222 | tmp, 223 | x=weight[i, x_l], 224 | size=size, 225 | ) 226 | else: 227 | h = self.propagate(tmp, x=x_l, size=size) 228 | out = out + (h @ weight[i]) 229 | 230 | root = self.root 231 | if root is not None: 232 | if not torch.is_floating_point(x_r): 233 | out = out + root[x_r] 234 | else: 235 | out = out + x_r @ root 236 | 237 | if self.bias is not None: 238 | out = out + self.bias 239 | return out 240 | 241 | def message(self, x_j: Tensor, edge_tensor_slice: TensorSlice = None) -> Tensor: 242 | if edge_tensor_slice is not None: 243 | return ops.fasten_segment_matmul(x_j, self.weight, edge_tensor_slice, self.engine) 244 | return x_j 245 | 246 | def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: 247 | adj_t = adj_t.set_value(None) 248 | return spmm(adj_t, x, reduce=self.aggr) 249 | 250 | def __repr__(self) -> str: 251 | return (f'{self.__class__.__name__}({self.in_channels}, ' 252 | f'{self.out_channels}, num_relations={self.num_relations})') 253 | -------------------------------------------------------------------------------- /fasten/nn/linear/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401 2 | from .linear import FastenHeteroDictLinear, FastenHeteroLinear 3 | -------------------------------------------------------------------------------- /fasten/nn/linear/linear.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from typing import Any, Dict, Optional, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import torch_geometric.typing 8 | from torch import Tensor 9 | from torch.nn.parameter import Parameter 10 | from torch_geometric.nn import inits 11 | from torch_geometric.utils import index_sort 12 | 13 | from fasten import Engine, TensorSlice, ops 14 | 15 | 16 | def is_uninitialized_parameter(x: Any) -> bool: 17 | if not hasattr(torch.nn.parameter, 'UninitializedParameter'): 18 | return False 19 | return isinstance(x, torch.nn.parameter.UninitializedParameter) 20 | 21 | 22 | def reset_weight_(weight: Tensor, in_channels: int, 23 | initializer: Optional[str] = None) -> Tensor: 24 | if in_channels <= 0: 25 | pass 26 | elif initializer == 'glorot': 27 | inits.glorot(weight) 28 | elif initializer == 'uniform': 29 | bound = 1.0 / math.sqrt(in_channels) 30 | torch.nn.init.uniform_(weight.data, -bound, bound) 31 | elif initializer == 'kaiming_uniform': 32 | inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5)) 33 | elif initializer is None: 34 | inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5)) 35 | else: 36 | raise RuntimeError(f"Weight initializer '{initializer}' not supported") 37 | 38 | return weight 39 | 40 | 41 | def reset_bias_(bias: Optional[Tensor], in_channels: int, 42 | initializer: Optional[str] = None) -> Optional[Tensor]: 43 | if bias is None or in_channels <= 0: 44 | pass 45 | elif initializer == 'zeros': 46 | inits.zeros(bias) 47 | elif initializer is None: 48 | inits.uniform(in_channels, bias) 49 | else: 50 | raise RuntimeError(f"Bias initializer '{initializer}' not supported") 51 | 52 | return bias 53 | 54 | 55 | class Linear(torch.nn.Module): 56 | r"""Applies a linear tranformation to the incoming data 57 | 58 | .. math:: 59 | \mathbf{x}^{\prime} = \mathbf{x} \mathbf{W}^{\top} + \mathbf{b} 60 | 61 | similar to :class:`torch.nn.Linear`. 62 | It supports lazy initialization and customizable weight and bias 63 | initialization. 64 | 65 | Args: 66 | in_channels (int): Size of each input sample. Will be initialized 67 | lazily in case it is given as :obj:`-1`. 68 | out_channels (int): Size of each output sample. 69 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 70 | an additive bias. (default: :obj:`True`) 71 | weight_initializer (str, optional): The initializer for the weight 72 | matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"` 73 | or :obj:`None`). 74 | If set to :obj:`None`, will match default weight initialization of 75 | :class:`torch.nn.Linear`. (default: :obj:`None`) 76 | bias_initializer (str, optional): The initializer for the bias vector 77 | (:obj:`"zeros"` or :obj:`None`). 78 | If set to :obj:`None`, will match default bias initialization of 79 | :class:`torch.nn.Linear`. (default: :obj:`None`) 80 | 81 | Shapes: 82 | - **input:** features :math:`(*, F_{in})` 83 | - **output:** features :math:`(*, F_{out})` 84 | """ 85 | 86 | def __init__(self, in_channels: int, out_channels: int, bias: bool = True, 87 | weight_initializer: Optional[str] = None, 88 | bias_initializer: Optional[str] = None): 89 | super().__init__() 90 | self.in_channels = in_channels 91 | self.out_channels = out_channels 92 | self.weight_initializer = weight_initializer 93 | self.bias_initializer = bias_initializer 94 | 95 | if in_channels > 0: 96 | self.weight = Parameter(torch.empty(out_channels, in_channels)) 97 | else: 98 | self.weight = torch.nn.parameter.UninitializedParameter() 99 | self._hook = self.register_forward_pre_hook( 100 | self.initialize_parameters) 101 | 102 | if bias: 103 | self.bias = Parameter(torch.empty(out_channels)) 104 | else: 105 | self.register_parameter('bias', None) 106 | 107 | self.reset_parameters() 108 | 109 | def __deepcopy__(self, memo): 110 | out = Linear(self.in_channels, self.out_channels, self.bias 111 | is not None, self.weight_initializer, 112 | self.bias_initializer) 113 | if self.in_channels > 0: 114 | out.weight = copy.deepcopy(self.weight, memo) 115 | if self.bias is not None: 116 | out.bias = copy.deepcopy(self.bias, memo) 117 | return out 118 | 119 | def reset_parameters(self): 120 | r"""Resets all learnable parameters of the module.""" 121 | reset_weight_(self.weight, self.in_channels, self.weight_initializer) 122 | reset_bias_(self.bias, self.in_channels, self.bias_initializer) 123 | 124 | def forward(self, x: Tensor) -> Tensor: 125 | r""" 126 | Args: 127 | x (torch.Tensor): The input features. 128 | """ 129 | return F.linear(x, self.weight, self.bias) 130 | 131 | @torch.no_grad() 132 | def initialize_parameters(self, module, input): 133 | if is_uninitialized_parameter(self.weight): 134 | self.in_channels = input[0].size(-1) 135 | self.weight.materialize((self.out_channels, self.in_channels)) 136 | self.reset_parameters() 137 | self._hook.remove() 138 | delattr(self, '_hook') 139 | 140 | def _save_to_state_dict(self, destination, prefix, keep_vars): 141 | if (is_uninitialized_parameter(self.weight) or torch.onnx.is_in_onnx_export() or keep_vars): 142 | destination[prefix + 'weight'] = self.weight 143 | else: 144 | destination[prefix + 'weight'] = self.weight.detach() 145 | if self.bias is not None: 146 | if torch.onnx.is_in_onnx_export() or keep_vars: 147 | destination[prefix + 'bias'] = self.bias 148 | else: 149 | destination[prefix + 'bias'] = self.bias.detach() 150 | 151 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): 152 | weight = state_dict.get(prefix + 'weight', None) 153 | 154 | if weight is not None and is_uninitialized_parameter(weight): 155 | self.in_channels = -1 156 | self.weight = torch.nn.parameter.UninitializedParameter() 157 | if not hasattr(self, '_hook'): 158 | self._hook = self.register_forward_pre_hook( 159 | self.initialize_parameters) 160 | 161 | elif weight is not None and is_uninitialized_parameter(self.weight): 162 | self.in_channels = weight.size(-1) 163 | self.weight.materialize((self.out_channels, self.in_channels)) 164 | if hasattr(self, '_hook'): 165 | self._hook.remove() 166 | delattr(self, '_hook') 167 | 168 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 169 | 170 | def __repr__(self) -> str: 171 | return (f'{self.__class__.__name__}({self.in_channels}, ' 172 | f'{self.out_channels}, bias={self.bias is not None})') 173 | 174 | 175 | class FastenHeteroLinear(torch.nn.Module): 176 | r"""Applies separate linear tranformations to the incoming data according 177 | to types 178 | 179 | .. math:: 180 | \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} 181 | \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa} 182 | 183 | for type :math:`\kappa`. 184 | It supports lazy initialization and customizable weight and bias 185 | initialization. 186 | 187 | Args: 188 | in_channels (int): Size of each input sample. Will be initialized 189 | lazily in case it is given as :obj:`-1`. 190 | out_channels (int): Size of each output sample. 191 | num_types (int): The number of types. 192 | is_sorted (bool, optional): If set to :obj:`True`, assumes that 193 | :obj:`type_vec` is sorted. This avoids internal re-sorting of the 194 | data and can improve runtime and memory efficiency. 195 | (default: :obj:`False`) 196 | **kwargs (optional): Additional arguments of 197 | :class:`torch_geometric.nn.Linear`. 198 | 199 | Shapes: 200 | - **input:** 201 | features :math:`(*, F_{in})`, 202 | type vector :math:`(*)` 203 | - **output:** features :math:`(*, F_{out})` 204 | """ 205 | 206 | def __init__( 207 | self, 208 | in_channels: int, 209 | out_channels: int, 210 | num_types: int, 211 | is_sorted: bool = False, 212 | engine: Engine = Engine.AUTO, 213 | **kwargs, 214 | ): 215 | super().__init__() 216 | 217 | self.in_channels = in_channels 218 | self.out_channels = out_channels 219 | self.num_types = num_types 220 | self.is_sorted = is_sorted 221 | self.engine = engine 222 | self.kwargs = kwargs 223 | 224 | self._use_segment_matmul_heuristic_output: Optional[bool] = None 225 | 226 | if self.in_channels == -1: 227 | self.weight = torch.nn.parameter.UninitializedParameter() 228 | self._hook = self.register_forward_pre_hook( 229 | self.initialize_parameters) 230 | else: 231 | self.weight = torch.nn.Parameter( 232 | torch.empty(num_types, in_channels, out_channels)) 233 | if kwargs.get('bias', True): 234 | self.bias = Parameter(torch.empty(num_types, out_channels)) 235 | else: 236 | self.register_parameter('bias', None) 237 | self.reset_parameters() 238 | 239 | def reset_parameters(self): 240 | r"""Resets all learnable parameters of the module.""" 241 | reset_weight_(self.weight, self.in_channels, 242 | self.kwargs.get('weight_initializer', None)) 243 | reset_bias_(self.bias, self.in_channels, 244 | self.kwargs.get('bias_initializer', None)) 245 | 246 | def forward(self, x: Tensor, type_vec: Tensor, tensor_slice_hl: TensorSlice) -> Tensor: 247 | r""" 248 | Args: 249 | x (torch.Tensor): The input features. 250 | type_vec (torch.Tensor): A vector that maps each entry to a type. 251 | """ 252 | use_segment_matmul = True # Making use_segemnt_matmul True by default 253 | if use_segment_matmul and torch_geometric.typing.WITH_SEGMM: 254 | assert self.weight is not None 255 | 256 | perm: Optional[Tensor] = None 257 | if not self.is_sorted: 258 | if (type_vec[1:] < type_vec[:-1]).any(): 259 | type_vec, perm = index_sort(type_vec, self.num_types) 260 | x = x[perm] 261 | 262 | out = ops.fasten_segment_matmul(x, self.weight, tensor_slice_hl, self.engine) 263 | 264 | if self.bias is not None: 265 | out += self.bias[type_vec] 266 | 267 | if perm is not None: # Restore original order (if necessary). 268 | out_unsorted = torch.empty_like(out) 269 | out_unsorted[perm] = out 270 | out = out_unsorted 271 | 272 | else: 273 | out = x.new_empty(x.size(0), self.out_channels) 274 | for i in range(self.num_types): 275 | mask = type_vec == i 276 | if mask.numel() == 0: 277 | continue 278 | subset_out = F.linear(x[mask], self.weight[i].T) 279 | # The data type may have changed with mixed precision: 280 | out[mask] = subset_out.to(out.dtype) 281 | 282 | if self.bias is not None: 283 | out += self.bias[type_vec] 284 | return out 285 | 286 | @torch.no_grad() 287 | def initialize_parameters(self, module, input): 288 | if is_uninitialized_parameter(self.weight): 289 | self.in_channels = input[0].size(-1) 290 | self.weight.materialize( 291 | (self.num_types, self.in_channels, self.out_channels)) 292 | self.reset_parameters() 293 | self._hook.remove() 294 | delattr(self, '_hook') 295 | 296 | def __repr__(self) -> str: 297 | return (f'{self.__class__.__name__}({self.in_channels}, ' 298 | f'{self.out_channels}, num_types={self.num_types}, ' 299 | f'bias={self.kwargs.get("bias", True)})') 300 | 301 | 302 | class FastenHeteroDictLinear(torch.nn.Module): 303 | r"""Applies separate linear tranformations to the incoming data dictionary 304 | 305 | .. math:: 306 | \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} 307 | \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa} 308 | 309 | for key :math:`\kappa`. 310 | It supports lazy initialization and customizable weight and bias 311 | initialization. 312 | 313 | Args: 314 | in_channels (int or Dict[Any, int]): Size of each input sample. If 315 | passed an integer, :obj:`types` will be a mandatory argument. 316 | initialized lazily in case it is given as :obj:`-1`. 317 | out_channels (int): Size of each output sample. 318 | types (List[Any], optional): The keys of the input dictionary. 319 | (default: :obj:`None`) 320 | **kwargs (optional): Additional arguments of 321 | :class:`torch_geometric.nn.Linear`. 322 | """ 323 | 324 | def __init__( 325 | self, 326 | in_channels: Union[int, Dict[Any, int]], 327 | out_channels: int, 328 | types: Optional[Any] = None, 329 | engine: Engine = Engine.AUTO, 330 | **kwargs, 331 | ): 332 | super().__init__() 333 | 334 | if isinstance(in_channels, dict): 335 | self.types = list(in_channels.keys()) 336 | 337 | if any([i == -1 for i in in_channels.values()]): 338 | self._hook = self.register_forward_pre_hook( 339 | self.initialize_parameters) 340 | 341 | if types is not None and set(self.types) != set(types): 342 | raise ValueError("The provided 'types' do not match with the " 343 | "keys in the 'in_channels' dictionary") 344 | 345 | else: 346 | if types is None: 347 | raise ValueError("Please provide a list of 'types' if passing " 348 | "'in_channels' as an integer") 349 | 350 | if in_channels == -1: 351 | self._hook = self.register_forward_pre_hook( 352 | self.initialize_parameters) 353 | 354 | self.types = types 355 | in_channels = {node_type: in_channels for node_type in types} 356 | 357 | self.in_channels = in_channels 358 | self.out_channels = out_channels 359 | self.engine = engine 360 | self.kwargs = kwargs 361 | 362 | self.lins = torch.nn.ModuleDict({ 363 | key: 364 | Linear(channels, self.out_channels, **kwargs) 365 | for key, channels in self.in_channels.items() 366 | }) 367 | 368 | self.reset_parameters() 369 | 370 | def reset_parameters(self): 371 | r"""Resets all learnable parameters of the module.""" 372 | for lin in self.lins.values(): 373 | lin.reset_parameters() 374 | 375 | def forward( 376 | self, 377 | x_dict: Dict[str, Tensor], 378 | tensor_slice: TensorSlice = None, 379 | slices: list = None, **kwargs 380 | ) -> Dict[str, Tensor]: 381 | r""" 382 | Args: 383 | x_dict (Dict[Any, torch.Tensor]): A dictionary holding input 384 | features for each individual type. 385 | """ 386 | out_dict = {} 387 | 388 | # Only apply fused kernel for more than 10 types, otherwise use 389 | # sequential computation (which is generally faster for these cases). 390 | use_segment_matmul = True 391 | if (use_segment_matmul and torch_geometric.typing.WITH_GMM and not torch.jit.is_scripting()): 392 | xs, weights, biases = [], [], [] 393 | for key, lin in self.lins.items(): 394 | if key in x_dict: 395 | xs.append(x_dict[key]) 396 | weights.append(lin.weight.t()) 397 | biases.append(lin.bias) 398 | biases = None if biases[0] is None else biases 399 | # Stacking the input and weight tensor to feed it to segment matmul 400 | stacked_xs = torch.cat(xs, dim=0) 401 | stacked_weights = torch.stack(weights) 402 | out_segmm = ops.fasten_segment_matmul(stacked_xs, stacked_weights, tensor_slice, self.engine) 403 | outs = [] 404 | 405 | for s in slices: 406 | outs.append(out_segmm[s]) 407 | 408 | if biases is not None: 409 | assert (len(biases) == len(outs)) 410 | for i in range(len(biases)): 411 | outs[i] = outs[i] + biases[i] 412 | 413 | for key, out in zip(x_dict.keys(), outs): 414 | if key in x_dict: 415 | out_dict[key] = out 416 | else: 417 | for key, lin in self.lins.items(): 418 | if key in x_dict: 419 | out_dict[key] = lin(x_dict[key]) 420 | 421 | return out_dict 422 | 423 | @torch.no_grad() 424 | def initialize_parameters(self, module, input): 425 | for key, x in input[0].items(): 426 | lin = self.lins[key] 427 | if is_uninitialized_parameter(lin.weight): 428 | self.lins[key].initialize_parameters(None, x) 429 | self.reset_parameters() 430 | self._hook.remove() 431 | self.in_channels = {key: x.size(-1) for key, x in input[0].items()} 432 | delattr(self, '_hook') 433 | 434 | def __repr__(self) -> str: 435 | return (f'{self.__class__.__name__}({self.in_channels}, ' 436 | f'{self.out_channels}, bias={self.kwargs.get("bias", True)})') 437 | -------------------------------------------------------------------------------- /fasten/operators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deep-Learning-Profiling-Tools/fasten/784bd907aa3685770fb14608fbbd620e21261937/fasten/operators/__init__.py -------------------------------------------------------------------------------- /fasten/operators/torch_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .segment_matmul import segment_matmul_backward_input # noqa F401 2 | from .segment_matmul import segment_matmul_backward_other # noqa F401 3 | from .segment_matmul import segment_matmul_forward # noqa F401 4 | -------------------------------------------------------------------------------- /fasten/operators/torch_ops/segment_matmul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def segment_matmul_forward(input: torch.Tensor, other: torch.Tensor, input_slices: torch.Tensor, 5 | output: torch.Tensor = None): 6 | assert input.device == other.device, 'input, other and output must be on the same device' 7 | input_slices = input_slices.to('cpu') 8 | if output is None: 9 | output = torch.empty(input.shape[0], other.shape[2], device=input.device, dtype=input.dtype) 10 | for i in range(input_slices.shape[0]): 11 | t = input_slices[i, 1] 12 | a = input[input_slices[i, 2]:input_slices[i, 3]] 13 | b = other[t] 14 | c = output[input_slices[i, 2]:input_slices[i, 3]] 15 | torch.matmul(a, b, out=c) 16 | return output 17 | 18 | 19 | def segment_matmul_backward_input(input: torch.Tensor, grad_output: torch.Tensor, other: torch.Tensor, 20 | input_slices: torch.Tensor, grad_input: torch.Tensor = None): 21 | assert input.device == other.device, 'input, other and output must be on the same device' 22 | input_slices = input_slices.to('cpu') 23 | grad_output = grad_output.contiguous() 24 | if grad_input is None: 25 | grad_input = torch.empty_like(input) 26 | for i in range(input_slices.shape[0]): 27 | t = input_slices[i, 1] 28 | a = grad_output[input_slices[i, 2]:input_slices[i, 3]] 29 | b = other[t] 30 | c = grad_input[input_slices[i, 2]:input_slices[i, 3]] 31 | torch.matmul(a, b.t(), out=c) 32 | return grad_input 33 | 34 | 35 | def segment_matmul_backward_other(input: torch.Tensor, grad_output: torch.Tensor, other: torch.Tensor, 36 | input_slices: torch.Tensor, grad_other: torch.Tensor = None): 37 | assert input.device == other.device, 'input, other and output must be on the same device' 38 | input_slices = input_slices.to('cpu') 39 | grad_output = grad_output.contiguous() 40 | if grad_other is None: 41 | # grad_other might be sparse 42 | grad_other = torch.zeros_like(other) 43 | for i in range(input_slices.shape[0]): 44 | t = input_slices[i, 1] 45 | a = input[input_slices[i, 2]:input_slices[i, 3]] 46 | b = grad_output[input_slices[i, 2]:input_slices[i, 3]] 47 | c = grad_other[t] 48 | torch.matmul(a.t(), b, out=c) 49 | return grad_other 50 | -------------------------------------------------------------------------------- /fasten/operators/triton_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .segment_matmul import segment_matmul_backward_input # noqa F401 2 | from .segment_matmul import segment_matmul_backward_other # noqa F401 3 | from .segment_matmul import segment_matmul_forward # noqa F401 4 | -------------------------------------------------------------------------------- /fasten/operators/triton_ops/kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deep-Learning-Profiling-Tools/fasten/784bd907aa3685770fb14608fbbd620e21261937/fasten/operators/triton_ops/kernels/__init__.py -------------------------------------------------------------------------------- /fasten/operators/triton_ops/kernels/matmul.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | 5 | @triton.jit 6 | def _reg_matmul( 7 | pid_n, type_id, 8 | start_off, 9 | input, other, output, N, 10 | stride_input_m, stride_input_k, 11 | stride_other_b, stride_other_k, stride_other_n, 12 | stride_output_m, stride_output_n, 13 | out_dtype: tl.constexpr, 14 | BLOCK_SIZE: tl.constexpr, 15 | EVEN_N: tl.constexpr, 16 | TILE_M: tl.constexpr, 17 | TILE_N: tl.constexpr, 18 | TILE_K: tl.constexpr 19 | ): 20 | offs_m = start_off + tl.arange(0, TILE_M) 21 | offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 22 | offs_k = tl.arange(0, TILE_K) 23 | rn = tl.max_contiguous(tl.multiple_of(offs_n % N, TILE_N), TILE_N) 24 | other_ptrs = other + type_id * stride_other_b + \ 25 | (offs_k[:, None] * stride_other_k + rn[None, :] * stride_other_n) 26 | b = tl.load(other_ptrs) 27 | 28 | # [M, K] x [K, N] -> [M, N] 29 | input_ptrs = input + (offs_m[:, None] * stride_input_m + offs_k[None, :] * stride_input_k) 30 | output_ptrs = output + stride_output_m * offs_m[:, None] + stride_output_n * offs_n[None, :] 31 | for _ in range(0, BLOCK_SIZE): 32 | a = tl.load(input_ptrs) 33 | acc = tl.dot(a, b, out_dtype=out_dtype).to(output.dtype.element_ty) 34 | if EVEN_N: 35 | tl.store(output_ptrs, acc) 36 | else: 37 | mask_n = offs_n[None, :] < N 38 | tl.store(output_ptrs, acc, mask=mask_n) 39 | input_ptrs += TILE_M * stride_input_m 40 | output_ptrs += TILE_M * stride_output_m 41 | 42 | 43 | @triton.jit 44 | def _general_matmul( 45 | pid_n, 46 | start_off, end_off, 47 | input, other, output, 48 | K, N, 49 | stride_input_m, stride_input_k, 50 | stride_other_k, stride_other_n, 51 | stride_output_m, stride_output_n, 52 | out_dtype: tl.constexpr, 53 | MASK_M: tl.constexpr, 54 | TILE_M: tl.constexpr, 55 | TILE_N: tl.constexpr, 56 | TILE_K: tl.constexpr, 57 | EVEN_N: tl.constexpr, 58 | EVEN_K: tl.constexpr 59 | ): 60 | offs_m = start_off + tl.arange(0, TILE_M) 61 | offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 62 | offs_k = tl.arange(0, TILE_K) 63 | rn = tl.max_contiguous(tl.multiple_of(offs_n % N, TILE_N), TILE_N) 64 | 65 | # [M, K] x [K, N] -> [M, N] 66 | input_ptrs = input + (offs_m[:, None] * stride_input_m + offs_k[None, :] * stride_input_k) 67 | other_ptrs = other + \ 68 | (offs_k[:, None] * stride_other_k + rn[None, :] * stride_other_n) 69 | 70 | acc = tl.zeros((TILE_M, TILE_N), dtype=out_dtype) 71 | mask_m = offs_m[:, None] < end_off if MASK_M else True 72 | 73 | k_iter = K // TILE_K if EVEN_K else tl.cdiv(K, TILE_K) 74 | for k in range(0, k_iter): 75 | if EVEN_K: 76 | if MASK_M: 77 | a = tl.load(input_ptrs, mask=mask_m, other=0.0) 78 | b = tl.load(other_ptrs) 79 | else: 80 | a = tl.load(input_ptrs) 81 | b = tl.load(other_ptrs) 82 | else: 83 | if MASK_M: 84 | a = tl.load(input_ptrs, mask=mask_m & (offs_k[None, :] + k * TILE_K < K), other=0.0) 85 | b = tl.load(other_ptrs, mask=(offs_k[:, None] + k * TILE_K < K), other=0.0) 86 | else: 87 | a = tl.load(input_ptrs, mask=(offs_k[None, :] + k * TILE_K < K), other=0.0) 88 | b = tl.load(other_ptrs, mask=(offs_k[:, None] + k * TILE_K < K), other=0.0) 89 | acc += tl.dot(a, b, out_dtype=out_dtype) 90 | input_ptrs += TILE_K * stride_input_k 91 | other_ptrs += TILE_K * stride_other_k 92 | 93 | acc = acc.to(output.dtype.element_ty) 94 | c_ptrs = output + stride_output_m * \ 95 | offs_m[:, None] + stride_output_n * offs_n[None, :] 96 | if EVEN_N: 97 | if MASK_M: 98 | tl.store(c_ptrs, acc, mask=mask_m) 99 | else: 100 | tl.store(c_ptrs, acc) 101 | else: 102 | mask_n = offs_n[None, :] < N 103 | if MASK_M: 104 | tl.store(c_ptrs, acc, mask=mask_m & mask_n) 105 | else: 106 | tl.store(c_ptrs, acc, mask_n) 107 | 108 | 109 | @triton.jit 110 | def _prefetch_matmul( 111 | pid_n, start_off, end_off, 112 | input, other, output, 113 | K, N, 114 | stride_input_m, stride_input_k, 115 | stride_other_k, stride_other_n, 116 | stride_output_m, stride_output_n, 117 | out_dtype: tl.constexpr, 118 | TILE_M: tl.constexpr, 119 | TILE_N: tl.constexpr, 120 | TILE_K: tl.constexpr, 121 | EVEN_N: tl.constexpr, 122 | EVEN_K: tl.constexpr, 123 | BLOCK_SIZE: tl.constexpr 124 | ): 125 | offs_m = start_off + tl.arange(0, TILE_M) 126 | offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 127 | offs_k = tl.arange(0, TILE_K) 128 | rn = tl.max_contiguous(tl.multiple_of(offs_n % N, TILE_N), TILE_N) 129 | 130 | # [M, K] x [K, N] -> [M, N] 131 | input_ptrs = input + (offs_m[:, None] * stride_input_m + offs_k[None, :] * stride_input_k) 132 | other_ptrs = other + \ 133 | (offs_k[:, None] * stride_other_k + rn[None, :] * stride_other_n) 134 | output_ptrs = output + stride_output_m * offs_m[:, None] + stride_output_n * offs_n[None, :] 135 | original_input_ptrs = input_ptrs 136 | original_other_ptrs = other_ptrs 137 | 138 | acc = tl.zeros((TILE_M, TILE_N), dtype=out_dtype) 139 | mask_n = offs_n[None, :] < N 140 | 141 | k_iters = K // TILE_K if EVEN_K else tl.cdiv(K, TILE_K) 142 | for k in range(0, k_iters * BLOCK_SIZE): 143 | i = k % k_iters 144 | if EVEN_K: 145 | a = tl.load(input_ptrs) 146 | b = tl.load(other_ptrs) 147 | else: 148 | a = tl.load(input_ptrs, mask=offs_k[None, :] + i * TILE_K < K, other=0.0) 149 | b = tl.load(other_ptrs, mask=offs_k[:, None] + i * TILE_K < K, other=0.0) 150 | acc += tl.dot(a, b, out_dtype=out_dtype) 151 | if i == k_iters - 1: 152 | if EVEN_N: 153 | tl.store(output_ptrs, acc.to(output.dtype.element_ty)) 154 | else: 155 | tl.store(output_ptrs, acc.to(output.dtype.element_ty), mask_n) 156 | output_ptrs += TILE_M * stride_output_m 157 | if i == k_iters - 1: 158 | acc = tl.zeros((TILE_M, TILE_N), dtype=out_dtype) 159 | original_input_ptrs += TILE_M * stride_input_m 160 | input_ptrs = original_input_ptrs 161 | other_ptrs = original_other_ptrs 162 | else: 163 | input_ptrs += TILE_K * stride_input_k 164 | other_ptrs += TILE_K * stride_other_k 165 | 166 | 167 | @triton.jit 168 | def _dynamic_matmul( 169 | pid_k, pid_n, next_id, 170 | input, grad_output, grad_other, grad_other_tiles, 171 | stride_input_m, stride_input_k, 172 | stride_grad_output_m, stride_grad_output_n, 173 | stride_grad_other_b, stride_grad_other_k, stride_grad_other_n, 174 | K, N, M, length, 175 | out_dtype: tl.constexpr, 176 | BLOCK_LENGTH: tl.constexpr, 177 | TILE_K: tl.constexpr, 178 | TILE_N: tl.constexpr, 179 | TILE_M: tl.constexpr, 180 | EVEN_N: tl.constexpr, 181 | EVEN_K: tl.constexpr, 182 | EVEN_M: tl.constexpr, 183 | DETERMINISTIC: tl.constexpr 184 | ): 185 | offs_k = pid_k * TILE_K + tl.arange(0, TILE_K) 186 | offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 187 | offs_m = tl.arange(0, TILE_M) 188 | acc = tl.zeros((TILE_K, TILE_N), dtype=out_dtype) 189 | mask_k = offs_k[:, None] < K if not EVEN_K else True 190 | mask_n = offs_n[None, :] < N if not EVEN_N else True 191 | 192 | # [M, K] -> [K, M] 193 | input_ptrs = input + (offs_m[None, :] * stride_input_m + offs_k[:, None] * stride_input_k) 194 | # [M, N] 195 | grad_output_ptrs = grad_output + (offs_m[:, None] * stride_grad_output_m + offs_n[None, :] * stride_grad_output_n) 196 | 197 | m_iter = length // TILE_M if EVEN_M else tl.cdiv(length, TILE_M) 198 | for m in range(0, m_iter): 199 | if EVEN_K: 200 | if EVEN_M: 201 | a = tl.load(input_ptrs) 202 | else: 203 | a = tl.load(input_ptrs, mask=(offs_m[None, :] + m * TILE_M < length), other=0.0) 204 | else: 205 | if EVEN_M: 206 | a = tl.load(input_ptrs, mask=mask_k, other=0.0) 207 | else: 208 | a = tl.load(input_ptrs, mask=mask_k & (offs_m[None, :] + m * TILE_M < length), other=0.0) 209 | if EVEN_N: 210 | if EVEN_M: 211 | b = tl.load(grad_output_ptrs) 212 | else: 213 | b = tl.load(grad_output_ptrs, mask=(offs_m[:, None] + m * TILE_M < length), other=0.0) 214 | else: 215 | if EVEN_M: 216 | b = tl.load(grad_output_ptrs, mask=mask_n) 217 | else: 218 | b = tl.load(grad_output_ptrs, mask=mask_n & (offs_m[:, None] + m * TILE_M < length), other=0.0) 219 | 220 | acc += tl.dot(a, b, out_dtype=out_dtype) 221 | input_ptrs += TILE_M * stride_input_m 222 | grad_output_ptrs += TILE_M * stride_grad_output_m 223 | 224 | acc = acc.to(grad_other.dtype.element_ty) 225 | 226 | if DETERMINISTIC: 227 | if M <= BLOCK_LENGTH: 228 | c_ptrs = grad_other + \ 229 | stride_grad_other_k * offs_k[:, None] + stride_grad_other_n * offs_n[None, :] 230 | if EVEN_N and EVEN_K: 231 | tl.store(c_ptrs, acc) 232 | else: 233 | c_mask = mask_k & mask_n 234 | tl.store(c_ptrs, acc, mask=c_mask) 235 | else: 236 | c_ptrs = grad_other_tiles + \ 237 | next_id * stride_grad_other_b + stride_grad_other_k * offs_k[:, None] + stride_grad_other_n * offs_n[None, :] 238 | if EVEN_N and EVEN_K: 239 | tl.store(c_ptrs, acc) 240 | else: 241 | c_mask = mask_k & mask_n 242 | tl.store(c_ptrs, acc, mask=c_mask) 243 | else: 244 | c_ptrs = grad_other + \ 245 | stride_grad_other_k * offs_k[:, None] + stride_grad_other_n * offs_n[None, :] 246 | if M <= BLOCK_LENGTH: 247 | if EVEN_N and EVEN_K: 248 | tl.store(c_ptrs, acc) 249 | else: 250 | c_mask = mask_k & mask_n 251 | tl.store(c_ptrs, acc, mask=c_mask) 252 | else: 253 | if EVEN_N and EVEN_K: 254 | tl.atomic_add(c_ptrs, acc) 255 | else: 256 | c_mask = mask_k & mask_n 257 | tl.atomic_add(c_ptrs, acc, mask=c_mask) 258 | -------------------------------------------------------------------------------- /fasten/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .tensor_slice import TensorSlice 4 | from .utils import Engine, engine_ops 5 | 6 | 7 | def execute_engine(*args, engine: Engine, tensor_slice: TensorSlice = None, op_name: str = None): 8 | if engine == Engine.AUTO or engine == Engine.TRITON: 9 | assert tensor_slice is not None, 'tensor_slice must be provided when using AUTO or TRITON engine' 10 | autotune = engine == Engine.AUTO 11 | cache_entry = tensor_slice.schedule(op_name, *args, autotune=autotune) 12 | best_config = cache_entry.best_config 13 | if best_config.input_tiles is None: 14 | return cache_entry.best_op(*args, input_slices=tensor_slice.slices) 15 | else: 16 | return cache_entry.best_op(*args, input_slices=tensor_slice.slices, **(best_config.asdict())) 17 | elif engine == Engine.TORCH: 18 | engine_op = getattr(engine_ops[engine], op_name) 19 | return engine_op(*args, input_slices=tensor_slice.slices) 20 | else: 21 | raise NotImplementedError(f'Engine {engine} is not implemented') 22 | 23 | 24 | class FastenSegmentMatmul(torch.autograd.Function): 25 | 26 | @staticmethod 27 | def forward(ctx, input: torch.Tensor, other: torch.Tensor, 28 | tensor_slice: TensorSlice, engine: Engine = Engine.AUTO): 29 | ctx.save_for_backward(input, other) 30 | ctx.engine = engine 31 | ctx.tensor_slice = tensor_slice 32 | return execute_engine(input, other, engine=engine, tensor_slice=tensor_slice, op_name='segment_matmul_forward') 33 | 34 | @staticmethod 35 | def backward(ctx, grad_output: torch.Tensor): 36 | input, other = ctx.saved_tensors 37 | grad_input = execute_engine( 38 | input, grad_output, other, 39 | engine=ctx.engine, tensor_slice=ctx.tensor_slice, op_name='segment_matmul_backward_input') 40 | grad_other = execute_engine( 41 | input, grad_output, other, 42 | engine=ctx.engine, tensor_slice=ctx.tensor_slice, op_name='segment_matmul_backward_other') 43 | return grad_input, grad_other, None, None 44 | 45 | 46 | fasten_segment_matmul = FastenSegmentMatmul.apply 47 | -------------------------------------------------------------------------------- /fasten/runtime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deep-Learning-Profiling-Tools/fasten/784bd907aa3685770fb14608fbbd620e21261937/fasten/runtime/__init__.py -------------------------------------------------------------------------------- /fasten/runtime/stream_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class StreamPool: 5 | _streams = [] 6 | if torch.cuda.is_available(): 7 | _streams.append(torch.cuda.current_stream()) 8 | 9 | @classmethod 10 | def add(cls, nstreams: int = 1) -> None: 11 | for _ in range(nstreams): 12 | cls._streams.append(torch.cuda.Stream()) 13 | 14 | @classmethod 15 | def reserve(cls, nstreams: int = 1) -> None: 16 | if torch.cuda.is_available(): 17 | if len(cls._streams) < nstreams: 18 | cls.add(nstreams - len(cls._streams)) 19 | 20 | @classmethod 21 | def get(cls, stream_idx: int = 1) -> torch.cuda.Stream: 22 | if torch.cuda.is_available(): 23 | return cls._streams[stream_idx] 24 | else: 25 | return None 26 | 27 | @classmethod 28 | def size(cls) -> int: 29 | return len(cls._streams) 30 | -------------------------------------------------------------------------------- /fasten/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import triton 5 | from dataclasses import asdict, dataclass, field 6 | 7 | from .utils import GlobalConfig, TilingMethod 8 | 9 | 10 | @dataclass 11 | class BestConfig: 12 | tile_size: int = None # the maximum size of each tile 13 | avg_tile_size: float = None # the average size of each tile 14 | stddev_tile_size: float = None # the standard deviation of tile size 15 | block_size: int = None # the number of tiles belong to a block, -1: dynamic block size 16 | num_blocks: int = None # number of blocks that group the tiles 17 | input_tiles: torch.Tensor = None 18 | slice_tile_mapping: torch.Tensor = None 19 | deterministic: bool = GlobalConfig.deterministic 20 | 21 | def asdict(self): 22 | return asdict(self) 23 | 24 | 25 | @dataclass 26 | class CacheEntry: 27 | best_ms: float 28 | best_config: BestConfig 29 | best_op: callable 30 | 31 | 32 | @dataclass 33 | class Scheduler: 34 | get_key: callable 35 | prune: callable = None 36 | record: callable = None 37 | cache: dict = None 38 | default_tile_size: int = 32 39 | tile_sizes: list[int] = field(default_factory=lambda: [Scheduler.default_block_size]) 40 | default_tiling_method = TilingMethod.DEFAULT 41 | tiling_methods: list[TilingMethod] = field(default_factory=lambda: [Scheduler.default_tiling_method]) 42 | default_block_size: int = 1 43 | block_sizes: list[int] = field(default_factory=lambda: [Scheduler.default_block_size]) 44 | 45 | def get_configs(self): 46 | configs = [] 47 | for tile_size in self.tile_sizes: 48 | for tiling_method in self.tiling_methods: 49 | for block_size in self.block_sizes: 50 | configs.append((tile_size, tiling_method, block_size)) 51 | return configs 52 | 53 | 54 | def _compress_slices(subslices: list[list], tile_size: int, block_size: int, num_blocks: int) -> Tuple[list[list], list[list]]: 55 | """Compress subslices into large and small blocks.""" 56 | compressed_subslices = [] 57 | small_subslices = [] 58 | for i in range(num_blocks): 59 | block_start_idx = i * block_size 60 | block_end_idx = min((i + 1) * block_size, len(subslices)) 61 | 62 | # Extract the first and last subslice for comparison 63 | first_subslice = subslices[block_start_idx] 64 | last_subslice = subslices[block_end_idx - 1] if block_end_idx - 1 < len(subslices) else None 65 | 66 | # Determine if we can create a large block 67 | if last_subslice and first_subslice[1] == last_subslice[1] \ 68 | and first_subslice[2] + tile_size * block_size == last_subslice[3]: 69 | compressed_subslices.append([first_subslice[0], first_subslice[1], first_subslice[2], last_subslice[3], 0]) 70 | else: 71 | next_id = 0 72 | for j in range(block_start_idx, block_end_idx): 73 | subslice = subslices[j] 74 | # Set continuation index for small blocks 75 | if j == block_start_idx: 76 | subslice[4] = len(small_subslices) + num_blocks 77 | next_id = subslice[4] + 1 78 | compressed_subslices.append(subslice) 79 | else: 80 | subslice[4] = next_id 81 | next_id += 1 82 | small_subslices.append(subslice) 83 | 84 | return compressed_subslices, small_subslices 85 | 86 | 87 | def tiling(slices: list[tuple], tile_size: int, block_size: int, reorder: bool) -> Tuple[list[list], int]: 88 | """Create subslices based on the tile size and compress them into blocks.""" 89 | # Generate subslices 90 | subslices = [ 91 | [index, type, off, min(off + tile_size, end), -1] 92 | for index, type, start, end, _ in slices 93 | for off in range(start, end, tile_size) 94 | ] 95 | 96 | if reorder: 97 | # Calculate the number of blocks 98 | num_blocks = triton.cdiv(len(subslices), block_size) 99 | 100 | # Compress subslices into large and small blocks 101 | compressed_subslices, small_subslices = _compress_slices(subslices, tile_size, block_size, num_blocks) 102 | 103 | # Combine all subslices and return 104 | compressed_subslices.extend(small_subslices) 105 | return compressed_subslices, num_blocks 106 | else: 107 | blocks = [] 108 | cur_block = [] 109 | 110 | def append_block(): 111 | if cur_block[0][2] + tile_size * block_size == cur_block[-1][3]: 112 | blocks.append([cur_block[0][0], cur_block[0][1], cur_block[0][2], cur_block[-1][3], 0]) 113 | else: 114 | blocks.append([cur_block[0][0], cur_block[0][1], cur_block[0][2], cur_block[-1][3], -1]) 115 | 116 | for subslice in subslices: 117 | if len(cur_block) == block_size or (len(cur_block) > 0 and subslice[1] != cur_block[-1][1]): 118 | append_block() 119 | cur_block = [] 120 | cur_block.append(subslice) 121 | if len(cur_block) > 0: 122 | append_block() 123 | return blocks, len(blocks) 124 | 125 | 126 | def _init_segment_matmul_forward_scheduler(): 127 | def get_key(input: torch.Tensor, other: torch.Tensor): 128 | return (input.size(1), other.size(2)) # (K, N) 129 | 130 | def prune(input_tiles, key: Tuple, config: Tuple) -> bool: 131 | tile_size, tiling_method, block_size = config 132 | if tile_size >= 64 and block_size >= 4: 133 | # low cache utilization 134 | return True 135 | if key[1] >= 128 and tile_size <= 32: 136 | # When K is large, we should use larger tile size 137 | return True 138 | return False 139 | 140 | return Scheduler(get_key=get_key, tile_sizes=[32, 64, 128], tiling_methods=[TilingMethod.BALANCED], block_sizes=[1, 2, 4, 8], prune=prune) 141 | 142 | 143 | def _init_segment_matmul_backward_input_scheduler(): 144 | def get_key(input: torch.Tensor, grad_output: torch.Tensor, other: torch.Tensor): 145 | return (input.size(1), other.size(2)) # (K, N) 146 | 147 | def prune(input_tiles, key: Tuple, config: Tuple) -> bool: 148 | tile_size, tiling_method, block_size = config 149 | if tile_size >= 64 and block_size >= 4: 150 | # low cache utilization 151 | return True 152 | if key[1] >= 128 and tile_size <= 32: 153 | # When K is large, we should use larger tile size 154 | return True 155 | return False 156 | 157 | return Scheduler(get_key=get_key, tile_sizes=[32, 64, 128], tiling_methods=[TilingMethod.BALANCED], block_sizes=[1, 2, 4, 8], prune=prune) 158 | 159 | 160 | def _init_segment_matmul_backward_other_scheduler(): 161 | def get_key(input: torch.Tensor, grad_output: torch.Tensor, other: torch.Tensor): 162 | return (input.size(1), other.size(2), GlobalConfig.deterministic) # (K, N) 163 | 164 | def prune(input_tiles, key: Tuple, config: Tuple) -> bool: 165 | tile_size, tiling_method, block_size = config 166 | stddev_tile_size = input_tiles.stddev_tile_size 167 | avg_tile_size = input_tiles.avg_tile_size 168 | num_slices = len(input_tiles) 169 | if num_slices < 100 and block_size >= 2: 170 | # 1. low parallelism 171 | return True 172 | if block_size != 1 and stddev_tile_size / avg_tile_size >= 0.5: 173 | # 2. low utilization 174 | return True 175 | if block_size != 16 and num_slices >= 400: 176 | # 3. Too many slices 177 | return True 178 | return False 179 | 180 | # Only default tiling method is supported 181 | return Scheduler(get_key=get_key, tile_sizes=[32, 64, 128], tiling_methods=[TilingMethod.DEFAULT], block_sizes=[1, 2, 4, 8, 16], prune=prune) 182 | 183 | 184 | schedulers = { 185 | 'segment_matmul_forward': _init_segment_matmul_forward_scheduler(), 186 | 'segment_matmul_backward_input': _init_segment_matmul_backward_input_scheduler(), 187 | 'segment_matmul_backward_other': _init_segment_matmul_backward_other_scheduler(), 188 | } 189 | -------------------------------------------------------------------------------- /fasten/stats.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .tensor_slice import TensorSlice 4 | 5 | 6 | def get_matmul_flops(input: TensorSlice, weight: torch.Tensor): 7 | assert weight.dim() == 3, f"weight dim should be 3, got {weight.dim()}" 8 | flops = 0 9 | for i in range(len(input)): 10 | s = input.get_slice_from_index(i, is_tensor=False) 11 | if s.stop - s.start == 0: 12 | continue 13 | flops += (s.stop - s.start) * weight.shape[1] * weight.shape[2] * 2 14 | return flops 15 | 16 | 17 | def get_matmul_bytes(input: TensorSlice, weight: torch.Tensor): 18 | assert weight.dim() == 3, f"weight dim should be 3, got {weight.dim()}" 19 | bytes = 0 20 | for i in range(len(input)): 21 | s = input.get_slice_from_index(i, is_tensor=False) 22 | if s.stop - s.start == 0: 23 | continue 24 | # input 25 | bytes += (s.stop - s.start) * weight.shape[1] * weight.element_size() 26 | # weight 27 | bytes += weight.shape[1] * weight.shape[2] * weight.element_size() 28 | # output 29 | bytes += (s.stop - s.start) * weight.shape[2] * weight.element_size() 30 | return bytes 31 | -------------------------------------------------------------------------------- /fasten/tensor_slice.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | from triton.runtime.autotuner import OutOfResources 6 | from triton.testing import do_bench 7 | 8 | from .operators import torch_ops, triton_ops 9 | from .scheduler import BestConfig, CacheEntry, Scheduler, schedulers, tiling 10 | from .utils import GlobalConfig, TilingMethod, is_debug 11 | 12 | 13 | class TensorSlice: 14 | ''' 15 | Construct a TensorSlice data structure 16 | 17 | Args: 18 | data: The original data tensor, could be on either the CPU or the GPU. 19 | It must have been sorted by types. 20 | slices: A 5-dim PyTorch Tensor, where each row represents [type_index, type, start, end, next]. 21 | It can also be a int or a list, then internally we transform it to a tensor. 22 | device: The device to put the slices on, default is 'cpu' 23 | block_size: The number of tiles belong to a block, default is 1 24 | num_blocks: The number of blocks that group the tiles, default is None, which means the number of blocks is the same as the number of slices. 25 | ''' 26 | 27 | def __init__(self, data: torch.Tensor, slices: Union[torch.Tensor, list, int], device: str = 'cpu', block_size: int = 1, num_blocks: Optional[int] = None, tiling_method: TilingMethod = TilingMethod.DEFAULT) -> None: 28 | self._data = data 29 | 30 | if type(slices) is int: 31 | # each slice is a single type 32 | self._slices = torch.zeros((slices, 5), dtype=torch.int, device='cpu') 33 | for i in range(0, slices): 34 | self._slices[i][0] = i # slice index 35 | self._slices[i][1] = i # type 36 | self._slices[i][2] = i # start 37 | self._slices[i][3] = i + 1 # end 38 | self._slices[i][4] = 0 # next 39 | self._slices = self._slices.to(device) 40 | elif type(slices) is list: 41 | # 2d list, nx5 42 | self._slices = torch.as_tensor(slices, dtype=torch.int, device=device) 43 | else: 44 | self._slices = slices.to(device) 45 | # Don't backpropagate on slice tensors 46 | self._slices.requires_grad = False 47 | self._block_size = block_size 48 | self._num_blocks = num_blocks if num_blocks is not None else len(self._slices) 49 | self._cache = {} 50 | self._tiling_method = tiling_method 51 | self._contiguous_ratio = self._get_contiguous_ratio() 52 | self._slice_tile_mapping = self._get_slice_tile_mapping() 53 | self._stddev_tile_size = self._get_stddev_tile_size() 54 | self._avg_tile_size = self._get_avg_tile_size() 55 | 56 | def _init_mappings(self): 57 | if not hasattr(self, '_type_slice_dict'): 58 | self._type_slice_dict = OrderedDict() 59 | for i in range(self._slices.size(0)): 60 | self._type_slice_dict[self._slices[i, 1].item()] = i 61 | if not hasattr(self, '_slice_type_dict'): 62 | self._slice_type_dict = OrderedDict() 63 | for i in range(self._slices.size(0)): 64 | self._slice_type_dict[i] = self._slices[i, 1].item() 65 | 66 | def __len__(self) -> int: 67 | return self._slices.size(0) 68 | 69 | @property 70 | def data_size(self) -> int: 71 | return self.stop(is_tensor=False) - self.start(is_tensor=False) 72 | 73 | def start(self, is_tensor: bool = True): 74 | ''' 75 | Get the start index of the original tensor. 76 | 77 | Args: 78 | is_tensor: If true, return a tensor. Otherwise, return a python int. 79 | ''' 80 | return self._slices[0, 2] if is_tensor else self._slices[0, 2].item() 81 | 82 | def stop(self, is_tensor: bool = True): 83 | ''' 84 | Get the stop index of the original tensor. 85 | 86 | Args: 87 | is_tensor: If true, return a tensor. Otherwise, return a python int. 88 | ''' 89 | return self._slices[-1, 3] if is_tensor else self._slices[-1, 3].item() 90 | 91 | @property 92 | def slices(self): 93 | return self._slices 94 | 95 | @property 96 | def data(self): 97 | return self._data 98 | 99 | @property 100 | def num_blocks(self): 101 | return self._num_blocks 102 | 103 | @property 104 | def block_size(self): 105 | return self._block_size 106 | 107 | @property 108 | def contiguous_ratio(self): 109 | return self._contiguous_ratio 110 | 111 | @property 112 | def slice_tile_mapping(self): 113 | return self._slice_tile_mapping 114 | 115 | @property 116 | def avg_tile_size(self): 117 | return self._avg_tile_size 118 | 119 | @property 120 | def stddev_tile_size(self): 121 | return self._stddev_tile_size 122 | 123 | def get_slice_from_type(self, type: int, is_tensor: bool = True): 124 | ''' 125 | Get the slice of the original tensor from the type. 126 | 127 | Args: 128 | type: The type 129 | is_tensor: If true, return a tensor. Otherwise, return a python slice. 130 | ''' 131 | self._init_mappings() 132 | entry = self._slices[self._type_slice_dict[type]][2:4] 133 | return entry if is_tensor else slice(entry[0].item(), entry[1].item()) 134 | 135 | def get_slice_from_index(self, index: int, is_tensor: bool = True): 136 | ''' 137 | Get the slice of the original tensor from the slice index. 138 | 139 | Args: 140 | index: The slice index 141 | is_tensor: If true, return a tensor. Otherwise, return a python slice. 142 | ''' 143 | self._init_mappings() 144 | entry = self._slices[index][2:4] 145 | return entry if is_tensor else slice(entry[0].item(), entry[1].item()) 146 | 147 | def get_type_from_index(self, index: int, is_tensor: bool = True) -> int: 148 | ''' 149 | Get the type from the slice index. 150 | 151 | Args: 152 | index: The slice index 153 | is_tensor: If true, return a tensor. Otherwise, return a python int. 154 | ''' 155 | self._init_mappings() 156 | return self._slices[index][1] if is_tensor else self._slices[index][1].item() 157 | 158 | def _get_slice_tile_mapping(self) -> torch.Tensor: 159 | if self._tiling_method == TilingMethod.DEFAULT: 160 | subslices = self._slices.tolist() 161 | segments = [] 162 | begin = 0 163 | for i in range(1, len(subslices)): 164 | if subslices[i][1] != subslices[i - 1][1]: 165 | segments.append((subslices[i - 1][1], begin, i)) 166 | begin = i 167 | segments.append((subslices[-1][1], begin, len(subslices))) 168 | return torch.tensor(segments, dtype=torch.int, device=self._slices.device) 169 | else: 170 | return None 171 | 172 | def _get_avg_tile_size(self) -> float: 173 | return torch.mean((self._slices[:, 3] - self._slices[:, 2]).float()).item() 174 | 175 | def _get_stddev_tile_size(self) -> float: 176 | return torch.std((self._slices[:, 3] - self._slices[:, 2]).float(), correction=0.0).item() 177 | 178 | def _get_contiguous_ratio(self) -> float: 179 | return torch.sum(self.slices[:, 4] == 0).item() / float(self.num_blocks) 180 | 181 | def _lookup_cache(self, op_name: str, key: tuple) -> CacheEntry: 182 | if op_name in self._cache and key in self._cache[op_name]: 183 | return self._cache[op_name][key] 184 | return None 185 | 186 | def _update_cache(self, op_name: str, key: tuple, entry: CacheEntry): 187 | if op_name not in self._cache: 188 | self._cache[op_name] = {} 189 | self._cache[op_name][key] = entry 190 | 191 | def tiling(self, tile_size: int = Scheduler.default_tile_size, block_size: int = Scheduler.default_block_size, method: TilingMethod = Scheduler.default_tiling_method): 192 | if tile_size <= 0: 193 | raise ValueError(f'Invalid tile size {tile_size}') 194 | slices = self._slices.tolist() 195 | num_blocks = None 196 | if method == TilingMethod.DEFAULT: 197 | subslices, num_blocks = tiling(slices, tile_size, block_size, reorder=False) 198 | elif method == TilingMethod.BALANCED: 199 | subslices, num_blocks = tiling(slices, tile_size, block_size, reorder=True) 200 | else: 201 | raise ValueError(f'Unsupported tiling method {method}') 202 | return TensorSlice(self.data, subslices, self._slices.device, block_size=block_size, num_blocks=num_blocks) 203 | 204 | def schedule(self, op_name: str, *args, autotune: bool = False) -> CacheEntry: 205 | scheduler = schedulers[op_name] 206 | key = scheduler.get_key(*args) 207 | cache_entry = self._lookup_cache(op_name, key) 208 | 209 | if cache_entry is not None: 210 | return cache_entry 211 | 212 | if autotune: 213 | best_ms, best_config, best_op = self.autotune(op_name, *args, scheduler=scheduler) 214 | else: 215 | best_ms, best_config, best_op = self.use_defaults(op_name, scheduler=scheduler) 216 | 217 | cache_entry = CacheEntry(best_ms, best_config, best_op) 218 | self._update_cache(op_name, key, cache_entry) 219 | return cache_entry 220 | 221 | def autotune(self, op_name: str, *args, scheduler: Scheduler) -> Tuple[float, BestConfig, callable]: 222 | best_op = getattr(torch_ops, op_name) 223 | best_ms = do_bench(lambda: best_op(*args, input_slices=self.slices), warmup=1, rep=1) 224 | best_config = BestConfig() 225 | key = scheduler.get_key(*args) 226 | debug = is_debug() 227 | 228 | triton_op = getattr(triton_ops, op_name) 229 | for config in scheduler.get_configs(): 230 | tile_size, tiling_method, block_size = config 231 | input_tiles = self.tiling(tile_size, method=tiling_method, block_size=block_size) 232 | if scheduler.prune and scheduler.prune(input_tiles, key, config): 233 | continue 234 | try: 235 | def _do_bench(input_tiles, tile_size): 236 | return do_bench( 237 | lambda input_tiles=input_tiles, tile_size=tile_size: triton_op( 238 | *args, 239 | input_slices=self.slices, 240 | input_tiles=input_tiles.slices, 241 | num_blocks=input_tiles.num_blocks, 242 | block_size=input_tiles.block_size, 243 | tile_size=tile_size, 244 | slice_tile_mapping=input_tiles.slice_tile_mapping, 245 | avg_tile_size=input_tiles.avg_tile_size, 246 | stddev_tile_size=input_tiles.stddev_tile_size 247 | ), 248 | warmup=1, 249 | rep=1, 250 | ) 251 | _do_bench(input_tiles, tile_size) # warmup 252 | ms = _do_bench(input_tiles, tile_size) 253 | if debug: 254 | print(f"op_name={op_name}, tile_size={tile_size}, block_size={block_size}, avg_tile_size={input_tiles.avg_tile_size}, " 255 | f"stddev_tile_size={input_tiles.stddev_tile_size}, contiguous_ratio={input_tiles.contiguous_ratio}, ms={ms}") 256 | if scheduler.record: 257 | scheduler.record(input_tiles.slices, key, config, ms) 258 | if ms < best_ms: 259 | best_ms, best_op, best_config = ms, triton_op, BestConfig(tile_size=tile_size, block_size=input_tiles.block_size, 260 | input_tiles=input_tiles.slices, num_blocks=input_tiles.num_blocks, 261 | slice_tile_mapping=input_tiles.slice_tile_mapping, 262 | avg_tile_size=input_tiles.avg_tile_size, stddev_tile_size=input_tiles.stddev_tile_size, 263 | deterministic=GlobalConfig.deterministic) 264 | except OutOfResources: 265 | if debug: 266 | print(f'op_name={op_name}, tile_size={tile_size}, block_size={block_size}, out of resources') 267 | if debug: 268 | print(f"best op_name={op_name}, tile_size={best_config.tile_size}, block_size={best_config.block_size}, " 269 | f"avg_tile_size={best_config.avg_tile_size}, stddev_tile_size={best_config.stddev_tile_size}") 270 | return best_ms, best_config, best_op 271 | 272 | def use_defaults(self, op_name: str, scheduler: Scheduler) -> Tuple[float, BestConfig, callable]: 273 | input_tiles = self.tiling(scheduler.default_tile_size, method=scheduler.default_tiling_method, block_size=scheduler.default_block_size) 274 | return (0.0, BestConfig(tile_size=scheduler.default_tile_size, block_size=scheduler.default_block_size, input_tiles=input_tiles.slices, num_blocks=input_tiles.num_blocks, 275 | slice_tile_mapping=input_tiles.slice_tile_mapping, avg_tile_size=input_tiles.avg_tile_size, stddev_tile_size=input_tiles.stddev_tile_size, deterministic=GlobalConfig.deterministic), 276 | getattr(triton_ops, op_name)) 277 | 278 | 279 | def compact_tensor_types(data: torch.Tensor, types: torch.Tensor, *, 280 | dim: int = 0, descending: bool = False, 281 | is_sorted: bool = False, device: str = 'cpu') -> TensorSlice: 282 | """ 283 | Sort the types and its corresponding tensor, if given 284 | 285 | Args: 286 | data (torch.Tensor): The input data to be sorted. 287 | types (torch.Tensor): The type of each record. 288 | dim (int, optional): Which dimension of the tensor represents types. Defaults to 0. 289 | descending (bool, optional): If true, sort the tensor in descending order. Defaults to False. 290 | is_sorted (bool, optional): If true, assumes types is already sorted. Defaults to False. 291 | device (str, optional): The device to put the slices. Note that tensor and sorted_types remain on the original device. Defaults to 'cpu'. 292 | 293 | Returns: 294 | TensorSlice: The sorted tensor and its corresponding TensorSlice. 295 | """ 296 | if not is_sorted: 297 | sorted_types, type_indices = torch.sort(types, descending=descending, stable=True) 298 | else: 299 | sorted_types = types 300 | 301 | unique_types, type_counts = torch.unique_consecutive( 302 | sorted_types, return_inverse=False, return_counts=True) 303 | 304 | type_list = [] 305 | cur_index = 0 306 | for i in range(len(unique_types)): 307 | type_list.append([ 308 | i, unique_types[i].item(), cur_index, cur_index + type_counts[i].item(), -1]) 309 | cur_index += type_counts[i].item() 310 | 311 | sorted_data = torch.index_select(data, dim=dim, index=type_indices) if not is_sorted else data 312 | return TensorSlice(sorted_data, type_list, device=device) 313 | -------------------------------------------------------------------------------- /fasten/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | 4 | import torch 5 | import triton.language as tl 6 | 7 | from .operators import torch_ops, triton_ops 8 | 9 | 10 | class GlobalConfig: 11 | deterministic: bool = True 12 | with_autotune: bool = True 13 | with_perf_model: bool = False 14 | binning_interval: float = 32.0 15 | 16 | 17 | class TilingMethod(Enum): 18 | DEFAULT = 'default' 19 | BALANCED = 'balanced' 20 | 21 | 22 | class Engine(Enum): 23 | ''' 24 | Engine is an enum class, including 'torch', 'triton', 'auto', and the default is 'auto': 25 | - 'auto': use triton operators if available, otherwise use torch native operators or triton operators 26 | offer lower performance 27 | - 'torch': use torch native operators 28 | - 'triton': use triton operators 29 | ''' 30 | TORCH = 'torch' 31 | TRITON = 'triton' 32 | AUTO = 'auto' 33 | 34 | 35 | engine_ops = { 36 | Engine.TORCH: torch_ops, 37 | Engine.TRITON: triton_ops, 38 | } 39 | 40 | 41 | def torch_dtype_to_triton_dtype(dtype, grad: bool = False): 42 | type_dict = { 43 | torch.float32: tl.float32, 44 | torch.float16: tl.float16, 45 | torch.int32: tl.int32, 46 | torch.int64: tl.int64, 47 | } 48 | promo_type_dict = { 49 | torch.float16: tl.float32, 50 | torch.float32: tl.float32, 51 | } 52 | if grad: 53 | if dtype not in promo_type_dict: 54 | raise ValueError(f'Unsupported dtype {dtype}') 55 | return promo_type_dict[dtype] 56 | else: 57 | if dtype not in type_dict: 58 | raise ValueError(f'Unsupported dtype {dtype}') 59 | return type_dict[dtype] 60 | 61 | 62 | def is_debug(): 63 | FLAG = os.environ.get('FASTEN_DEBUG', '0') 64 | return FLAG == '1' or FLAG.lower() == 'true' or FLAG.lower() == 'yes' or FLAG.lower() == 'on' 65 | 66 | 67 | def binning(x, interval: float = GlobalConfig.binning_interval): 68 | return x // interval 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Always prefer setuptools over distutils 2 | import os 3 | 4 | from setuptools import find_packages, setup 5 | 6 | ci_only = os.environ.get("CI_ONLY", False) 7 | pkgs = ['pytest', 'flake8', 'autopep8', 'isort', 'pre-commit', 'pytest'] 8 | if not ci_only: 9 | pkgs += ['pyg-nightly', 'pyg-lib@git+https://github.com/pyg-team/pyg-lib.git'], 10 | 11 | setup( 12 | name='fasten', # Required 13 | version='0.1', # Required 14 | description='A libary for fast segment operators', # Optional 15 | 16 | author='Keren Zhou', # Optional 17 | 18 | author_email='kerenzhou@outlook.com', # Optional 19 | 20 | classifiers=[ # Optional 21 | 'Development Status :: 3 - Alpha', 22 | 'License :: OSI Approved :: BSD-3 License' 23 | ], 24 | 25 | packages=find_packages(), # Required 26 | 27 | python_requires='>=3.6, <4', 28 | install_requires=pkgs, 29 | include_package_data=True, 30 | long_description_content_type='text/markdown' 31 | ) 32 | -------------------------------------------------------------------------------- /test/datasets_csv/ACM.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,5343 3 | 5343,15292 4 | 15292,18317 5 | 18317,273936 6 | -------------------------------------------------------------------------------- /test/datasets_csv/AIFB.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,4139 3 | 4139,8278 4 | 8278,12065 5 | 12065,15852 6 | 15852,19814 7 | 19814,23776 8 | 23776,26252 9 | 26252,28728 10 | 28728,29113 11 | 29113,29498 12 | 29498,30705 13 | 30705,31912 14 | 31912,33122 15 | 33122,34332 16 | 34332,35549 17 | 35549,36766 18 | 36766,37718 19 | 37718,38670 20 | 38670,39622 21 | 39622,40574 22 | 40574,41330 23 | 41330,42086 24 | 42086,42832 25 | 42832,43578 26 | 43578,44146 27 | 44146,44714 28 | 44714,45254 29 | 45254,45794 30 | 45794,46324 31 | 46324,46854 32 | 46854,47210 33 | 47210,47566 34 | 47566,47904 35 | 47904,48242 36 | 48242,48551 37 | 48551,48860 38 | 48860,49151 39 | 49151,49442 40 | 49442,49573 41 | 49573,49704 42 | 49704,49902 43 | 49902,50100 44 | 50100,50326 45 | 50326,50552 46 | 50552,50750 47 | 50750,50948 48 | 50948,51148 49 | 51148,51348 50 | 51348,51383 51 | 51383,51418 52 | 51418,51608 53 | 51608,51798 54 | 51798,51956 55 | 51956,52114 56 | 52114,52137 57 | 52137,52160 58 | 52160,52160 59 | 52160,52160 60 | 52160,52273 61 | 52273,52386 62 | 52386,52530 63 | 52530,52674 64 | 52674,52803 65 | 52803,52932 66 | 52932,53046 67 | 53046,53160 68 | 53160,53239 69 | 53239,53318 70 | 53318,53397 71 | 53397,53476 72 | 53476,53540 73 | 53540,53604 74 | 53604,53668 75 | 53668,53732 76 | 53732,53782 77 | 53782,53832 78 | 53832,53880 79 | 53880,53928 80 | 53928,53944 81 | 53944,53960 82 | 53960,53975 83 | 53975,53990 84 | 53990,54002 85 | 54002,54014 86 | 54014,54014 87 | 54014,54014 88 | 54014,54019 89 | 54019,54024 90 | 54024,54024 91 | 54024,54024 92 | -------------------------------------------------------------------------------- /test/datasets_csv/AM.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,219777 3 | 219777,439554 4 | 439554,474471 5 | 474471,509388 6 | 509388,722309 7 | 722309,935230 8 | 935230,946166 9 | 946166,957102 10 | 957102,969238 11 | 969238,981374 12 | 981374,993088 13 | 993088,1004802 14 | 1004802,1016839 15 | 1016839,1028876 16 | 1028876,1039026 17 | 1039026,1049176 18 | 1049176,1068202 19 | 1068202,1087228 20 | 1087228,1102580 21 | 1102580,1117932 22 | 1117932,1133152 23 | 1133152,1148372 24 | 1148372,1165699 25 | 1165699,1183026 26 | 1183026,1215471 27 | 1215471,1247916 28 | 1247916,1253249 29 | 1253249,1258582 30 | 1258582,1260794 31 | 1260794,1263006 32 | 1263006,1295173 33 | 1295173,1327340 34 | 1327340,1416065 35 | 1416065,1504790 36 | 1504790,1526353 37 | 1526353,1547916 38 | 1547916,1569311 39 | 1569311,1590706 40 | 1590706,1601546 41 | 1601546,1612386 42 | 1612386,1641638 43 | 1641638,1670890 44 | 1670890,1674833 45 | 1674833,1678776 46 | 1678776,1679776 47 | 1679776,1680776 48 | 1680776,1684569 49 | 1684569,1688362 50 | 1688362,1689362 51 | 1689362,1690362 52 | 1690362,1694155 53 | 1694155,1697948 54 | 1697948,1770862 55 | 1770862,1843776 56 | 1843776,1847569 57 | 1847569,1851362 58 | 1851362,1855155 59 | 1855155,1858948 60 | 1858948,1887635 61 | 1887635,1916322 62 | 1916322,1917632 63 | 1917632,1918942 64 | 1918942,1923764 65 | 1923764,1928586 66 | 1928586,1928776 67 | 1928776,1928966 68 | 1928966,1994746 69 | 1994746,2060526 70 | 2060526,2111725 71 | 2111725,2162924 72 | 2162924,2172695 73 | 2172695,2182466 74 | 2182466,2192234 75 | 2192234,2202002 76 | 2202002,2257504 77 | 2257504,2313006 78 | 2313006,2366623 79 | 2366623,2420240 80 | 2420240,2425126 81 | 2425126,2430012 82 | 2430012,2481117 83 | 2481117,2532222 84 | 2532222,2547409 85 | 2547409,2562596 86 | 2562596,2576329 87 | 2576329,2590062 88 | 2590062,2620197 89 | 2620197,2650332 90 | 2650332,2666552 91 | 2666552,2682772 92 | 2682772,2697348 93 | 2697348,2711924 94 | 2711924,2741600 95 | 2741600,2771276 96 | 2771276,2786375 97 | 2786375,2801474 98 | 2801474,2827119 99 | 2827119,2852764 100 | 2852764,2858835 101 | 2858835,2864906 102 | 2864906,2866789 103 | 2866789,2868672 104 | 2868672,2881814 105 | 2881814,2894956 106 | 2894956,2897035 107 | 2897035,2899114 108 | 2899114,2899854 109 | 2899854,2900594 110 | 2900594,2905179 111 | 2905179,2909764 112 | 2909764,2931903 113 | 2931903,2954042 114 | 2954042,2973647 115 | 2973647,2993252 116 | 2993252,2995939 117 | 2995939,2998626 118 | 2998626,3001197 119 | 3001197,3003768 120 | 3003768,3005946 121 | 3005946,3008124 122 | 3008124,3023253 123 | 3023253,3038382 124 | 3038382,3053727 125 | 3053727,3069072 126 | 3069072,3084384 127 | 3084384,3099696 128 | 3099696,3107054 129 | 3107054,3114412 130 | 3114412,3121732 131 | 3121732,3129052 132 | 3129052,3134767 133 | 3134767,3140482 134 | 3140482,3145839 135 | 3145839,3151196 136 | 3151196,3151852 137 | 3151852,3152508 138 | 3152508,3159642 139 | 3159642,3166776 140 | 3166776,3173910 141 | 3173910,3181044 142 | 3181044,3188009 143 | 3188009,3194974 144 | 3194974,3200608 145 | 3200608,3206242 146 | 3206242,3211411 147 | 3211411,3216580 148 | 3216580,3218688 149 | 3218688,3220796 150 | 3220796,3226497 151 | 3226497,3232198 152 | 3232198,3235028 153 | 3235028,3237858 154 | 3237858,3241117 155 | 3241117,3244376 156 | 3244376,3248396 157 | 3248396,3252416 158 | 3252416,3254980 159 | 3254980,3257544 160 | 3257544,3257856 161 | 3257856,3258168 162 | 3258168,3261682 163 | 3261682,3265196 164 | 3265196,3265762 165 | 3265762,3266328 166 | 3266328,3268607 167 | 3268607,3270886 168 | 3270886,3274766 169 | 3274766,3278646 170 | 3278646,3281853 171 | 3281853,3285060 172 | 3285060,3287503 173 | 3287503,3289946 174 | 3289946,3293371 175 | 3293371,3296796 176 | 3296796,3299675 177 | 3299675,3302554 178 | 3302554,3303956 179 | 3303956,3305358 180 | 3305358,3305446 181 | 3305446,3305534 182 | 3305534,3306266 183 | 3306266,3306998 184 | 3306998,3308131 185 | 3308131,3309264 186 | 3309264,3310398 187 | 3310398,3311532 188 | 3311532,3312106 189 | 3312106,3312680 190 | 3312680,3313363 191 | 3313363,3314046 192 | 3314046,3314208 193 | 3314208,3314370 194 | 3314370,3314531 195 | 3314531,3314692 196 | 3314692,3314770 197 | 3314770,3314848 198 | 3314848,3314961 199 | 3314961,3315074 200 | 3315074,3315084 201 | 3315084,3315094 202 | 3315094,3316630 203 | 3316630,3318166 204 | 3318166,3318654 205 | 3318654,3319142 206 | 3319142,3319599 207 | 3319599,3320056 208 | 3320056,3320253 209 | 3320253,3320450 210 | 3320450,3320473 211 | 3320473,3320496 212 | 3320496,3320813 213 | 3320813,3321130 214 | 3321130,3321303 215 | 3321303,3321476 216 | 3321476,3321578 217 | 3321578,3321680 218 | 3321680,3321959 219 | 3321959,3322238 220 | 3322238,3322308 221 | 3322308,3322378 222 | 3322378,3322392 223 | 3322392,3322406 224 | 3322406,3322788 225 | 3322788,3323170 226 | 3323170,3323185 227 | 3323185,3323200 228 | 3323200,3323332 229 | 3323332,3323464 230 | 3323464,3323635 231 | 3323635,3323806 232 | 3323806,3323901 233 | 3323901,3323996 234 | 3323996,3324007 235 | 3324007,3324018 236 | 3324018,3324018 237 | 3324018,3324018 238 | 3324018,3324021 239 | 3324021,3324024 240 | 3324024,3324026 241 | 3324026,3324028 242 | 3324028,3324031 243 | 3324031,3324034 244 | 3324034,3324091 245 | 3324091,3324148 246 | 3324148,3324156 247 | 3324156,3324164 248 | 3324164,3324271 249 | 3324271,3324378 250 | 3324378,3324404 251 | 3324404,3324430 252 | 3324430,3324430 253 | 3324430,3324430 254 | 3324430,3324430 255 | 3324430,3324430 256 | 3324430,3324430 257 | 3324430,3324430 258 | 3324430,3324430 259 | 3324430,3324430 260 | 3324430,3324430 261 | 3324430,3324430 262 | 3324430,3324430 263 | 3324430,3324430 264 | 3324430,3324430 265 | 3324430,3324430 266 | 3324430,3324430 267 | 3324430,3324430 268 | -------------------------------------------------------------------------------- /test/datasets_csv/BGS.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,83031 3 | 83031,166062 4 | 166062,166811 5 | 166811,167560 6 | 167560,168119 7 | 168119,168678 8 | 168678,180893 9 | 180893,193108 10 | 193108,204917 11 | 204917,216726 12 | 216726,241413 13 | 241413,266100 14 | 266100,290773 15 | 290773,315446 16 | 315446,339290 17 | 339290,363134 18 | 363134,363223 19 | 363223,363312 20 | 363312,363401 21 | 363401,363490 22 | 363490,363579 23 | 363579,363668 24 | 363668,363757 25 | 363757,363846 26 | 363846,384079 27 | 384079,404312 28 | 404312,412965 29 | 412965,421618 30 | 421618,442757 31 | 442757,463896 32 | 463896,464093 33 | 464093,464290 34 | 464290,464290 35 | 464290,464290 36 | 464290,478797 37 | 478797,493304 38 | 493304,493304 39 | 493304,493304 40 | 493304,507811 41 | 507811,522318 42 | 522318,522529 43 | 522529,522740 44 | 522740,535639 45 | 535639,548538 46 | 548538,561439 47 | 561439,574340 48 | 574340,587123 49 | 587123,599906 50 | 599906,611982 51 | 611982,624058 52 | 624058,635998 53 | 635998,647938 54 | 647938,659668 55 | 659668,671398 56 | 671398,682875 57 | 682875,694352 58 | 694352,706044 59 | 706044,717736 60 | 717736,717921 61 | 717921,718106 62 | 718106,729782 63 | 729782,741458 64 | 741458,753134 65 | 753134,764810 66 | 764810,776486 67 | 776486,788162 68 | 788162,799700 69 | 799700,811238 70 | 811238,822914 71 | 822914,834590 72 | 834590,846266 73 | 846266,857942 74 | 857942,868479 75 | 868479,879016 76 | 879016,887343 77 | 887343,895670 78 | 895670,896640 79 | 896640,897610 80 | 897610,897736 81 | 897736,897862 82 | 897862,904605 83 | 904605,911348 84 | 911348,918679 85 | 918679,926010 86 | 926010,931873 87 | 931873,937736 88 | 937736,938019 89 | 938019,938302 90 | 938302,938379 91 | 938379,938456 92 | 938456,938534 93 | 938534,938612 94 | 938612,944146 95 | 944146,949680 96 | 949680,954920 97 | 954920,960160 98 | 960160,965110 99 | 965110,970060 100 | 970060,974748 101 | 974748,979436 102 | 979436,979663 103 | 979663,979890 104 | 979890,980032 105 | 980032,980174 106 | 980174,980391 107 | 980391,980608 108 | 980608,980608 109 | 980608,980608 110 | 980608,980764 111 | 980764,980920 112 | 980920,981104 113 | 981104,981288 114 | 981288,984252 115 | 984252,987216 116 | 987216,990176 117 | 990176,993136 118 | 993136,995861 119 | 995861,998586 120 | 998586,1001537 121 | 1001537,1004488 122 | 1004488,1007440 123 | 1007440,1010392 124 | 1010392,1010746 125 | 1010746,1011100 126 | 1011100,1011101 127 | 1011101,1011102 128 | 1011102,1011628 129 | 1011628,1012154 130 | 1012154,1012154 131 | 1012154,1012154 132 | 1012154,1013696 133 | 1013696,1015238 134 | 1015238,1016468 135 | 1016468,1017698 136 | 1017698,1017720 137 | 1017720,1017742 138 | 1017742,1017853 139 | 1017853,1017964 140 | 1017964,1017964 141 | 1017964,1017964 142 | 1017964,1018370 143 | 1018370,1018776 144 | 1018776,1018787 145 | 1018787,1018798 146 | 1018798,1018875 147 | 1018875,1018952 148 | 1018952,1019033 149 | 1019033,1019114 150 | 1019114,1019520 151 | 1019520,1019926 152 | 1019926,1019937 153 | 1019937,1019948 154 | 1019948,1020354 155 | 1020354,1020760 156 | 1020760,1021066 157 | 1021066,1021372 158 | 1021372,1021372 159 | 1021372,1021372 160 | 1021372,1021381 161 | 1021381,1021390 162 | 1021390,1021390 163 | 1021390,1021390 164 | 1021390,1021390 165 | 1021390,1021390 166 | 1021390,1021390 167 | 1021390,1021390 168 | 1021390,1021413 169 | 1021413,1021436 170 | 1021436,1021454 171 | 1021454,1021472 172 | 1021472,1021472 173 | 1021472,1021472 174 | 1021472,1021475 175 | 1021475,1021478 176 | 1021478,1021479 177 | 1021479,1021480 178 | 1021480,1021483 179 | 1021483,1021486 180 | 1021486,1021488 181 | 1021488,1021490 182 | 1021490,1021492 183 | 1021492,1021494 184 | 1021494,1021494 185 | 1021494,1021494 186 | 1021494,1021494 187 | 1021494,1021494 188 | 1021494,1021494 189 | 1021494,1021494 190 | 1021494,1021494 191 | 1021494,1021494 192 | 1021494,1021494 193 | 1021494,1021494 194 | 1021494,1021494 195 | 1021494,1021494 196 | 1021494,1021494 197 | 1021494,1021494 198 | 1021494,1021495 199 | 1021495,1021496 200 | 1021496,1021496 201 | 1021496,1021496 202 | 1021496,1021496 203 | 1021496,1021496 204 | 1021496,1021496 205 | 1021496,1021496 206 | 1021496,1021496 207 | 1021496,1021496 208 | -------------------------------------------------------------------------------- /test/datasets_csv/DBLP.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,19645 3 | 19645,105455 4 | 105455,119783 5 | -------------------------------------------------------------------------------- /test/datasets_csv/Freebase.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,202674 3 | 202674,240973 4 | 240973,247588 5 | 247588,274509 6 | 274509,296409 7 | 296409,384247 8 | 384247,415733 9 | 415733,427024 10 | 427024,710694 11 | 710694,719669 12 | 719669,762584 13 | 762584,769347 14 | 769347,770637 15 | 770637,771293 16 | 771293,806880 17 | 806880,824484 18 | 824484,835432 19 | 835432,850282 20 | 850282,873095 21 | 873095,888229 22 | 888229,890444 23 | 890444,895822 24 | 895822,917121 25 | 917121,964938 26 | 964938,978066 27 | 978066,988768 28 | 988768,989327 29 | 989327,992023 30 | 992023,993124 31 | 993124,994197 32 | 994197,1012822 33 | 1012822,1021219 34 | 1021219,1045983 35 | 1045983,1046593 36 | 1046593,1053240 37 | 1053240,1057688 38 | -------------------------------------------------------------------------------- /test/datasets_csv/IMDB.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,4932 3 | 4932,19711 4 | 19711,43321 5 | -------------------------------------------------------------------------------- /test/datasets_csv/MUTAG.csv: -------------------------------------------------------------------------------- 1 | Start,End 2 | 0,22484 3 | 22484,44968 4 | 44968,63602 5 | 63602,82236 6 | 82236,91553 7 | 91553,100870 8 | 100870,110059 9 | 110059,119248 10 | 119248,128437 11 | 128437,137626 12 | 137626,141152 13 | 141152,144678 14 | 144678,145018 15 | 145018,145358 16 | 145358,145663 17 | 145663,145968 18 | 145968,146261 19 | 146261,146554 20 | 146554,146837 21 | 146837,147120 22 | 147120,147323 23 | 147323,147526 24 | 147526,147546 25 | 147546,147566 26 | 147566,147627 27 | 147627,147688 28 | 147688,147743 29 | 147743,147798 30 | 147798,147833 31 | 147833,147868 32 | 147868,147893 33 | 147893,147918 34 | 147918,147943 35 | 147943,147968 36 | 147968,147985 37 | 147985,148002 38 | 148002,148005 39 | 148005,148008 40 | 148008,148021 41 | 148021,148034 42 | 148034,148047 43 | 148047,148060 44 | 148060,148066 45 | 148066,148072 46 | 148072,148077 47 | 148077,148082 48 | -------------------------------------------------------------------------------- /test/test_nn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import pytest 4 | import torch 5 | import torch_geometric 6 | from torch import Tensor 7 | from torch_geometric.data import HeteroData 8 | from torch_geometric.nn import HEATConv, HGTConv, Linear, RGATConv, RGCNConv 9 | from torch_geometric.utils import index_sort 10 | from torch_geometric.utils.sparse import index2ptr 11 | 12 | from fasten import Engine, TensorSlice, compact_tensor_types 13 | from fasten.nn.conv import FastenHEATConv, FastenHGTConv, FastenRGATConv, FastenRGCNConv 14 | 15 | torch.backends.cuda.matmul.allow_tf32 = True 16 | torch_geometric.backend.use_segment_matmul = True 17 | 18 | 19 | def ptr_to_tensor_slice(ptr: List, data: Tensor = None, is_sorted: bool = False) -> Tuple[TensorSlice, List]: 20 | 21 | assert ptr is not None 22 | slices = [slice(ptr[i], ptr[i + 1]) for i in range(len(ptr) - 1)] 23 | types = torch.zeros((ptr[-1],), dtype=torch.int) 24 | for i, s in enumerate(slices): 25 | types[s] = i 26 | tensor_slice = compact_tensor_types(data=data, types=types, is_sorted=is_sorted, device="cuda") 27 | return tensor_slice, slices 28 | 29 | 30 | def tensor_slice_gen(x_dict, edge_index_dict, meta_data, num_heads) -> Tuple[TensorSlice, Tensor, TensorSlice, List]: 31 | 32 | # Generating tensor_slice for HeteroDictLinear 33 | ptr = [0] 34 | for key, _ in x_dict.items(): 35 | ptr.append(ptr[-1] + x_dict[key].shape[0]) 36 | tensor_slice_hdl, slices = ptr_to_tensor_slice(ptr, is_sorted=True) 37 | slices_hdl = slices 38 | 39 | # Generating tensor_slice for HeteroLinear 40 | edge_types = meta_data[1] 41 | num_edge_types = len(edge_types) 42 | H = num_heads # No of heads 43 | type_list = [] 44 | edge_map = {edge_type: i for i, edge_type in enumerate(meta_data[1])} 45 | 46 | for key, _ in edge_index_dict.items(): 47 | N = x_dict[key[0]].shape[0] 48 | edge_type_offset = edge_map[key] 49 | type_vec = torch.arange(H, dtype=torch.long).view(-1, 1).repeat(1, N) * num_edge_types + edge_type_offset 50 | type_list.append(type_vec) 51 | 52 | type_vec = torch.cat(type_list, dim=1).flatten() 53 | num_types = H * len(edge_types) 54 | ptr = index2ptr(type_vec, num_types) 55 | tensor_slice_hl, _ = ptr_to_tensor_slice(ptr, is_sorted=True) 56 | 57 | return tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl 58 | 59 | 60 | def heat_tensor_slice_gen(data) -> TensorSlice: 61 | 62 | sorted_node_type, _ = index_sort(data.node_type, len(torch.unique(data.node_type))) 63 | ptr = index2ptr(sorted_node_type, len(torch.unique(data.node_type))) 64 | tensor_slice_hl, _ = ptr_to_tensor_slice(ptr, is_sorted=True) 65 | return tensor_slice_hl 66 | 67 | 68 | @pytest.mark.parametrize("device", ["cuda"]) 69 | def test_rgcn(device: str): 70 | torch.manual_seed(12345) 71 | x = torch.randn(4, 4).to(device) 72 | edge_index = torch.tensor([ 73 | [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3], 74 | [1, 1, 1, 2, 1, 1, 1, 0, 1, 3, 1, 3], 75 | ]).to(device) 76 | edge_type = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7]).to(device) 77 | tensor_slice = compact_tensor_types(data=None, types=edge_type, is_sorted=True, device=device) 78 | 79 | torch.manual_seed(12345) 80 | rgcn_conv = RGCNConv(4, 16, 8, aggr="add", is_sorted=True).to(device) 81 | torch.manual_seed(12345) 82 | fasten_rgcn_conv = FastenRGCNConv(4, 16, 8, aggr="add", is_sorted=True, engine=Engine.TRITON).to(device) 83 | 84 | rgcn_conv_out = rgcn_conv(x, edge_index, edge_type) 85 | fasten_rgcn_conv_out = fasten_rgcn_conv(x, edge_index, edge_type, tensor_slice) 86 | 87 | assert fasten_rgcn_conv_out.shape == rgcn_conv_out.shape 88 | torch.testing.assert_close(fasten_rgcn_conv_out, rgcn_conv_out, rtol=1e-2, atol=1e-2) 89 | 90 | 91 | @pytest.mark.parametrize("device", ["cuda"]) 92 | def test_hgt(device: str): 93 | 94 | node_types = ['x', 'y', 'w', 'z'] 95 | x_dict = {'x': torch.randn(10, 15), 'y': torch.randn(15, 20), 'w': torch.randn(10, 10), 'z': torch.randn(25, 1)} 96 | edge_type = [('x', 'to', 'y'), ('y', 'to', 'x'), ('y', 'to', 'w'), ('y', 'to', 'z'), ('w', 'to', 'y'), ('z', 'to', 'y')] 97 | edge_index_dict = {('x', 'to', 'y'): torch.cat((torch.sort(torch.randint(10, (1, 25))).values, torch.sort(torch.randint(15, (1, 25))).values), dim=0), 98 | ('y', 'to', 'x'): torch.cat((torch.sort(torch.randint(15, (1, 25))).values, torch.sort(torch.randint(10, (1, 25))).values), dim=0), 99 | ('y', 'to', 'w'): torch.cat((torch.sort(torch.randint(15, (1, 30))).values, torch.sort(torch.randint(10, (1, 30))).values), dim=0), 100 | ('y', 'to', 'z'): torch.cat((torch.sort(torch.randint(15, (1, 10))).values, torch.sort(torch.randint(25, (1, 10))).values), dim=0), 101 | ('w', 'to', 'y'): torch.cat((torch.sort(torch.randint(10, (1, 15))).values, torch.sort(torch.randint(15, (1, 15))).values), dim=0), 102 | ('z', 'to', 'y'): torch.cat((torch.sort(torch.randint(25, (1, 20))).values, torch.sort(torch.randint(15, (1, 20))).values), dim=0) 103 | } 104 | 105 | meta_data = (node_types, edge_type) 106 | num_heads = 2 107 | tensor_slice_hl, type_vec, tensor_slice_hdl, slices_hdl = tensor_slice_gen(x_dict, edge_index_dict, meta_data, num_heads) 108 | 109 | torch.manual_seed(12345) 110 | hidden_channels = 32 111 | 112 | lin_dict = torch.nn.ModuleDict() 113 | for node_type in node_types: 114 | lin_dict[node_type] = Linear(-1, hidden_channels) 115 | 116 | x_dict = { 117 | node_type: lin_dict[node_type](x).relu_() 118 | for node_type, x in x_dict.items() 119 | } 120 | data = HeteroData() 121 | data.x_dict = x_dict 122 | data.edge_index_dict = edge_index_dict 123 | data = data.to(device) 124 | 125 | torch.manual_seed(12345) 126 | hgt_conv = HGTConv(hidden_channels, hidden_channels, meta_data, num_heads).to(device) 127 | 128 | torch.manual_seed(12345) 129 | fasten_hgt_conv = FastenHGTConv(hidden_channels, hidden_channels, meta_data, num_heads, engine=Engine.TRITON).to(device) 130 | 131 | hgt_conv_out = hgt_conv(data.x_dict, data.edge_index_dict) 132 | fasten_hgt_conv_out = fasten_hgt_conv(x_dict=data.x_dict, edge_index_dict=data.edge_index_dict, tensor_slice_hl=tensor_slice_hl, type_vec=type_vec, tensor_slice_hdl=tensor_slice_hdl, slices_hdl=slices_hdl) 133 | 134 | assert fasten_hgt_conv_out['x'].shape == hgt_conv_out['x'].shape 135 | torch.testing.assert_close(fasten_hgt_conv_out['x'], hgt_conv_out['x'], rtol=1e-2, atol=1e-2) 136 | 137 | 138 | @pytest.mark.parametrize("device", ["cuda"]) 139 | def test_heat(device: str): 140 | 141 | hidden_channels = 16 142 | num_heads = 2 143 | x = torch.randn(60, 25).to(device) 144 | edge_index = torch.cat((torch.cat((torch.sort(torch.randint(10, (1, 25))).values, torch.sort(torch.randint(15, (1, 25))).values), dim=0), 145 | torch.cat((torch.sort(torch.randint(15, (1, 25))).values, torch.sort(torch.randint(10, (1, 25))).values), dim=0), 146 | torch.cat((torch.sort(torch.randint(15, (1, 30))).values, torch.sort(torch.randint(10, (1, 30))).values), dim=0), 147 | torch.cat((torch.sort(torch.randint(15, (1, 10))).values, torch.sort(torch.randint(25, (1, 10))).values), dim=0), 148 | torch.cat((torch.sort(torch.randint(10, (1, 15))).values, torch.sort(torch.randint(15, (1, 15))).values), dim=0), 149 | torch.cat((torch.sort(torch.randint(25, (1, 20))).values, torch.sort(torch.randint(15, (1, 20))).values), dim=0)), dim=1) 150 | num_nodes = [10, 15, 10, 25] 151 | num_edges = [25, 25, 30, 10, 15, 20] 152 | node_type = [num for num, freq in enumerate(num_nodes) for _ in range(freq)] 153 | edge_type = [num for num, freq in enumerate(num_edges) for _ in range(freq)] 154 | lin_in = Linear(-1, hidden_channels).to(device) 155 | data = HeteroData() 156 | data.x = lin_in(x).relu_() 157 | data.edge_index = edge_index 158 | data.node_type = torch.tensor(node_type) 159 | data.edge_type = torch.tensor(edge_type) 160 | data.edge_attr = torch.randn((data.edge_index.shape[1], 2)) 161 | data = data.to(device) 162 | 163 | tensor_slice_hl = heat_tensor_slice_gen(data) 164 | 165 | torch.manual_seed(12345) 166 | heat_conv = HEATConv(hidden_channels, hidden_channels, len(torch.unique(data.node_type)), len(torch.unique(data.edge_type)), 5, 2, 6, num_heads, concat=False).to(device) 167 | torch.manual_seed(12345) 168 | fasten_heat_conv = FastenHEATConv(hidden_channels, hidden_channels, len(torch.unique(data.node_type)), len(torch.unique(data.edge_type)), 5, 2, 6, 169 | num_heads, concat=False, engine=Engine.TRITON).to(device) 170 | 171 | heat_conv_out = heat_conv(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr) 172 | fasten_heat_conv_out = fasten_heat_conv(data.x, data.edge_index, data.node_type, data.edge_type, data.edge_attr, tensor_slice_hl=tensor_slice_hl) 173 | 174 | assert fasten_heat_conv_out.shape == heat_conv_out.shape 175 | torch.testing.assert_close(fasten_heat_conv_out, heat_conv_out, rtol=1e-2, atol=1e-2) 176 | 177 | 178 | @pytest.mark.parametrize("device", ["cuda"]) 179 | def test_rgat(device: str): 180 | 181 | torch.manual_seed(12345) 182 | x = torch.randn(4, 16).to(device) 183 | edge_index = torch.tensor([ 184 | [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3], 185 | [1, 1, 1, 2, 1, 1, 1, 0, 1, 3, 1, 3], 186 | ]).to(device) 187 | edge_type = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7]).to(device) 188 | tensor_slice = compact_tensor_types(data=None, types=edge_type, is_sorted=True, device=device) 189 | assert tensor_slice is not None 190 | 191 | torch.manual_seed(12345) 192 | rgat_conv = RGATConv(16, 16, 8).to(device) 193 | torch.manual_seed(12345) 194 | fasten_rgat_conv = FastenRGATConv(16, 16, 8, engine=Engine.TRITON).to(device) 195 | 196 | rgat_conv_out = rgat_conv(x, edge_index, edge_type) 197 | fasten_rgat_conv_out = fasten_rgat_conv(x, edge_index, edge_type, tensor_slice=tensor_slice) 198 | 199 | assert fasten_rgat_conv_out.shape == rgat_conv_out.shape 200 | torch.testing.assert_close(fasten_rgat_conv_out, rgat_conv_out, rtol=1e-2, atol=1e-2) 201 | -------------------------------------------------------------------------------- /test/test_ops.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import pyg_lib 4 | import pytest 5 | import torch 6 | from utils import read_slices_from_csv 7 | 8 | from fasten import Engine, compact_tensor_types, ops 9 | from fasten.stats import get_matmul_bytes, get_matmul_flops 10 | from fasten.utils import GlobalConfig 11 | 12 | slices0 = [slice(0, 63), slice(63, 90), slice(90, 128)] 13 | slices1 = [slice(0, 127), slice(127, 256), slice(256, 257), slice(257, 512)] 14 | AIFB = read_slices_from_csv('datasets_csv/AIFB.csv') 15 | AM = read_slices_from_csv('datasets_csv/AM.csv') 16 | BGS = read_slices_from_csv('datasets_csv/BGS.csv') 17 | MUTAG = read_slices_from_csv('datasets_csv/MUTAG.csv') 18 | ACM = read_slices_from_csv('datasets_csv/ACM.csv') 19 | IMDB = read_slices_from_csv('datasets_csv/IMDB.csv') 20 | DBLP = read_slices_from_csv('datasets_csv/DBLP.csv') 21 | Freebase = read_slices_from_csv('datasets_csv/Freebase.csv') 22 | slices_obj = [("AIFB", AIFB), ("AM", AM), ("BGS", BGS), ("MUTAG", MUTAG), ("ACM", ACM), ("DBLP", DBLP), ("IMDB", IMDB), ("Freebase", Freebase)] 23 | 24 | # non-cudagraph tests are not stable on GPU, but pyg_lib only supports the cudagraph mode 25 | use_cudagraph = False 26 | 27 | 28 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 29 | @pytest.mark.parametrize("engine", [Engine.TORCH, Engine.TRITON]) 30 | @pytest.mark.parametrize("phase", ["forward", "backward"]) 31 | @pytest.mark.parametrize("dtype", ["float32"]) 32 | @pytest.mark.parametrize("slices", [slices0, slices1, AIFB, AM, BGS, MUTAG]) 33 | @pytest.mark.parametrize("K", [16, 32, 64, 80]) 34 | @pytest.mark.parametrize("deterministic", [True, False]) 35 | def test_segment_matmul(K: int, slices: list, engine: Engine, device: str, phase: str, dtype: str, deterministic: bool) -> None: 36 | if engine == Engine.TRITON and device == "cpu": 37 | pytest.skip("Triton does not support CPU inference") 38 | if device == "cpu" and dtype == "float16": 39 | pytest.skip("CPU does not support FP16") 40 | GlobalConfig.with_autotune = False 41 | GlobalConfig.deterministic = deterministic 42 | T = len(slices) 43 | dtype = getattr(torch, dtype) 44 | M = sum([s.stop - s.start for s in slices]) 45 | data = torch.randn((M, K), device=device, dtype=dtype) 46 | types = torch.zeros((M,), device=device, dtype=torch.int) 47 | rand_types = torch.randperm(T, device=device, dtype=torch.int) 48 | for i, s in enumerate(slices): 49 | types[s] = rand_types[i] 50 | tensor_slice = compact_tensor_types(data, types, device=device) 51 | other = torch.randn((T, K, K), device=device, dtype=dtype) 52 | if phase == "forward": 53 | output = ops.fasten_segment_matmul(tensor_slice.data, other, tensor_slice, engine) 54 | output_ref = torch.zeros((M, K), device=device, dtype=dtype) 55 | for i in range(len(tensor_slice)): 56 | s = tensor_slice.get_slice_from_index(i, is_tensor=False) 57 | t = tensor_slice.get_type_from_index(i, is_tensor=False) 58 | output_ref[s] = torch.matmul(tensor_slice.data[s], other[t]) 59 | torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-2) 60 | elif phase == "backward": 61 | tensor_slice.data.requires_grad = True 62 | other.requires_grad = True 63 | output = ops.fasten_segment_matmul(tensor_slice.data, other, tensor_slice, engine) 64 | output_grad = torch.randn_like(output) 65 | output.backward(output_grad) 66 | sorted_data_grad_ref = torch.zeros_like(data, dtype=dtype) 67 | other_grad_ref = torch.zeros_like(other, dtype=dtype) 68 | for i in range(len(tensor_slice)): 69 | s = tensor_slice.get_slice_from_index(i, is_tensor=False) 70 | t = tensor_slice.get_type_from_index(i, is_tensor=False) 71 | sorted_data_grad_ref[s] = torch.matmul(output_grad[s], other[t].t()) 72 | other_grad_ref[t] = torch.matmul(tensor_slice.data[s].t(), output_grad[s]) 73 | torch.testing.assert_close(tensor_slice.data.grad, sorted_data_grad_ref, atol=1e-1, rtol=1e-2) 74 | if M // T >= 2048: 75 | # gradient accumlation starts to be significantly different with large samples 76 | torch.testing.assert_close(other.grad, other_grad_ref, atol=1.0, rtol=1e-2) 77 | else: 78 | torch.testing.assert_close(other.grad, other_grad_ref, atol=1e-1, rtol=1e-2) 79 | if device == "cuda": 80 | torch.cuda.empty_cache() 81 | 82 | 83 | @pytest.fixture(scope="session") 84 | def session(): 85 | import triton.profiler as proton 86 | session_id = proton.start("benchmark_results", hook="triton") 87 | yield session_id 88 | 89 | proton.finalize() 90 | 91 | 92 | @pytest.mark.parametrize("phase", ["forward", "backward"]) 93 | @pytest.mark.parametrize("dtype", ["float32"]) 94 | @pytest.mark.parametrize("engine", ["fasten", "cutlass", "torch"]) 95 | @pytest.mark.parametrize("slices_name, slices", slices_obj) 96 | @pytest.mark.parametrize("K", [32, 64, 128]) 97 | def test_perf(phase: str, dtype: str, engine: str, slices_name: str, slices: list, K: int, session: Callable[[], None]) -> None: 98 | import triton.profiler as proton 99 | if engine == "cutlass" and dtype == "float16": 100 | pytest.skip("pyg_lib cutlass does not support float16") 101 | torch.backends.cuda.matmul.allow_tf32 = True 102 | GlobalConfig.with_perf_model = True 103 | T = len(slices) 104 | M = sum([s.stop - s.start for s in slices]) 105 | dtype = getattr(torch, dtype) 106 | data = torch.randn((M, K), device="cuda", dtype=dtype) 107 | types = torch.zeros((M,), device="cuda", dtype=torch.int) 108 | rand_types = torch.randperm(T, device="cuda", dtype=torch.int) 109 | for i, s in enumerate(slices): 110 | types[s] = rand_types[i] 111 | tensor_slice = compact_tensor_types(data, types, device="cuda") 112 | data = tensor_slice.data 113 | other = torch.randn((T, K, K), device="cuda", dtype=dtype) 114 | # ptr should be on CPU 115 | ptr = torch.tensor([s.start for s in slices] + [slices[-1].stop]) 116 | 117 | if phase == "backward": 118 | data.requires_grad = True 119 | other.requires_grad = True 120 | 121 | proton.deactivate(session) 122 | 123 | # warmup and get output 124 | if engine == "fasten": 125 | output = ops.fasten_segment_matmul(data, other, tensor_slice, Engine.AUTO) 126 | elif engine == "cutlass": 127 | output = pyg_lib.ops.segment_matmul(data, ptr, other) 128 | elif engine == "torch": 129 | output = ops.fasten_segment_matmul(data, other, tensor_slice, Engine.TORCH) 130 | 131 | if phase == "backward": 132 | grad = torch.randn_like(output) 133 | if engine == "cutlass": 134 | grouped_data = [] 135 | grouped_grad = [] 136 | for s in slices: 137 | if s.stop > s.start: 138 | grouped_data.append(data[s.start:s.stop, :].t()) 139 | grouped_grad.append(grad[s.start:s.stop, :]) 140 | 141 | def fasten_fn(): 142 | if phase == "forward": 143 | ops.fasten_segment_matmul(data, other, tensor_slice, Engine.AUTO if engine == "fasten" else Engine.TORCH) 144 | else: # phase == "backward" 145 | output.backward(grad, retain_graph=True) 146 | 147 | def cutlass_fn(): 148 | if phase == "forward": 149 | pyg_lib.ops.segment_matmul(data, ptr, other) 150 | else: # phase == "backward" 151 | # dx 152 | # [M, N] * [K, N]^T = [M, K]^T 153 | pyg_lib.ops.segment_matmul(grad, ptr, other.transpose(1, 2)) 154 | # dw 155 | # [M, K]^T * [M, N] = [K, N] 156 | pyg_lib.ops.grouped_matmul(grouped_data, grouped_grad) 157 | 158 | fn = cutlass_fn if engine == "pyg" else fasten_fn 159 | 160 | # warmup again to trigger backward kernels 161 | fn() 162 | proton.activate(session) 163 | with proton.scope(f"{slices_name}_{phase}_{engine}_{K}", metrics={"flops": get_matmul_flops(tensor_slice, other)}): 164 | fn() 165 | 166 | 167 | @pytest.mark.parametrize("phase", ["forward", "backward"]) 168 | @pytest.mark.parametrize("dtype", ["float32"]) 169 | @pytest.mark.parametrize("engine", ["fasten", "cutlass"]) 170 | @pytest.mark.parametrize("K", [32, 128]) 171 | @pytest.mark.parametrize("T", list(range(100, 2000, 200))) 172 | @pytest.mark.parametrize("M", [1000000]) 173 | def test_perf_random(phase: str, dtype: str, engine: str, K: int, T: int, M: int, session: Callable[[], None]): 174 | import triton.profiler as proton 175 | if engine == "cutlass" and dtype == "float16": 176 | pytest.skip("pyg_lib cutlass does not support float16") 177 | torch.backends.cuda.matmul.allow_tf32 = True 178 | torch.random.manual_seed(T) 179 | dtype = getattr(torch, dtype) 180 | data = torch.randn((M, K), device="cuda", dtype=dtype) 181 | types = torch.randint(0, T, (M,), device="cuda", dtype=torch.int) 182 | tensor_slice = compact_tensor_types(data, types, device="cuda") 183 | data = tensor_slice.data 184 | other = torch.randn((T, K, K), device="cuda", dtype=dtype) 185 | # ptr should be on CPU 186 | ptr = [] 187 | for i in range(len(tensor_slice)): 188 | ptr.append(tensor_slice.get_slice_from_index(i, is_tensor=False).start) 189 | ptr.append(tensor_slice.get_slice_from_index(len(tensor_slice) - 1, is_tensor=False).stop) 190 | ptr = torch.tensor(ptr) 191 | 192 | if phase == "backward": 193 | data.requires_grad = True 194 | other.requires_grad = True 195 | 196 | # warmup and get output 197 | if engine == "fasten": 198 | output = ops.fasten_segment_matmul(data, other, tensor_slice, Engine.AUTO) 199 | elif engine == "cutlass": 200 | output = pyg_lib.ops.segment_matmul(data, ptr, other) 201 | elif engine == "torch": 202 | output = ops.fasten_segment_matmul(data, other, tensor_slice, Engine.TORCH) 203 | 204 | if phase == "backward": 205 | grad = torch.randn_like(output) 206 | if engine == "cutlass": 207 | grouped_data = [] 208 | grouped_grad = [] 209 | for i in range(len(tensor_slice)): 210 | s = tensor_slice.get_slice_from_index(i, is_tensor=False) 211 | if s.stop > s.start: 212 | grouped_data.append(data[s.start:s.stop, :].t()) 213 | grouped_grad.append(grad[s.start:s.stop, :]) 214 | 215 | def fasten_fn(): 216 | if phase == "forward": 217 | ops.fasten_segment_matmul(data, other, tensor_slice, Engine.AUTO if engine == "fasten" else Engine.TORCH) 218 | else: # phase == "backward" 219 | output.backward(grad, retain_graph=True) 220 | 221 | def cutlass_fn(): 222 | if phase == "forward": 223 | pyg_lib.ops.segment_matmul(data, ptr, other) 224 | else: # phase == "backward" 225 | # dx 226 | # [M, N] * [K, N]^T = [M, K]^T 227 | pyg_lib.ops.segment_matmul(grad, ptr, other.transpose(1, 2)) 228 | # dw 229 | # [M, K]^T * [M, N] = [K, N] 230 | pyg_lib.ops.grouped_matmul(grouped_data, grouped_grad) 231 | 232 | fn = cutlass_fn if engine == "pyg" else fasten_fn 233 | fn() 234 | flops = get_matmul_flops(tensor_slice, other) 235 | flops = 2 * flops if phase == "backward" else flops 236 | bytes = get_matmul_bytes(tensor_slice, other) 237 | with proton.scope(f"random_{phase}_{engine}_{K}_{T}", metrics={"flops": flops, "bytes": bytes}): 238 | fn() 239 | 240 | 241 | def test_cache(): 242 | M = 128 243 | K = 16 244 | T = 16 245 | data = torch.randn((M, K), device='cuda', dtype=torch.float32) 246 | types = torch.zeros((M,), device='cuda', dtype=torch.int) 247 | slices = [slice(0, 63), slice(63, 90), slice(90, 128)] 248 | for s in slices: 249 | if s.stop > s.start: 250 | types[s] = torch.randint(0, T, (s.stop - s.start,), device='cuda', dtype=torch.int) 251 | tensor_slice = compact_tensor_types(data, types, device='cuda') 252 | other = torch.randn((T, K, K), device='cuda', dtype=torch.float32) 253 | ops.fasten_segment_matmul(tensor_slice.data, other, tensor_slice, Engine.TRITON) 254 | assert len(tensor_slice._cache) == 1 255 | assert len(tensor_slice._cache['segment_matmul_forward']) == 1 256 | -------------------------------------------------------------------------------- /test/test_stats.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from fasten import compact_tensor_types 5 | from fasten.stats import get_matmul_flops 6 | 7 | 8 | @pytest.mark.parametrize('device', ['cpu', 'cuda']) 9 | def test_matmul_flops(device): 10 | data = torch.tensor([[1, 2, 3, 4], [3, 4, 5, 6], [5, 6, 7, 8]], device=device) 11 | types = torch.tensor([2, 1, 2], dtype=torch.int, device=device) 12 | tensor_slice = compact_tensor_types(data, types, device=device) 13 | weight = torch.randn((3, 4, 5), device=device) 14 | flops = get_matmul_flops(tensor_slice, weight) 15 | flops_ref = 2 * 4 * 5 * 2 + 4 * 5 * 2 16 | assert flops == flops_ref, f"{flops} != {flops_ref}" 17 | -------------------------------------------------------------------------------- /test/test_tensor_slice.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import triton 4 | 5 | from fasten import compact_tensor_types 6 | 7 | 8 | @pytest.mark.parametrize('device', ['cpu', 'cuda']) 9 | @pytest.mark.parametrize('dim', [0, 1]) 10 | def test_compact_tensor_types(device: str, dim: int): 11 | data = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]], device=device) 12 | types = torch.tensor([2, 1, 2], dtype=torch.int, device=device) 13 | tensor_slice = compact_tensor_types(data, types, dim=dim, device=device) 14 | if dim == 0: 15 | assert tensor_slice.data[0].tolist() == [3, 4, 5] 16 | else: 17 | assert tensor_slice.data[:, 0].tolist() == [2, 4, 6] 18 | slice = tensor_slice.get_slice_from_type(2) 19 | assert slice[0] == 1 20 | assert slice[1] == 3 21 | index_slice = tensor_slice.get_slice_from_index(1) 22 | assert torch.equal(index_slice, slice) 23 | type = tensor_slice.get_type_from_index(1) 24 | assert type == 2 25 | 26 | 27 | @pytest.mark.parametrize('device', ['cpu', 'cuda']) 28 | @pytest.mark.parametrize('tile_size', [1, 2, 3, 16, 128]) 29 | @pytest.mark.parametrize('block_size', [1]) 30 | def test_tiling_default(tile_size: int, block_size: int, device: str): 31 | data = torch.ones((128, 128), device=device) 32 | types = torch.zeros(128, dtype=torch.int, device=device) 33 | types[63:90] = 2 34 | types[90:128] = 3 35 | types[0:63] = 1 36 | tensor_slice = compact_tensor_types(data, types, device=device) 37 | tensor_tile = tensor_slice.tiling(tile_size, block_size=block_size) 38 | num_slices = triton.cdiv(90 - 63, tile_size) + triton.cdiv(128 - 90, tile_size) + triton.cdiv(63, tile_size) 39 | avg_tile_size = 128 / num_slices 40 | # calculate stddev tile size 41 | stddev_tile_size = 0 42 | for i in range(len(tensor_tile.slices)): 43 | slice = tensor_tile.get_slice_from_index(i, is_tensor=False) 44 | stddev_tile_size += ((slice.stop - slice.start) - avg_tile_size) ** 2 45 | stddev_tile_size = (stddev_tile_size / num_slices) ** 0.5 46 | assert len(tensor_tile) == num_slices 47 | 48 | torch.testing.assert_close(tensor_tile.avg_tile_size, avg_tile_size) 49 | torch.testing.assert_close(tensor_tile.stddev_tile_size, stddev_tile_size) 50 | 51 | 52 | @pytest.mark.parametrize('device', ['cpu', 'cuda']) 53 | @pytest.mark.parametrize('tile_size', [1, 2, 3, 16, 128]) 54 | @pytest.mark.parametrize('block_size', [1]) 55 | def test_tile_slice_mapping(tile_size: int, block_size: int, device: str): 56 | data = torch.ones((128, 128), device=device) 57 | types = torch.zeros(128, dtype=torch.int, device=device) 58 | types[63:90] = 2 59 | types[90:128] = 3 60 | types[0:63] = 1 61 | tensor_slice = compact_tensor_types(data, types, device=device) 62 | tensor_tile = tensor_slice.tiling(tile_size, block_size=block_size) 63 | slice_tile_mapping = tensor_tile.slice_tile_mapping 64 | assert slice_tile_mapping[-1][2] == len(tensor_tile) 65 | -------------------------------------------------------------------------------- /test/test_triton.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from fasten import compact_tensor_types 5 | from fasten.operators import triton_ops 6 | from fasten.utils import GlobalConfig, TilingMethod 7 | 8 | GlobalConfig.with_autotune = False 9 | 10 | 11 | @pytest.mark.parametrize("phase", ["forward", "backward"]) 12 | @pytest.mark.parametrize("dtype", ["float32", "float16"]) 13 | @pytest.mark.parametrize("M", [128, 1024]) 14 | @pytest.mark.parametrize("T", [16, 33]) 15 | @pytest.mark.parametrize("tile_size", [16, 64]) 16 | @pytest.mark.parametrize("block_size", [1, 4]) 17 | @pytest.mark.parametrize("K", [16, 32, 80]) 18 | @pytest.mark.parametrize("device", ["cuda"]) 19 | @pytest.mark.parametrize("tiling_method", ["balanced", "default"]) 20 | @pytest.mark.parametrize("deterministic", [True, False]) 21 | def test_segment_matmul(M: int, K: int, T: int, phase: str, dtype: str, tile_size: int, block_size: int, device: str, tiling_method: str, deterministic: bool) -> None: 22 | if not deterministic and phase == "forward": 23 | pytest.skip("Non-deterministic test is not supported for forward pass") 24 | dtype = getattr(torch, dtype) 25 | data = torch.randn((M, K), dtype=dtype, device=device) 26 | types = torch.randint(0, T, (M,), device=device, dtype=torch.int) 27 | tensor_slice = compact_tensor_types(data, types, device=device) 28 | other = torch.randn((T, K, K), device=device, dtype=dtype) 29 | tiling_method = getattr(TilingMethod, tiling_method.upper()) 30 | if phase == "forward": 31 | input_tiles = tensor_slice.tiling(tile_size, method=tiling_method, block_size=block_size) 32 | output = triton_ops.segment_matmul_forward(tensor_slice.data, other, input_tiles.slices, input_slices=tensor_slice.slices, 33 | tile_size=tile_size, out_dtype=torch.float32, 34 | num_blocks=input_tiles.num_blocks, block_size=input_tiles.block_size, 35 | deterministic=deterministic, slice_tile_mapping=input_tiles.slice_tile_mapping, 36 | avg_tile_size=input_tiles.avg_tile_size, stddev_tile_size=input_tiles.stddev_tile_size) 37 | output_ref = torch.zeros((M, K), dtype=dtype, device="cuda") 38 | for i in range(len(tensor_slice)): 39 | s = tensor_slice.get_slice_from_index(i, is_tensor=False) 40 | t = tensor_slice.get_type_from_index(i, is_tensor=False) 41 | output_ref[s] = torch.matmul(tensor_slice.data[s], other[t]) 42 | torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-2) 43 | elif phase == "backward": 44 | input_tiles = tensor_slice.tiling(tile_size, method=tiling_method, block_size=block_size) 45 | output = triton_ops.segment_matmul_forward(tensor_slice.data, other, input_tiles.slices, input_slices=tensor_slice.slices, 46 | tile_size=tile_size, num_blocks=input_tiles.num_blocks, 47 | block_size=input_tiles.block_size, 48 | deterministic=deterministic, slice_tile_mapping=input_tiles.slice_tile_mapping, 49 | avg_tile_size=input_tiles.avg_tile_size, stddev_tile_size=input_tiles.stddev_tile_size) 50 | output_grad = torch.randn_like(output) 51 | grad_input = triton_ops.segment_matmul_backward_input(tensor_slice.data, output_grad, other, input_tiles.slices, 52 | input_slices=tensor_slice.slices, tile_size=tile_size, 53 | num_blocks=input_tiles.num_blocks, block_size=input_tiles.block_size, 54 | avg_tile_size=input_tiles.avg_tile_size, stddev_tile_size=input_tiles.stddev_tile_size) 55 | grad_tiles = tensor_slice.tiling(tile_size, method=TilingMethod.DEFAULT, block_size=block_size) 56 | grad_other = triton_ops.segment_matmul_backward_other(tensor_slice.data, output_grad, other, grad_tiles.slices, 57 | input_slices=tensor_slice.slices, tile_size=tile_size, 58 | num_blocks=grad_tiles.num_blocks, block_size=grad_tiles.block_size, 59 | deterministic=deterministic, slice_tile_mapping=grad_tiles.slice_tile_mapping, 60 | avg_tile_size=input_tiles.avg_tile_size, stddev_tile_size=input_tiles.stddev_tile_size) 61 | sorted_data_grad_ref = torch.zeros_like(data, dtype=dtype) 62 | other_grad_ref = torch.zeros_like(other, dtype=dtype) 63 | for i in range(len(tensor_slice)): 64 | s = tensor_slice.get_slice_from_index(i, is_tensor=False) 65 | t = tensor_slice.get_type_from_index(i, is_tensor=False) 66 | sorted_data_grad_ref[s] = torch.matmul(output_grad[s], other[t].t()) 67 | other_grad_ref[t] = torch.matmul(tensor_slice.data[s].t(), output_grad[s]) 68 | torch.testing.assert_close(grad_input, sorted_data_grad_ref, atol=1e-1, rtol=1e-2) 69 | torch.testing.assert_close(grad_other, other_grad_ref, atol=1e-1, rtol=1e-2) 70 | -------------------------------------------------------------------------------- /test/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | 4 | def read_slices_from_csv(csv_file): 5 | slices = [] 6 | 7 | with open(csv_file, mode='r') as file: 8 | reader = csv.DictReader(file) 9 | for row in reader: 10 | start = int(row["Start"]) 11 | end = int(row["End"]) 12 | slices.append(slice(start, end)) 13 | 14 | return slices 15 | --------------------------------------------------------------------------------