├── tests ├── __init__.py ├── conftest.py ├── test_dropout_layer.py ├── test_softmax_layers.py ├── test_cross_entropy_loss_layer.py ├── test_nll_loss_layer.py ├── test_p_loss_layers.py ├── test_glu_layer.py ├── test_pooling_layer.py ├── test_act_layers.py ├── test_rms_norm_layer.py ├── test_layer_norm_layer.py ├── test_linear_layer.py ├── test_conv_layer.py ├── test_multi_head_attention_layer.py ├── utils.py └── test_batch_norm_layer.py ├── examples ├── __init__.py ├── mnist │ ├── README.md │ ├── mlp.py │ └── main.py ├── regression │ ├── README.md │ └── main.py ├── README.md ├── wikitext-2 │ ├── README.md │ ├── gpt.py │ └── main.py ├── imagenette │ └── README.md └── utils.py ├── attorch ├── nn.py ├── types.py ├── __init__.py ├── utils.py ├── dropout_kernels.py ├── dropout_layer.py ├── glu_kernels.py ├── softmax_layers.py ├── glu_layer.py ├── pooling_layer.py ├── p_loss_kernels.py ├── rms_norm_layer.py ├── multi_head_attention_layer.py ├── cross_entropy_loss_layer.py ├── nll_loss_layer.py ├── p_loss_layers.py ├── cross_entropy_loss_kernels.py ├── softmax_kernels.py ├── layer_norm_layer.py ├── linear_layer.py ├── linear_kernels.py ├── rms_norm_kernels.py └── math.py ├── LICENSE └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep learning examples using attorch. 3 | """ 4 | -------------------------------------------------------------------------------- /attorch/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interface for attorch with PyTorch fallback. 3 | """ 4 | 5 | 6 | from torch.nn import * 7 | from attorch import * 8 | from torch.nn import AvgPool1d, AvgPool2d, Conv1d, Conv2d 9 | -------------------------------------------------------------------------------- /attorch/types.py: -------------------------------------------------------------------------------- 1 | """ 2 | Type aliases for certain objects. 3 | """ 4 | 5 | 6 | from typing import Any, Optional, Union 7 | 8 | import torch 9 | 10 | 11 | Context = Any 12 | Device = Optional[Union[torch.device, str]] 13 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fixture for subset, a switch for running tests on a small subset of shapes. 3 | """ 4 | 5 | 6 | import pytest 7 | from _pytest.config.argparsing import Parser 8 | from pytest import FixtureRequest 9 | 10 | 11 | def pytest_addoption(parser: Parser) -> None: 12 | parser.addoption('--subset', 13 | action='store_true', 14 | help='Flag to test on a small subset of shapes.') 15 | 16 | 17 | @pytest.fixture 18 | def subset(request: FixtureRequest) -> bool: 19 | return request.config.getoption('--subset') 20 | -------------------------------------------------------------------------------- /examples/mnist/README.md: -------------------------------------------------------------------------------- 1 | # MNIST Classification 2 | 3 | This example trains and benchmarks a multilayer perceptron (MLP) on the MNIST dataset. 4 | 5 | ## Requirements 6 | 7 | The requirement for this example, aside from attorch and its dependencies, is, 8 | 9 | * ```torchvision==0.19.0``` 10 | 11 | ## Training 12 | 13 | To run this example, please run ```python -m examples.mnist.main``` from the root directory. The arguments are as follows. 14 | * ```--hidden_dim```: Number of hidden features in the MLP. 15 | * ```--depth```: Number of hidden layers in the MLP. 16 | * ```--epochs```: Number of epochs to train for. 17 | * ```--batch_size```: Batch size. 18 | -------------------------------------------------------------------------------- /examples/regression/README.md: -------------------------------------------------------------------------------- 1 | # Synthetic Regression 2 | 3 | This example trains and benchmarks a multilayer perceptron (MLP) on synthetic regression data. It relies on ```attorch.nn``` to demonstrate how attorch can be used as a drop-in replacement for PyTorch without kernel fusion or other enhancements. 4 | 5 | ## Requirements 6 | 7 | This example has no requirements aside from attorch and its dependencies. 8 | 9 | ## Training 10 | 11 | To run this example, please run ```python -m examples.regression.main``` from the root directory. The arguments are as follows. 12 | * ```--n_samples```: Number of samples to generate. 13 | * ```--dim```: Dimensionality of synthetic data. 14 | * ```--epochs```: Number of epochs to train for. 15 | * ```--batch_size```: Batch size. 16 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This folder - inspired by [```pytorch/examples```](https://github.com/pytorch/examples/) - contains a few common deep learning examples whose backends can be switched between PyTorch and attorch to demonstrate attorch's usage and time it against PyTorch. Each example implements relevant models with the option to switch between PyTorch and attorch, and backend-agnostic training can be conducted by running the designated ```main.py``` file, which also performs benchmarks. 4 | 5 | The examples are listed below. 6 | 7 | * [Imagenette image classification](https://github.com/bobmcdear/attorch/examples/imagenette) 8 | * [WikiText-2 language modelling](https://github.com/bobmcdear/attorch/examples/wikitext-2) 9 | * [MNIST classification](https://github.com/bobmcdear/attorch/examples/mnist) 10 | * [Synthetic regression](https://github.com/bobmcdear/attorch/examples/regression) 11 | -------------------------------------------------------------------------------- /examples/wikitext-2/README.md: -------------------------------------------------------------------------------- 1 | # WikiText-2 Language Modelling 2 | 3 | This example trains and benchmarks a language model on the WikiText-2 dataset. 4 | 5 | ## Requirements 6 | 7 | The requirements for this example, aside from attorch and its dependencies, are, 8 | 9 | * ```datasets==2.18.0``` 10 | * ```transformers==4.39.0``` 11 | 12 | ## Training 13 | 14 | To run this example, please run ```python -m examples.wikitext-2.main``` from the root directory. The arguments are as follows. 15 | 16 | * ```--model```: Name of language model to train. The only option is ```gpt2```. 17 | * `--downsize`: The depth and width of the model are calculated by dividing GPT2's original depth and width by this factor. 18 | * ```--scheduler```: Learning rate scheduler. Options are `one-cycle` and `cosine`. 19 | * ```--epochs```: Number of epochs to train for. 20 | * ```--batch_size```: Batch size. 21 | * ```--seq_len```: Sequence length. 22 | * ```--num_workers```: Number of workers for data loading. 23 | -------------------------------------------------------------------------------- /attorch/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A subset of PyTorch's neural network modules, 3 | written in Python using OpenAI's Triton. 4 | """ 5 | 6 | 7 | from . import math, nn 8 | from .act_layers import CELU, ELU, GELU, Hardshrink, Hardsigmoid, Hardswish, Hardtanh, LeakyReLU, \ 9 | LogSigmoid, Mish, ReLU, ReLU6, SELU, SiLU, Sigmoid, Softplus, Softshrink, Softsign, Tanh, Tanhshrink 10 | from .batch_norm_layer import BatchNorm1d, BatchNorm2d 11 | from .conv_layer import Conv1d, Conv2d 12 | from .cross_entropy_loss_layer import CrossEntropyLoss 13 | from .dropout_layer import Dropout 14 | from .glu_layer import GLU 15 | from .layer_norm_layer import LayerNorm 16 | from .linear_layer import Linear 17 | from .multi_head_attention_layer import MultiheadAttention 18 | from .nll_loss_layer import NLLLoss 19 | from .p_loss_layers import HuberLoss, L1Loss, MSELoss, SmoothL1Loss 20 | from .pooling_layer import AvgPool1d, AvgPool2d 21 | from .rms_norm_layer import RMSNorm 22 | from .softmax_layers import LogSoftmax, Softmax, Softmin 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Borna Ahmadzadeh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_dropout_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | 6 | import attorch 7 | from .utils import assert_close, create_input, create_input_like, default_shapes 8 | 9 | 10 | @pytest.mark.parametrize('shape', default_shapes()) 11 | @pytest.mark.parametrize('drop_p', [0.0, 0.15, 0.3, 0.5, 0.75, 0.9, 1.0]) 12 | def test_dropout_layer(shape: Tuple[int, ...], drop_p: float, subset: bool) -> None: 13 | if subset and (shape not in default_shapes(subset=True)): 14 | return 15 | 16 | input = create_input(shape) 17 | dropout = attorch.Dropout(drop_p) 18 | output = dropout(input) 19 | n_zeroed = (torch.count_nonzero(input) - torch.count_nonzero(output)).item() 20 | 21 | if drop_p == 0: 22 | assert n_zeroed == 0 23 | 24 | elif drop_p == 1: 25 | assert torch.count_nonzero(output).item() == 0 26 | 27 | else: 28 | assert_close((output, torch.where(output == 0, output, input / (1 - drop_p)))) 29 | assert_close((n_zeroed / input.numel(), drop_p), rtol=1e-1, atol=5e-2) 30 | 31 | output_grad = create_input_like(output) 32 | output.backward(output_grad) 33 | input_grad = torch.where(output == 0, output, output_grad / (1 - drop_p)) 34 | 35 | assert_close((input.grad, input_grad)) 36 | -------------------------------------------------------------------------------- /examples/imagenette/README.md: -------------------------------------------------------------------------------- 1 | # Imagenette Image Classification 2 | 3 | This example trains and benchmarks a vision model on the Imagenette classification dataset. 4 | 5 | ## Requirements 6 | 7 | The requirement for this example, aside from attorch and its dependencies, is, 8 | 9 | * ```torchvision==0.19.0``` 10 | 11 | ## Training 12 | 13 | To run this example, please run ```python -m examples.imagenette.main``` from the root directory. The arguments are as follows. 14 | * ```--model```: Name of vision model to train. Options are `resnet18`, `resnet34`, ```resnet14```, ```resnet26```, ```resnet50```, ```resnet101```, ```resnet152```, ```convnext_atto```, ```convnext_femto```, ```convnext_pico```, ```convnext_nano```, ```convnext_tiny```, ```convnext_small```, ```convnext_base```, ```convnext_large```, ```convnext_xlarge```, `vit_tiny_patch16`, `vit_xsmall_patch16`, `vit_small_patch32`, `vit_small_patch16`, `vit_small_patch8`, `vit_medium_patch32`, `vit_medium_patch16`, `vit_base_patch32`, `vit_base_patch16`, `vit_base_patch8`, `vit_large_patch32`, `vit_large_patch16`, and `vit_large_patch14`. 15 | * ```--scheduler```: Learning rate scheduler. Options are `one-cycle` and `cosine`. 16 | * ```--epochs```: Number of epochs to train for. 17 | * ```--batch_size```: Batch size. 18 | * ```--center_crop_size```: Center crop size for validation. 19 | * ```--image_size```: Input image size. 20 | * ```--num_workers```: Number of workers for data loading. 21 | -------------------------------------------------------------------------------- /tests/test_softmax_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | 7 | import attorch 8 | from .utils import assert_close, create_input, create_input_like, default_shapes 9 | 10 | 11 | @pytest.mark.parametrize('shape', default_shapes()) 12 | @pytest.mark.parametrize('softmax', ['Softmax', 'LogSoftmax', 'Softmin']) 13 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 14 | @pytest.mark.parametrize('amp', [False, True]) 15 | def test_softmax_layers( 16 | shape: Tuple[int, ...], 17 | softmax: str, 18 | input_dtype: bool, 19 | amp: bool, 20 | subset: bool, 21 | ) -> None: 22 | if subset and (shape not in default_shapes(subset=True)): 23 | return 24 | 25 | if input_dtype is torch.float16 and not amp: 26 | return 27 | 28 | attorch_input = create_input(shape, dtype=input_dtype) 29 | pytorch_input = create_input(shape, dtype=input_dtype) 30 | 31 | attorch_softmax = getattr(attorch, softmax)(dim=-1) 32 | pytorch_softmax = getattr(nn, softmax)(dim=-1) 33 | 34 | with autocast('cuda', enabled=amp): 35 | attorch_output = attorch_softmax(attorch_input) 36 | pytorch_output = pytorch_softmax(pytorch_input) 37 | 38 | assert_close((attorch_output, pytorch_output)) 39 | 40 | attorch_output.backward(create_input_like(attorch_output)) 41 | pytorch_output.backward(create_input_like(pytorch_output)) 42 | 43 | assert_close((attorch_input.grad, pytorch_input.grad)) 44 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for examples. 3 | """ 4 | 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | from torch import autocast, nn 10 | from triton.testing import do_bench 11 | 12 | 13 | def benchmark_fw_and_bw( 14 | model: nn.Module, 15 | amp: bool = True, 16 | **input, 17 | ) -> Tuple[float, float, float]: 18 | """ 19 | Benchmarks the forward and backward pass of a model. 20 | 21 | Args: 22 | model: Model to run. 23 | amp: Flag for running the forward pass using automatic mixed precision. 24 | **input: Input to the model. 25 | """ 26 | with autocast('cuda', enabled=amp): 27 | fw = do_bench(lambda: model(**input)) 28 | 29 | with autocast('cuda', enabled=amp): 30 | output = model(**input) 31 | output_grad = torch.randn_like(output) 32 | bw = do_bench(lambda: output.backward(output_grad, retain_graph=True)) 33 | 34 | print(f'Forward pass mean execution time: {fw}') 35 | print(f'Backward pass mean execution time: {bw}') 36 | print(f'Forward plus backward pass mean execution time: {fw+bw}') 37 | 38 | 39 | class AvgMeter: 40 | """ 41 | Keeps track of the running average of a series of values. 42 | """ 43 | def __init__(self) -> None: 44 | self.reset() 45 | 46 | def reset(self) -> None: 47 | self.sum = 0 48 | self.count = 0 49 | 50 | def update(self, val: float, count: int = 1) -> None: 51 | self.sum += count * val 52 | self.count += count 53 | 54 | @property 55 | def avg(self): 56 | return self.sum / self.count 57 | -------------------------------------------------------------------------------- /tests/test_cross_entropy_loss_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | 7 | import attorch 8 | from .utils import assert_close, create_input, default_shapes 9 | 10 | 11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2, max_dim=2)) 12 | @pytest.mark.parametrize('weighted', [False, True]) 13 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 14 | @pytest.mark.parametrize('amp', [False, True]) 15 | def test_cross_entropy_loss_layer( 16 | shape: Tuple[int, ...], 17 | weighted: bool, 18 | input_dtype: bool, 19 | amp: bool, 20 | subset: bool, 21 | ) -> None: 22 | if subset and (shape not in default_shapes(subset=True)): 23 | return 24 | 25 | if input_dtype is torch.float16 and not amp: 26 | return 27 | 28 | attorch_input = create_input(shape, dtype=input_dtype) 29 | pytorch_input = create_input(shape, dtype=input_dtype) 30 | target = torch.randint(0, shape[1], 31 | size=(shape[0],), 32 | device='cuda') 33 | weight = (torch.randn(shape[1], device='cuda') 34 | if weighted else None) 35 | 36 | attorch_loss = attorch.CrossEntropyLoss(weight=weight) 37 | pytorch_loss = nn.CrossEntropyLoss(weight=weight) 38 | 39 | with autocast('cuda', enabled=amp): 40 | attorch_output = attorch_loss(attorch_input, target) 41 | pytorch_output = pytorch_loss(pytorch_input, target) 42 | 43 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3) 44 | 45 | attorch_output.backward() 46 | pytorch_output.backward() 47 | 48 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-3) 49 | -------------------------------------------------------------------------------- /examples/mnist/mlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multilayer perceptron (MLP) for MNIST classification. 3 | """ 4 | 5 | 6 | from typing import Optional 7 | 8 | from torch import Tensor 9 | from torch import nn 10 | 11 | import attorch 12 | 13 | 14 | class MLP(nn.Module): 15 | """ 16 | Transforms the input using a multilayer perceptron (MLP) with an arbitrary 17 | number of hidden layers, optionally computing the loss if targets are passed. 18 | 19 | Args: 20 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend. 21 | in_dim: Number of input features. 22 | hidden_dim: Number of hidden features. 23 | depth: Number of hidden layers. 24 | num_classes: Number of output classes. 25 | """ 26 | def __init__( 27 | self, 28 | use_attorch: bool, 29 | in_dim: int = 784, 30 | hidden_dim: int = 128, 31 | depth: int = 1, 32 | num_classes: int = 10, 33 | ) -> None: 34 | super().__init__() 35 | 36 | backend = attorch if use_attorch else nn 37 | layer_fn = lambda dim: ([attorch.Linear(dim, hidden_dim, act_func='relu')] 38 | if use_attorch else [nn.Linear(dim, hidden_dim), nn.ReLU()]) 39 | 40 | layers = layer_fn(in_dim) 41 | for _ in range(depth - 1): 42 | layers += layer_fn(hidden_dim) 43 | layers.append(backend.Linear(hidden_dim, num_classes)) 44 | 45 | self.layers = nn.Sequential(*layers) 46 | self.loss_func = backend.CrossEntropyLoss() 47 | 48 | def forward(self, input: Tensor, target: Optional[Tensor] = None) -> Tensor: 49 | output = self.layers(input.flatten(start_dim=1)) 50 | return output if target is None else self.loss_func(output, target) 51 | -------------------------------------------------------------------------------- /tests/test_nll_loss_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | 7 | import attorch 8 | from .utils import assert_close, create_input, create_input_like, default_shapes 9 | 10 | 11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2)) 12 | @pytest.mark.parametrize('reduction', ['none', 'mean', 'sum']) 13 | @pytest.mark.parametrize('weighted', [False, True]) 14 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 15 | @pytest.mark.parametrize('amp', [False, True]) 16 | def test_nll_loss_layer( 17 | shape: Tuple[int, ...], 18 | reduction: str, 19 | weighted: bool, 20 | input_dtype: bool, 21 | amp: bool, 22 | subset: bool, 23 | ) -> None: 24 | if subset and (shape not in default_shapes(subset=True)): 25 | return 26 | 27 | if input_dtype is torch.float16 and not amp: 28 | return 29 | 30 | attorch_input = create_input(shape) 31 | pytorch_input = create_input(shape) 32 | target = torch.randint(0, shape[1], 33 | size=(shape[0], *shape[2:]), 34 | device='cuda') 35 | weight = (torch.randn(shape[1], device='cuda', dtype=torch.float32) 36 | if weighted else None) 37 | 38 | attorch_loss = attorch.NLLLoss(reduction=reduction, weight=weight) 39 | pytorch_loss = nn.NLLLoss(reduction=reduction, weight=weight) 40 | 41 | with autocast('cuda', enabled=amp): 42 | attorch_output = attorch_loss(attorch_input, target) 43 | pytorch_output = pytorch_loss(pytorch_input, target) 44 | 45 | assert_close((attorch_output, pytorch_output)) 46 | 47 | attorch_output.backward(create_input_like(attorch_output)) 48 | pytorch_output.backward(create_input_like(pytorch_output)) 49 | 50 | assert_close((attorch_input.grad, pytorch_input.grad)) 51 | -------------------------------------------------------------------------------- /tests/test_p_loss_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | 7 | import attorch 8 | from .utils import assert_close, create_input, create_input_like, default_shapes 9 | 10 | 11 | @pytest.mark.parametrize('shape', default_shapes()) 12 | @pytest.mark.parametrize('p_loss', ['SmoothL1Loss', 'L1Loss', 'MSELoss', 'HuberLoss']) 13 | @pytest.mark.parametrize('reduction', ['none', 'mean', 'sum']) 14 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 15 | @pytest.mark.parametrize('amp', [False, True]) 16 | def test_p_loss_layers( 17 | shape: Tuple[int, ...], 18 | p_loss: str, 19 | reduction: str, 20 | input_dtype: bool, 21 | amp: bool, 22 | subset: bool, 23 | ) -> None: 24 | if subset and (shape not in default_shapes(subset=True)): 25 | return 26 | 27 | if input_dtype is torch.float16 and not amp: 28 | return 29 | 30 | attorch_input = create_input(shape, dtype=input_dtype) 31 | attorch_target = create_input(shape, dtype=input_dtype, seed=1) 32 | 33 | pytorch_input = create_input(shape, dtype=input_dtype) 34 | pytorch_target = create_input(shape, dtype=input_dtype, seed=1) 35 | 36 | attorch_loss = getattr(attorch, p_loss)(reduction=reduction) 37 | pytorch_loss = getattr(nn, p_loss)(reduction=reduction) 38 | 39 | with autocast('cuda', enabled=amp): 40 | attorch_output = attorch_loss(attorch_input, attorch_target) 41 | pytorch_output = pytorch_loss(pytorch_input, pytorch_target) 42 | 43 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3) 44 | 45 | attorch_output.backward(create_input_like(attorch_output)) 46 | pytorch_output.backward(create_input_like(pytorch_output)) 47 | 48 | assert_close((attorch_input.grad, pytorch_input.grad), 49 | (attorch_target.grad, pytorch_target.grad), 50 | rtol=1e-3, atol=1e-3) 51 | -------------------------------------------------------------------------------- /tests/test_glu_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast 6 | from torch.nn import functional as F 7 | 8 | import attorch 9 | from .utils import assert_close, create_input, create_input_like, default_shapes 10 | 11 | 12 | @pytest.mark.parametrize('shape', default_shapes(max_dim=3)) 13 | @pytest.mark.parametrize('act_func', ['sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'silu', 14 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 15 | 'mish', 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_0.01', 16 | 'elu_1', 'celu_1', 'hardshrink_0.5', 'softshrink_0.5']) 17 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 18 | @pytest.mark.parametrize('amp', [False, True]) 19 | def test_glu_layer( 20 | shape: Tuple[int, ...], 21 | act_func: str, 22 | input_dtype: bool, 23 | amp: bool, 24 | subset: bool, 25 | ) -> None: 26 | if subset and (shape not in default_shapes(subset=True)): 27 | return 28 | 29 | if input_dtype is torch.float16 and not amp: 30 | return 31 | 32 | attorch_input = create_input(shape, dtype=input_dtype) 33 | pytorch_input = create_input(shape, dtype=input_dtype) 34 | 35 | attorch_glu = attorch.GLU(act_func=act_func) 36 | pytorch_glu = lambda input1, input2: input1 * getattr(F, act_func.rsplit('_', 1)[0])(input2) 37 | 38 | with autocast('cuda', enabled=amp): 39 | attorch_output = attorch_glu(attorch_input) 40 | pytorch_output = pytorch_glu(*pytorch_input.chunk(2, dim=-1)) 41 | 42 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3) 43 | 44 | attorch_output.backward(create_input_like(attorch_output)) 45 | pytorch_output.backward(create_input_like(pytorch_output)) 46 | 47 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-3) 48 | -------------------------------------------------------------------------------- /tests/test_pooling_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | 7 | import attorch 8 | from .utils import assert_close, create_input, create_input_like, default_shapes 9 | 10 | 11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=3, max_dim=4)) 12 | @pytest.mark.parametrize('kernel_size', [2, 3, 4]) 13 | @pytest.mark.parametrize('stride', [None, 1, 2]) 14 | @pytest.mark.parametrize('padding', [0, 1, -1]) 15 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 16 | @pytest.mark.parametrize('amp', [False, True]) 17 | def test_pooling_layer( 18 | shape: Tuple[int, ...], 19 | kernel_size: Union[int, Tuple[int, int]], 20 | stride: Union[int, Tuple[int, int]], 21 | padding: Union[int, Tuple[int, int]], 22 | input_dtype: bool, 23 | amp: bool, 24 | subset: bool, 25 | ) -> None: 26 | if subset and (shape not in default_shapes(subset=True)): 27 | return 28 | 29 | if input_dtype is torch.float16 and not amp: 30 | return 31 | 32 | if padding == -1: 33 | padding = kernel_size // 2 34 | 35 | attorch_input = create_input(shape, dtype=input_dtype) 36 | pytorch_input = create_input(shape, dtype=input_dtype) 37 | 38 | pooling_name = 'AvgPool2d' if len(shape) == 4 else 'AvgPool1d' 39 | attorch_pool = getattr(attorch, pooling_name)(kernel_size, stride, padding) 40 | pytorch_pool = getattr(nn, pooling_name)(kernel_size, stride, padding) 41 | 42 | with autocast('cuda', enabled=amp): 43 | attorch_output = attorch_pool(attorch_input) 44 | pytorch_output = pytorch_pool(pytorch_input) 45 | 46 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-2) 47 | 48 | attorch_output.backward(create_input_like(attorch_output)) 49 | pytorch_output.backward(create_input_like(pytorch_output)) 50 | 51 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-2) 52 | -------------------------------------------------------------------------------- /tests/test_act_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | 7 | import attorch 8 | from .utils import assert_close, create_input, create_input_like, default_shapes 9 | 10 | 11 | @pytest.mark.parametrize('shape', default_shapes()) 12 | @pytest.mark.parametrize('act_func', ['Sigmoid', 'LogSigmoid', 'Tanh', 'ReLU', 'GELU', 'SiLU', 13 | 'ReLU6', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'SELU', 14 | 'Mish', 'Softplus', 'Softsign', 'Tanhshrink', 'LeakyReLU', 15 | 'ELU', 'CELU', 'Hardshrink', 'Softshrink']) 16 | @pytest.mark.parametrize('drop_p', [0.0, 0.5]) 17 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 18 | @pytest.mark.parametrize('amp', [False, True]) 19 | def test_act_layers( 20 | shape: Tuple[int, ...], 21 | act_func: str, 22 | drop_p: float, 23 | input_dtype: bool, 24 | amp: bool, 25 | subset: bool, 26 | ) -> None: 27 | if subset and (shape not in default_shapes(subset=True)): 28 | return 29 | 30 | if input_dtype is torch.float16 and not amp: 31 | return 32 | 33 | attorch_input = create_input(shape, dtype=input_dtype) 34 | pytorch_input = create_input(shape, dtype=input_dtype) 35 | 36 | attorch_act_func = getattr(attorch, act_func)(drop_p=drop_p) 37 | pytorch_act_func = getattr(nn, act_func)() 38 | 39 | with autocast('cuda', enabled=amp): 40 | attorch_output = attorch_act_func(attorch_input) 41 | pytorch_output = pytorch_act_func(pytorch_input) 42 | 43 | if drop_p > 0.0: 44 | mask = attorch_output == 0 45 | pytorch_output = torch.where(mask, 0, pytorch_output / (1 - drop_p)) 46 | 47 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3) 48 | 49 | attorch_output.backward(create_input_like(attorch_output)) 50 | pytorch_output.backward(create_input_like(pytorch_output)) 51 | 52 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-3) 53 | -------------------------------------------------------------------------------- /tests/test_rms_norm_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | from torch.nn import init 7 | 8 | import attorch 9 | from .utils import assert_close, create_input, create_input_like, default_shapes 10 | 11 | 12 | @pytest.mark.parametrize('shape', default_shapes(max_dim=3)) 13 | @pytest.mark.parametrize('eps', [1e-5, 1e-6]) 14 | @pytest.mark.parametrize('elementwise_affine', [False, True]) 15 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 16 | @pytest.mark.parametrize('amp', [False, True]) 17 | def test_rms_norm_layer( 18 | shape: Tuple[int, ...], 19 | eps: float, 20 | elementwise_affine: bool, 21 | input_dtype: bool, 22 | amp: bool, 23 | subset: bool, 24 | ) -> None: 25 | if subset and (shape not in default_shapes(subset=True)): 26 | return 27 | 28 | if input_dtype is torch.float16 and not amp: 29 | return 30 | 31 | attorch_input = create_input(shape, dtype=input_dtype) 32 | pytorch_input = create_input(shape, dtype=input_dtype) 33 | 34 | attorch_rms_norm = attorch.RMSNorm(shape[-1], eps, elementwise_affine) 35 | pytorch_rms_norm = nn.RMSNorm(shape[-1], eps, elementwise_affine, 36 | device='cuda') 37 | 38 | if elementwise_affine: 39 | torch.manual_seed(0) 40 | init.normal_(attorch_rms_norm.weight) 41 | 42 | torch.manual_seed(0) 43 | init.normal_(pytorch_rms_norm.weight) 44 | 45 | with autocast('cuda', enabled=amp): 46 | attorch_output = attorch_rms_norm(attorch_input) 47 | pytorch_output = pytorch_rms_norm(pytorch_input) 48 | 49 | assert_close((attorch_output, pytorch_output), 50 | rtol=1e-3, atol=1e-3) 51 | 52 | attorch_output.backward(create_input_like(attorch_output)) 53 | pytorch_output.backward(create_input_like(pytorch_output)) 54 | 55 | weight_grad_pair = ((attorch_rms_norm.weight.grad, pytorch_rms_norm.weight.grad) 56 | if elementwise_affine else (None, None)) 57 | assert_close((attorch_input.grad, pytorch_input.grad), 58 | weight_grad_pair, 59 | rtol=1e-3, atol=1e-3) 60 | -------------------------------------------------------------------------------- /tests/test_layer_norm_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | from torch.nn import init 7 | 8 | import attorch 9 | from .utils import assert_close, create_input, create_input_like, default_shapes 10 | 11 | 12 | @pytest.mark.parametrize('shape', default_shapes(max_dim=3)) 13 | @pytest.mark.parametrize('eps', [1e-5, 1e-6]) 14 | @pytest.mark.parametrize('elementwise_affine', [False, True]) 15 | @pytest.mark.parametrize('bias', [False, True]) 16 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 17 | @pytest.mark.parametrize('amp', [False, True]) 18 | def test_layer_norm_layer( 19 | shape: Tuple[int, ...], 20 | eps: float, 21 | elementwise_affine: bool, 22 | bias: bool, 23 | input_dtype: bool, 24 | amp: bool, 25 | subset: bool, 26 | ) -> None: 27 | if subset and (shape not in default_shapes(subset=True)): 28 | return 29 | 30 | if input_dtype is torch.float16 and not amp: 31 | return 32 | 33 | attorch_input = create_input(shape, dtype=input_dtype) 34 | pytorch_input = create_input(shape, dtype=input_dtype) 35 | 36 | attorch_layer_norm = attorch.LayerNorm(shape[-1], eps, elementwise_affine, bias) 37 | pytorch_layer_norm = nn.LayerNorm(shape[-1], eps, elementwise_affine, bias, 38 | device='cuda') 39 | 40 | if elementwise_affine: 41 | torch.manual_seed(0) 42 | init.normal_(attorch_layer_norm.weight) 43 | if bias: 44 | init.normal_(attorch_layer_norm.bias) 45 | 46 | torch.manual_seed(0) 47 | init.normal_(pytorch_layer_norm.weight) 48 | if bias: 49 | init.normal_(pytorch_layer_norm.bias) 50 | 51 | with autocast('cuda', enabled=amp): 52 | attorch_output = attorch_layer_norm(attorch_input) 53 | pytorch_output = pytorch_layer_norm(pytorch_input) 54 | 55 | assert_close((attorch_output, pytorch_output), 56 | rtol=1e-3, atol=1e-3) 57 | 58 | attorch_output.backward(create_input_like(attorch_output)) 59 | pytorch_output.backward(create_input_like(pytorch_output)) 60 | 61 | weight_grad_pair = ((attorch_layer_norm.weight.grad, pytorch_layer_norm.weight.grad) 62 | if elementwise_affine else (None, None)) 63 | bias_grad_pair = ((attorch_layer_norm.bias.grad, pytorch_layer_norm.bias.grad) 64 | if elementwise_affine and bias else (None, None)) 65 | assert_close((attorch_input.grad, pytorch_input.grad), 66 | weight_grad_pair, bias_grad_pair, 67 | rtol=1e-3, atol=1e-3) 68 | -------------------------------------------------------------------------------- /tests/test_linear_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | from torch.nn import functional as F 7 | 8 | import attorch 9 | from .utils import assert_close, create_input, create_input_like, default_shapes 10 | 11 | 12 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2, max_dim=3)) 13 | @pytest.mark.parametrize('out_dim', [16, 96, 128, 196, 384, 512, 768, 1024]) 14 | @pytest.mark.parametrize('bias', [False, True]) 15 | @pytest.mark.parametrize('act_func', [None, 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'silu', 16 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 17 | 'mish', 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_0.01', 18 | 'elu_1', 'celu_1', 'hardshrink_0.5', 'softshrink_0.5']) 19 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 20 | @pytest.mark.parametrize('amp', [False, True]) 21 | def test_linear_layer( 22 | shape: Tuple[int, ...], 23 | out_dim: int, 24 | bias: bool, 25 | act_func: Optional[str], 26 | input_dtype: bool, 27 | amp: bool, 28 | subset: bool, 29 | ) -> None: 30 | if subset and (shape not in default_shapes(subset=True)): 31 | return 32 | 33 | if input_dtype is torch.float16 and not amp: 34 | return 35 | 36 | attorch_input = create_input(shape, dtype=input_dtype) 37 | pytorch_input = create_input(shape, dtype=input_dtype) 38 | 39 | torch.manual_seed(0) 40 | attorch_linear = attorch.Linear(shape[-1], out_dim, 41 | bias=bias, 42 | act_func=act_func) 43 | 44 | torch.manual_seed(0) 45 | pytorch_linear = nn.Linear(shape[-1], out_dim, 46 | bias=bias, device='cuda') 47 | pytorch_act = nn.Identity() if act_func is None else getattr(F, act_func.rsplit('_', 1)[0]) 48 | 49 | with autocast('cuda', enabled=amp): 50 | attorch_output = attorch_linear(attorch_input) 51 | pytorch_output = pytorch_act(pytorch_linear(pytorch_input)) 52 | 53 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3) 54 | 55 | attorch_output.backward(create_input_like(attorch_output)) 56 | pytorch_output.backward(create_input_like(pytorch_output)) 57 | 58 | bias_grad_pair = ((attorch_linear.bias.grad, pytorch_linear.bias.grad) 59 | if bias else (None, None)) 60 | assert_close((attorch_input.grad, pytorch_input.grad), 61 | (attorch_linear.weight.grad, pytorch_linear.weight.grad.T.contiguous()), 62 | bias_grad_pair, rtol=1e-3, atol=1e-3) 63 | -------------------------------------------------------------------------------- /attorch/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for attorch kernels and layers. 3 | """ 4 | 5 | 6 | from typing import List, Optional 7 | 8 | import torch 9 | import triton 10 | 11 | 12 | def allow_tf32() -> bool: 13 | """ 14 | Returns whether the current GPU architecture supports TF32. 15 | """ 16 | return torch.cuda.get_device_capability()[0] >= 8 17 | 18 | 19 | def get_n_stages(n_stages: int = 2) -> int: 20 | """ 21 | Receives number of stages for software pipelining and returns it as-is 22 | if the GPU architecture is Ampere or newer and 2 otherwise. 23 | """ 24 | return 2 if torch.cuda.get_device_capability()[0] < 8 else n_stages 25 | 26 | 27 | def get_output_dtype( 28 | input_dtype: torch.dtype = torch.float32, 29 | autocast: Optional[str] = None, 30 | ) -> torch.dtype: 31 | """ 32 | Returns the appropriate output dtype for automatic mixed precision 33 | given the input dtype and the operation's autocast behaviour. 34 | 35 | Args: 36 | input_dtype: Input dtype. 37 | autocast: The relevent operation's autocast behaviour. 38 | None signifies the input dtype should flow through, 39 | 'fp16' signifies autocasting to FP16 when AMP is enabled, 40 | and 'fp32' signifies autocasting to FP32 when AMP is enabled. 41 | """ 42 | dtype = torch.get_autocast_dtype('cuda') 43 | assert dtype, \ 44 | f'Only autocast to float16 is supported, received {dtype}' 45 | 46 | if torch.is_autocast_enabled(): 47 | if autocast is None: 48 | return input_dtype 49 | 50 | elif autocast == 'fp16': 51 | return torch.float16 52 | 53 | elif autocast == 'fp32': 54 | return torch.float32 55 | 56 | else: 57 | raise RuntimeError(f'Autocast type {autocast} is invalid. ' 58 | 'Options are None, fp16, and fp32') 59 | 60 | else: 61 | return input_dtype 62 | 63 | 64 | def element_wise_kernel_configs( 65 | block_name: str = 'BLOCK_SIZE', 66 | ) -> List[triton.Config]: 67 | """ 68 | Returns kernel configurations for element-wise operations. 69 | 70 | Args: 71 | block_name: Name of block argument rows are distributed over. 72 | """ 73 | return [triton.Config({block_name: 64}, num_warps=2), 74 | triton.Config({block_name: 128}, num_warps=2), 75 | triton.Config({block_name: 256}, num_warps=4), 76 | triton.Config({block_name: 512}, num_warps=4), 77 | triton.Config({block_name: 1024}, num_warps=4)] 78 | 79 | 80 | def warps_kernel_configs() -> List[triton.Config]: 81 | """ 82 | Returns kernel configurations with all possible number of warps. 83 | """ 84 | return [triton.Config({}, num_warps=2**i) for i in range(6)] 85 | -------------------------------------------------------------------------------- /tests/test_conv_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | 7 | import attorch 8 | from .utils import assert_close, create_input, create_input_like, default_shapes 9 | 10 | 11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=3, max_dim=4)) 12 | @pytest.mark.parametrize('out_dim', [16, 96, 128, 196, 384, 512, 768, 1024]) 13 | @pytest.mark.parametrize('kernel_size', [1, 2, 3, 4, 5, 7, (5, 3), (3, 5)]) 14 | @pytest.mark.parametrize('stride', [1, 2, (1, 2), (2, 1)]) 15 | @pytest.mark.parametrize('padding', [0, 1, 3, -1]) 16 | @pytest.mark.parametrize('groups', [1, 2, 4, -1]) 17 | @pytest.mark.parametrize('bias', [False, True]) 18 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 19 | @pytest.mark.parametrize('amp', [False, True]) 20 | def test_conv2d_layer( 21 | shape: Tuple[int, ...], 22 | out_dim: int, 23 | kernel_size: Union[int, Tuple[int, int]], 24 | stride: Union[int, Tuple[int, int]], 25 | padding: Union[int, Tuple[int, int]], 26 | groups: int, 27 | bias: bool, 28 | input_dtype: bool, 29 | amp: bool, 30 | subset: bool, 31 | ) -> None: 32 | if subset and (shape not in default_shapes(subset=True)): 33 | return 34 | 35 | if input_dtype is torch.float16 and not amp: 36 | return 37 | 38 | if len(shape) == 3 and (isinstance(kernel_size, tuple) or isinstance(padding, tuple)): 39 | return 40 | 41 | if padding == -1: 42 | padding = kernel_size // 2 43 | 44 | if groups == -1: 45 | groups = shape[1] 46 | 47 | if shape[1] % groups != 0: 48 | groups = 1 49 | 50 | conv_name = 'Conv2d' if len(shape) == 4 else 'Conv1d' 51 | attorch_input = create_input(shape, dtype=input_dtype) 52 | pytorch_input = create_input(shape, dtype=input_dtype) 53 | 54 | torch.manual_seed(0) 55 | attorch_conv = getattr(attorch, conv_name)(shape[1], out_dim, 56 | kernel_size, 57 | stride=stride, padding=padding, 58 | groups=groups, bias=bias) 59 | 60 | torch.manual_seed(0) 61 | pytorch_conv = getattr(nn, conv_name)(shape[1], out_dim, 62 | kernel_size, 63 | stride=stride, padding=padding, 64 | groups=groups, bias=bias, 65 | device='cuda') 66 | 67 | with autocast('cuda', enabled=amp): 68 | attorch_output = attorch_conv(attorch_input) 69 | pytorch_output = pytorch_conv(pytorch_input) 70 | 71 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-2) 72 | 73 | attorch_output.backward(create_input_like(attorch_output)) 74 | pytorch_output.backward(create_input_like(pytorch_output)) 75 | 76 | bias_grad_pair = ((attorch_conv.bias.grad, pytorch_conv.bias.grad) 77 | if bias else (torch.tensor(0), torch.tensor(0))) 78 | assert_close((attorch_input.grad, pytorch_input.grad), 79 | (attorch_conv.weight.grad, pytorch_conv.weight.grad), 80 | bias_grad_pair, rtol=1e-3, atol=1e-2) 81 | -------------------------------------------------------------------------------- /attorch/dropout_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kernels for dropout. 3 | """ 4 | 5 | 6 | import triton 7 | import triton.language as tl 8 | 9 | from .utils import element_wise_kernel_configs 10 | 11 | 12 | @triton.jit 13 | def apply_dropout(input, drop_p, seed, offset): 14 | """ 15 | Randomly zeroes elements in the input. 16 | 17 | Args: 18 | input: Input. The input must be loaded and cannot be a pointer. 19 | drop_p: Probability of dropping an element. 20 | seed: Seed for generating the dropout mask. 21 | offset: Offset to generate the mask for. 22 | 23 | Returns: 24 | Input with elements randomly zeroed out. 25 | """ 26 | random = tl.rand(seed, offset) 27 | return tl.where(random < drop_p, 0, input / (1 - drop_p)) 28 | 29 | 30 | @triton.jit 31 | def apply_dropout_grad(output_grad, drop_p, seed, offset): 32 | """ 33 | Calculates the input gradient of dropout. 34 | 35 | Args: 36 | output_grad: Output gradients. The output gradients must be 37 | loaded and cannot be a pointer. 38 | drop_p: Probability of dropping an element. 39 | seed: Seed for generating the dropout mask. 40 | offset: Offset to generate the mask for. 41 | 42 | Returns: 43 | Gradient of dropout. 44 | """ 45 | random = tl.rand(seed, offset) 46 | return tl.where(random < drop_p, 0, output_grad / (1 - drop_p)) 47 | 48 | 49 | @triton.autotune( 50 | configs=element_wise_kernel_configs(), 51 | key=['size'], 52 | ) 53 | @triton.jit 54 | def dropout_forward_kernel( 55 | input_pointer, output_pointer, size, 56 | drop_p, seed, 57 | BLOCK_SIZE: tl.constexpr, 58 | ): 59 | """ 60 | Randomly zeroes elements in the input. 61 | 62 | Args: 63 | input_pointer: Pointer to the input to perform dropout on. 64 | The input must be of shape [size]. 65 | output_pointer: Pointer to a container the result is written to. 66 | The container must be of shape [size]. 67 | size: Number of elements in the input. 68 | drop_p: Probability of dropping an element. 69 | seed: Seed for generating the dropout mask. 70 | BLOCK_SIZE: Block size. 71 | """ 72 | # This program processes BLOCK_SIZE rows. 73 | pid = tl.program_id(axis=0) 74 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 75 | mask = offset < size 76 | 77 | input = tl.load(input_pointer + offset, mask=mask) 78 | output = apply_dropout(input, drop_p, seed, offset) 79 | tl.store(output_pointer + offset, output, mask=mask) 80 | 81 | 82 | @triton.autotune( 83 | configs=element_wise_kernel_configs(), 84 | key=['size'], 85 | ) 86 | @triton.jit 87 | def dropout_backward_kernel( 88 | output_grad_pointer, input_grad_pointer, size, 89 | drop_p, seed, 90 | BLOCK_SIZE: tl.constexpr, 91 | ): 92 | """ 93 | Calculates the input gradient of dropout. 94 | 95 | Args: 96 | output_grad_pointer: Pointer to dropout's output gradients. 97 | The output gradients must be of shape [size]. 98 | input_grad_pointer: Pointer to a container the input's gradients are written to. 99 | The container must be of shape [size]. 100 | size: Number of elements in the input. 101 | drop_p: Probability of dropping an element used in dropout. 102 | seed: Seed for generating the dropout mask. 103 | BLOCK_SIZE: Block size. 104 | """ 105 | # This program processes BLOCK_SIZE rows. 106 | pid = tl.program_id(axis=0) 107 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 108 | mask = offset < size 109 | 110 | output_grad = tl.load(output_grad_pointer + offset, mask=mask) 111 | input_grad = apply_dropout_grad(output_grad, drop_p, seed, offset) 112 | tl.store(input_grad_pointer + offset, input_grad, mask=mask) 113 | -------------------------------------------------------------------------------- /attorch/dropout_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dropout layer with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | import warnings 7 | from random import randint 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | from torch.amp import custom_bwd, custom_fwd 14 | from triton import cdiv 15 | 16 | from .dropout_kernels import dropout_backward_kernel, dropout_forward_kernel 17 | from .types import Context 18 | 19 | 20 | class DropoutAutoGrad(torch.autograd.Function): 21 | """ 22 | Autodiff for dropout. 23 | """ 24 | @staticmethod 25 | @custom_fwd(device_type='cuda') 26 | def forward( 27 | ctx: Context, 28 | input: Tensor, 29 | drop_p: float, 30 | training: bool, 31 | ) -> Tensor: 32 | """ 33 | Randomly zeroes elements in the input. 34 | 35 | Args: 36 | ctx: Context for variable storage. 37 | input: Input to perform dropout on. 38 | Can have arbitrary shape. 39 | drop_p: Probability of dropping an element. 40 | training: Flag indicating if the model is in training mode. 41 | If False, no dropout is applied. 42 | 43 | Returns: 44 | Input with some elements zeroed out. 45 | """ 46 | ctx.do_dropout = True 47 | if not training or drop_p == 0: 48 | ctx.do_dropout = False 49 | return input 50 | 51 | ctx.drop_all = False 52 | if drop_p == 1: 53 | ctx.drop_all = True 54 | return torch.zeros_like(input) 55 | 56 | flattened_input = input.flatten() 57 | size = len(flattened_input) 58 | output = torch.empty_like(flattened_input) 59 | 60 | seed = randint(0, 65535) 61 | ctx.seed = seed 62 | ctx.drop_p = drop_p 63 | 64 | # Launches 1D grid where each program operates over 65 | # BLOCK_SIZE elements. 66 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),) 67 | dropout_forward_kernel[grid](flattened_input, output, 68 | size, drop_p, seed) 69 | 70 | return output.view_as(input) 71 | 72 | @staticmethod 73 | @custom_bwd(device_type='cuda') 74 | def backward( 75 | ctx: Context, 76 | output_grad: Tensor, 77 | ) -> Tuple[Optional[Tensor], ...]: 78 | """ 79 | Calculates the input gradient of dropout. 80 | 81 | Args: 82 | ctx: Context containing stored variables. 83 | output_grad: Output gradients. 84 | Must be the same shape as the output. 85 | 86 | Returns: 87 | Input gradient of dropout. 88 | """ 89 | if not ctx.do_dropout: 90 | return output_grad, None, None 91 | 92 | if ctx.drop_all: 93 | return torch.zeros_like(output_grad), None, None 94 | 95 | orig_shape = output_grad.shape 96 | output_grad = output_grad.flatten() 97 | size = len(output_grad) 98 | input_grad = torch.empty_like(output_grad) 99 | 100 | # Launches 1D grid where each program operates over 101 | # BLOCK_SIZE elements. 102 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),) 103 | dropout_backward_kernel[grid](output_grad, input_grad, 104 | size, ctx.drop_p, ctx.seed) 105 | 106 | # Pads output with None because a gradient is necessary for 107 | # all input arguments. 108 | return input_grad.view(orig_shape), None, None 109 | 110 | 111 | class Dropout(nn.Dropout): 112 | """ 113 | Randomly zeroes elements in the input during training. 114 | See also base class. 115 | 116 | Args: 117 | p: Probability of dropping an element. 118 | inplace: This is a dummy argument and has no effects, 119 | as in-place is currently not supported. 120 | """ 121 | def __init__(self, p: float = 0.5, inplace: bool = False) -> None: 122 | super().__init__(p=p, inplace=False) 123 | 124 | if inplace is True: 125 | warnings.warn('In-place dropout currently not supported; ' 126 | 'falling back to out-of-place.') 127 | 128 | def forward(self, input: Tensor) -> Tensor: 129 | return DropoutAutoGrad.apply(input, self.p, self.training) 130 | -------------------------------------------------------------------------------- /tests/test_multi_head_attention_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | from torch.nn import init 7 | 8 | import attorch 9 | from .utils import assert_close, create_input, create_input_like, default_shapes 10 | 11 | 12 | @pytest.mark.parametrize('shape', default_shapes(min_dim=3, max_dim=3)) 13 | @pytest.mark.parametrize('self_attention', [False, True]) 14 | @pytest.mark.parametrize('num_heads', [1, 2, 4, 6, 8]) 15 | @pytest.mark.parametrize('bias', [False, True]) 16 | @pytest.mark.parametrize('causal', [False, True]) 17 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 18 | @pytest.mark.parametrize('amp', [False, True]) 19 | def test_multi_head_attention_layer( 20 | shape: Tuple[int, ...], 21 | self_attention: bool, 22 | num_heads: int, 23 | bias: bool, 24 | causal: bool, 25 | input_dtype: bool, 26 | amp: bool, 27 | subset: bool, 28 | ) -> None: 29 | if subset and (shape not in default_shapes(subset=True)): 30 | return 31 | 32 | if (input_dtype is torch.float16 and not amp) or (shape[-1] % num_heads != 0): 33 | return 34 | 35 | attorch_input_q = create_input(shape, dtype=input_dtype) 36 | pytorch_input_q = create_input(shape, dtype=input_dtype) 37 | 38 | if self_attention: 39 | attorch_input_k = attorch_input_q 40 | pytorch_input_k = pytorch_input_q 41 | 42 | else: 43 | attorch_input_k = create_input(shape, dtype=input_dtype, seed=1) 44 | pytorch_input_k = create_input(shape, dtype=input_dtype, seed=1) 45 | 46 | attorch_input_v = attorch_input_k 47 | pytorch_input_v = pytorch_input_k 48 | 49 | torch.manual_seed(0) 50 | attorch_multi_head_attention = attorch.MultiheadAttention(shape[-1], num_heads, 51 | bias=bias, 52 | batch_first=True) 53 | 54 | torch.manual_seed(0) 55 | pytorch_multi_head_attention = nn.MultiheadAttention(shape[-1], num_heads, 56 | bias=bias, 57 | batch_first=True, 58 | device='cuda') 59 | 60 | if bias: 61 | torch.manual_seed(0) 62 | init.normal_(attorch_multi_head_attention.in_proj_bias) 63 | init.normal_(attorch_multi_head_attention.out_proj.bias) 64 | 65 | torch.manual_seed(0) 66 | init.normal_(pytorch_multi_head_attention.in_proj_bias) 67 | init.normal_(pytorch_multi_head_attention.out_proj.bias) 68 | 69 | with autocast('cuda', enabled=amp): 70 | attorch_output = attorch_multi_head_attention(attorch_input_q, 71 | attorch_input_k, 72 | attorch_input_v, 73 | causal=causal) 74 | pytorch_output = pytorch_multi_head_attention(pytorch_input_q, 75 | pytorch_input_k, 76 | pytorch_input_v, 77 | attn_mask=torch.empty(2*(shape[1],)) if causal else None, 78 | is_causal=causal, 79 | need_weights=False)[0] 80 | 81 | assert_close((attorch_output, pytorch_output), 82 | rtol=1e-2, atol=1e-2) 83 | 84 | attorch_output.backward(create_input_like(attorch_output)) 85 | pytorch_output.backward(create_input_like(attorch_output)) 86 | 87 | bias_grad_pair = ((attorch_multi_head_attention.in_proj_bias, 88 | attorch_multi_head_attention.in_proj_bias), 89 | (attorch_multi_head_attention.out_proj.bias, 90 | attorch_multi_head_attention.out_proj.bias) 91 | if bias else ((None, None), (None, None))) 92 | assert_close((attorch_input_q.grad, pytorch_input_q.grad), 93 | (attorch_input_k.grad, pytorch_input_k.grad), 94 | (attorch_input_v.grad, pytorch_input_v.grad), 95 | (attorch_multi_head_attention.in_proj_weight.grad, 96 | pytorch_multi_head_attention.in_proj_weight.grad), 97 | (attorch_multi_head_attention.out_proj.weight.grad, 98 | pytorch_multi_head_attention.out_proj.weight.grad), 99 | *bias_grad_pair, 100 | rtol=1e-2, atol=1e-2) 101 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for tests. 3 | """ 4 | 5 | 6 | from typing import List, Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from attorch.types import Device 12 | 13 | 14 | def default_shapes( 15 | min_dim: int = 0, 16 | max_dim: int = 4, 17 | subset: bool = False, 18 | ) -> List[Tuple[int, ...]]: 19 | """ 20 | Returns typical data shapes for testing. 21 | 22 | Args: 23 | min_dim: Minimum dimensionality of the shapes returned. 24 | max_dim: Maximum dimensionality of the shapes returned. 25 | subset: Returns a subset of data shapes, useful for rapid testing. 26 | """ 27 | if subset: 28 | shapes = [(96,), 29 | (768,), 30 | (4, 2000), 31 | (8, 1024), 32 | (8, 960, 196), 33 | (64, 768, 128), 34 | (64, 64, 56, 56), 35 | (256, 128, 28, 28)] 36 | else: 37 | shapes = [(96,), 38 | (128,), 39 | (196,), 40 | (384,), 41 | (768,), 42 | (1024,), 43 | (3200,), 44 | (4800,), 45 | (8000,), 46 | (12288,), 47 | (1, 8000), 48 | (4, 2000), 49 | (8, 1024), 50 | (32, 1024), 51 | (128, 1024), 52 | (2048, 768), 53 | (6144, 256), 54 | (8096, 32), 55 | (12288, 1), 56 | (1, 1024, 3072), 57 | (8, 960, 196), 58 | (64, 768, 128), 59 | (128, 960, 196), 60 | (2048, 64, 16), 61 | (1, 3, 224, 224), 62 | (8, 3, 224, 224), 63 | (64, 64, 56, 56), 64 | (256, 128, 28, 28), 65 | (256, 2048, 7, 7)] 66 | return list(filter(lambda shape: min_dim <= len(shape) <= max_dim, shapes)) 67 | 68 | 69 | def create_input( 70 | shape: Tuple[int, ...], 71 | dtype: torch.dtype = torch.float32, 72 | device: Device = 'cuda', 73 | requires_grad: bool = True, 74 | seed: Optional[int] = 0, 75 | ) -> Tensor: 76 | """ 77 | Creates a tensor filled with random numbers. 78 | 79 | Args: 80 | shape: Shape of tensor. 81 | dtype: Dtype of tensor. 82 | device: Device of tensor. 83 | requires_grad: Flag for recording operations for autodiff. 84 | seed: Seed for generating random numbers. If None, no seed is set. 85 | 86 | Returns: 87 | Tensor with random numbers. 88 | """ 89 | if seed is not None: 90 | torch.manual_seed(seed) 91 | 92 | return torch.randn(shape, dtype=dtype, device=device, 93 | requires_grad=requires_grad) 94 | 95 | 96 | def create_input_like( 97 | input: Tensor, 98 | requires_grad: bool = False, 99 | seed: Optional[int] = 0, 100 | ) -> Tensor: 101 | """ 102 | Creates a tensor filled with random numbers with the same size, dtype, and 103 | device as input. 104 | 105 | Args: 106 | input: Input. 107 | requires_grad: Flag for recording operations for autodiff. 108 | seed: Seed for generating random numbers. If None, no seed is set. 109 | 110 | Returns: 111 | Tensor with random numbers. 112 | """ 113 | if seed is not None: 114 | torch.manual_seed(seed) 115 | 116 | return torch.randn_like(input, requires_grad=requires_grad) 117 | 118 | 119 | def assert_close( 120 | *tensor_pairs: Tuple[Tensor, Tensor], 121 | rtol: Optional[float] = None, 122 | atol: Optional[float] = None, 123 | ) -> None: 124 | """ 125 | Asserts that the two tensors in each pair of tensor_pairs are close. 126 | See also torch.testing.assert_close. 127 | 128 | Args: 129 | *tensor_pairs: Pairs of tensors that are asserted to be close. 130 | rtol: Relative tolerance. If specified, atol must also be specified. 131 | Otherwise, it is selected according to the tensors' dtypes. 132 | atol: Absolute tolerance. If specified, rtol must also be specified. 133 | Otherwise, it is selected according to the tensors' dtypes. 134 | """ 135 | for pair in tensor_pairs: 136 | try: 137 | torch.testing.assert_close(*pair, rtol=rtol, atol=atol, equal_nan=True) 138 | 139 | except AssertionError as error: 140 | sum_diffs = torch.abs(pair[0] - pair[1]).sum() 141 | elems = pair[0].numel() 142 | raise RuntimeError(f'Tensors not equal; ' 143 | f'average difference per element is {sum_diffs / elems}. ' 144 | f'Original error: {error}') 145 | -------------------------------------------------------------------------------- /attorch/glu_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kernels for gated linear units with arbitrary activation functions. 3 | """ 4 | 5 | 6 | import triton 7 | import triton.language as tl 8 | 9 | from .act_kernels import apply_act_func, apply_act_func_grad 10 | from .utils import element_wise_kernel_configs 11 | 12 | 13 | @triton.autotune( 14 | configs=element_wise_kernel_configs(), 15 | key=['size'], 16 | ) 17 | @triton.jit 18 | def glu_forward_kernel( 19 | input1_pointer, input2_pointer, output_pointer, size, param, 20 | act_func: tl.constexpr, 21 | BLOCK_SIZE: tl.constexpr, 22 | ): 23 | """ 24 | Applies the gated linear unit with an arbitrary activation function 25 | to the input. 26 | 27 | Args: 28 | input1_pointer: Pointer to the first half of the input to gate. 29 | The first half must be contiguous and contain size elements. 30 | input2_pointer: Pointer to the second half of the input to gate. 31 | The second half must be contiguous and contain size elements. 32 | output_pointer: Pointer to a container the result is written to. 33 | The container must be contiguous and contain size elements. 34 | size: Number of elements in each half of the input. 35 | param: Parameter in the case of parameterized activation functions. 36 | act_func: Name of activation function to apply. 37 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 38 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish', 39 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', and 'hardshrink'. 40 | BLOCK_SIZE: Block size. 41 | """ 42 | # This program processes BLOCK_SIZE elements. 43 | pid = tl.program_id(axis=0) 44 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 45 | mask = offset < size 46 | 47 | input1 = tl.load(input1_pointer + offset, mask=mask) 48 | input2 = tl.load(input2_pointer + offset, mask=mask) 49 | 50 | output = input1 * apply_act_func(input2, None, None, None, param, 51 | act_func, False) 52 | tl.store(output_pointer + offset, output, mask=mask) 53 | 54 | 55 | @triton.autotune( 56 | configs=element_wise_kernel_configs(), 57 | key=['size'], 58 | ) 59 | @triton.jit 60 | def glu_backward_kernel( 61 | output_grad_pointer, input1_pointer, input2_pointer, 62 | input1_grad_pointer, input2_grad_pointer, size, param, 63 | act_func: tl.constexpr, 64 | BLOCK_SIZE: tl.constexpr, 65 | ): 66 | """ 67 | Calculates the input gradient of the gated linear unit. 68 | 69 | Args: 70 | output_grad_pointer: Pointer to the unit's output gradients. 71 | The output gradients must be contiguous and contain size elements. 72 | input1_pointer: Pointer to the first half of the input that was gated. 73 | The first half must be contiguous and contain size elements. 74 | input2_pointer: Pointer to the second half of the input that was gated. 75 | The second half must be contiguous and contain size elements. 76 | input1_grad_pointer: Pointer to a container the first half's gradients are written to. 77 | The container must be contiguous and contain size elements. 78 | input2_grad_pointer: Pointer to a container the second half's gradients are written to. 79 | The container must be contiguous and contain size elements. 80 | size: Number of elements in each half of the input. 81 | param: Parameter in the case of parameterized activation functions. 82 | act_func: Name of activation function to apply. 83 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 84 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', 'hardshrink', 85 | and 'softshrink'. 86 | BLOCK_SIZE: Block size. 87 | """ 88 | # This program processes BLOCK_SIZE elements. 89 | pid = tl.program_id(axis=0) 90 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 91 | mask = offset < size 92 | 93 | output_grad = tl.load(output_grad_pointer + offset, mask=mask) 94 | input1 = tl.load(input1_pointer + offset, mask=mask) 95 | input2 = tl.load(input2_pointer + offset, mask=mask) 96 | 97 | input1_grad = output_grad * apply_act_func(input2, None, None, None, param, 98 | act_func, False) 99 | input2_grad = output_grad * input1 * apply_act_func_grad(1, input2, 100 | None, None, None, 101 | param, act_func, 102 | False) 103 | 104 | tl.store(input1_grad_pointer + offset, input1_grad, mask=mask) 105 | tl.store(input2_grad_pointer + offset, input2_grad, mask=mask) 106 | -------------------------------------------------------------------------------- /attorch/softmax_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Softmax and related layers with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from triton import cdiv 12 | 13 | from .softmax_kernels import softmax_backward_kernel, softmax_forward_kernel 14 | from .types import Context 15 | from .utils import get_output_dtype 16 | 17 | 18 | class SoftmaxAutoGrad(torch.autograd.Function): 19 | """ 20 | Autodiff for softmax and related functions. 21 | """ 22 | @staticmethod 23 | def forward( 24 | ctx: Context, 25 | input: Tensor, 26 | neg: bool, 27 | log: bool, 28 | ) -> Tensor: 29 | """ 30 | Normalizes the input using softmax. 31 | 32 | Args: 33 | ctx: Context for variable storage. 34 | input: Input to normalize. 35 | Can have arbitrary shape. 36 | neg: Flag indicating if the input should be negated to get softmin. 37 | log: Flag indicating if the log of softmax should be taken. 38 | 39 | Returns: 40 | Input normalized by softmax. 41 | """ 42 | flattened_input = input.unsqueeze(0) if input.ndim == 1 else input 43 | flattened_input = flattened_input.flatten(0, -2) 44 | batch_dim, feat_dim = flattened_input.shape 45 | 46 | output_dtype = get_output_dtype(input.dtype, autocast='fp32') 47 | output = torch.empty_like(flattened_input, dtype=output_dtype) 48 | 49 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 50 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),) 51 | softmax_forward_kernel[grid](flattened_input, output, batch_dim, feat_dim, 52 | *flattened_input.stride(), *output.stride(), 53 | neg=neg, log=log) 54 | 55 | ctx.neg = neg 56 | ctx.log = log 57 | if input.requires_grad: 58 | ctx.save_for_backward(output) 59 | 60 | return output.view_as(input) 61 | 62 | @staticmethod 63 | def backward( 64 | ctx: Context, 65 | output_grad: Tensor, 66 | ) -> Tuple[Optional[Tensor], ...]: 67 | """ 68 | Calculates the input gradient of softmax. 69 | 70 | Args: 71 | ctx: Context containing stored variables. 72 | output_grad: Output gradients. 73 | Must be the same shape as the output. 74 | 75 | Returns: 76 | Input gradient of softmax. 77 | """ 78 | (output,) = ctx.saved_tensors 79 | flattened_output_grad = output_grad.view_as(output) 80 | 81 | batch_dim, feat_dim = output.shape 82 | input_grad = torch.empty_like(output) 83 | 84 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 85 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),) 86 | softmax_backward_kernel[grid](flattened_output_grad, output, input_grad, 87 | batch_dim, feat_dim, 88 | *flattened_output_grad.stride(), 89 | *output.stride(), *input_grad.stride(), 90 | neg=ctx.neg, log=ctx.log) 91 | 92 | # Pads output with None because a gradient is necessary for 93 | # all input arguments. 94 | return input_grad.view_as(output_grad), None, None 95 | 96 | 97 | class Softmax(nn.Softmax): 98 | """ 99 | Normalizes the input using softmax. 100 | See also base class. 101 | 102 | Args: 103 | dim: Dimension along which softmax will be computed. 104 | Only softmax along the last dimension is supported. 105 | """ 106 | def forward(self, input: Tensor) -> Tensor: 107 | if self.dim != -1 and self.dim != input.ndim - 1: 108 | raise RuntimeError(f'Only softmax along the last dimension is supported.') 109 | 110 | return SoftmaxAutoGrad.apply(input, False, False) 111 | 112 | 113 | class LogSoftmax(nn.LogSoftmax): 114 | """ 115 | Normalizes the input using softmax and takes its log. 116 | See also base class. 117 | 118 | Args: 119 | dim: Dimension along which softmax will be computed. 120 | Only softmax along the last dimension is supported. 121 | """ 122 | def forward(self, input: Tensor) -> Tensor: 123 | if self.dim != -1 and self.dim != input.ndim - 1: 124 | raise RuntimeError(f'Only softmax along the last dimension is supported.') 125 | 126 | return SoftmaxAutoGrad.apply(input, False, True) 127 | 128 | 129 | class Softmin(nn.Softmin): 130 | """ 131 | Normalizes the input using softmin. 132 | See also base class. 133 | 134 | Args: 135 | dim: Dimension along which softmin will be computed. 136 | Only softmin along the last dimension is supported. 137 | """ 138 | def forward(self, input: Tensor) -> Tensor: 139 | if self.dim != -1 and self.dim != input.ndim - 1: 140 | raise RuntimeError(f'Only softmin along the last dimension is supported.') 141 | 142 | return SoftmaxAutoGrad.apply(input, True, False) 143 | -------------------------------------------------------------------------------- /attorch/glu_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gated linear unit layer with arbitrary activation functions 3 | with PyTorch autodiff support. 4 | """ 5 | 6 | 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | from torch import Tensor 11 | from torch import nn 12 | from torch.amp import custom_bwd, custom_fwd 13 | from triton import cdiv 14 | 15 | from .glu_kernels import glu_backward_kernel, glu_forward_kernel 16 | from .types import Context 17 | 18 | 19 | class GLUAutoGrad(torch.autograd.Function): 20 | """ 21 | Autodiff for gated linear unit. 22 | """ 23 | @staticmethod 24 | @custom_fwd(device_type='cuda') 25 | def forward( 26 | ctx: Context, 27 | input: Tensor, 28 | dim: int, 29 | act_func: str, 30 | ) -> Tensor: 31 | """ 32 | Applies the gated linear unit with an arbitrary activation function 33 | to the input. 34 | 35 | Args: 36 | ctx: Context for variable storage. 37 | input: Input to gate. 38 | Can have arbitrary shape but dimension dim must be even. 39 | dim: Dimension over which to gate. 40 | act_func: Name of activation function to apply. 41 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 42 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish', 43 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM', 44 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM' 45 | where PARAM stands for the parameter in the case of parameterized 46 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a 47 | negative slope of 0.01). 48 | 49 | Returns: 50 | Input transformed by the gated linear unit 51 | with an arbitrary activation function. 52 | """ 53 | param = None 54 | if '_' in act_func: 55 | comps = act_func.split('_') 56 | act_func = '_'.join(comps[:-1]) 57 | param = float(comps[-1]) 58 | 59 | input1, input2 = input.chunk(2, dim=dim) 60 | input1 = input1.contiguous() 61 | input2 = input2.contiguous() 62 | 63 | requires_grad = input.requires_grad 64 | size = input1.numel() 65 | output = torch.empty_like(input1) 66 | 67 | ctx.param = param 68 | ctx.act_func = act_func 69 | ctx.dim = dim 70 | ctx.size = size 71 | if requires_grad: 72 | ctx.save_for_backward(input1, input2) 73 | 74 | # Launches 1D grid where each program operates over 75 | # BLOCK_SIZE elements. 76 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),) 77 | glu_forward_kernel[grid](input1, input2, output, size, param, act_func) 78 | 79 | return output 80 | 81 | @staticmethod 82 | @custom_bwd(device_type='cuda') 83 | def backward( 84 | ctx: Context, 85 | output_grad: Tensor, 86 | ) -> Tuple[Optional[Tensor], ...]: 87 | """ 88 | Calculates the input gradient of the gated linear unit. 89 | 90 | Args: 91 | ctx: Context containing stored variables. 92 | output_grad: Output gradients. 93 | Must be the same shape as the output. 94 | 95 | Returns: 96 | Input gradient of the gated linear unit. 97 | """ 98 | (input1, input2) = ctx.saved_tensors 99 | input1_grad = torch.empty_like(input1) 100 | input2_grad = torch.empty_like(input2) 101 | 102 | # Launches 1D grid where each program operates over 103 | # BLOCK_SIZE elements. 104 | grid = lambda META: (cdiv(ctx.size, META['BLOCK_SIZE']),) 105 | glu_backward_kernel[grid](output_grad, input1, input2, 106 | input1_grad, input2_grad, 107 | ctx.size, ctx.param, ctx.act_func) 108 | 109 | # Pads output with None because a gradient is necessary for 110 | # all input arguments. 111 | return torch.concat([input1_grad, input2_grad], dim=ctx.dim), None, None 112 | 113 | 114 | class GLU(nn.GLU): 115 | """ 116 | Applies the gated linear unit with an arbitrary activation function 117 | to the input. 118 | See also base class. 119 | 120 | Args: 121 | dim: Dimension over which to gate. 122 | act_func: Name of activation function to apply. 123 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 124 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish', 125 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM', 126 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM' 127 | where PARAM stands for the parameter in the case of parameterized 128 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a 129 | negative slope of 0.01). 130 | """ 131 | def __init__(self, dim: int = -1, act_func: str = 'sigmoid') -> None: 132 | super().__init__(dim) 133 | self.act_func = act_func 134 | 135 | def forward(self, input: Tensor) -> Tensor: 136 | return GLUAutoGrad.apply(input, self.dim, self.act_func) 137 | -------------------------------------------------------------------------------- /tests/test_batch_norm_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import pytest 4 | import torch 5 | from torch import autocast, nn 6 | from torch.nn import functional as F, init 7 | 8 | import attorch 9 | from .utils import assert_close, create_input, create_input_like, default_shapes 10 | 11 | 12 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2, max_dim=4)) 13 | @pytest.mark.parametrize('eps', [1e-5, 1e-6]) 14 | @pytest.mark.parametrize('momentum', [0.1, 0.2]) 15 | @pytest.mark.parametrize('affine', [False, True]) 16 | @pytest.mark.parametrize('track_running_stats', [False, True]) 17 | @pytest.mark.parametrize('add_pre_act', [False, True]) 18 | @pytest.mark.parametrize('act_func', [None, 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'silu', 19 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 20 | 'mish', 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_0.01', 21 | 'elu_1', 'celu_1', 'hardshrink_0.5', 'softshrink_0.5']) 22 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16]) 23 | @pytest.mark.parametrize('amp', [False, True]) 24 | def test_batch_norm_layer( 25 | shape: Tuple[int, ...], 26 | eps: float, 27 | momentum: float, 28 | affine: bool, 29 | track_running_stats: bool, 30 | add_pre_act: bool, 31 | act_func: Optional[str], 32 | input_dtype: bool, 33 | amp: bool, 34 | subset: bool, 35 | ) -> None: 36 | if subset and (shape not in default_shapes(subset=True)): 37 | return 38 | 39 | if shape[0] == 1 or (input_dtype is torch.float16 and not amp): 40 | return 41 | 42 | bn_name = 'BatchNorm2d' if len(shape) == 4 else 'BatchNorm1d' 43 | attorch_input = create_input(shape, dtype=input_dtype) 44 | pytorch_input = create_input(shape, dtype=input_dtype) 45 | 46 | if add_pre_act: 47 | attorch_residual = create_input(shape, dtype=input_dtype, seed=1) 48 | pytorch_residual = create_input(shape, dtype=input_dtype, seed=1) 49 | 50 | else: 51 | attorch_residual = pytorch_residual = None 52 | 53 | attorch_batch_norm = getattr(attorch, bn_name)(num_features=shape[1], 54 | eps=eps, momentum=momentum, 55 | affine=affine, 56 | track_running_stats=track_running_stats, 57 | act_func=act_func) 58 | pytorch_batch_norm = getattr(nn, bn_name)(num_features=shape[1], 59 | eps=eps, momentum=momentum, 60 | affine=affine, 61 | track_running_stats=track_running_stats, 62 | device='cuda') 63 | pytorch_act = nn.Identity() if act_func is None else getattr(F, act_func.rsplit('_', 1)[0]) 64 | 65 | if affine: 66 | torch.manual_seed(0) 67 | init.normal_(attorch_batch_norm.weight) 68 | init.normal_(attorch_batch_norm.bias) 69 | 70 | torch.manual_seed(0) 71 | init.normal_(pytorch_batch_norm.weight) 72 | init.normal_(pytorch_batch_norm.bias) 73 | 74 | with autocast('cuda', enabled=amp): 75 | if add_pre_act: 76 | attorch_output = attorch_batch_norm(attorch_input, attorch_residual) 77 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input) + 78 | pytorch_residual) 79 | 80 | else: 81 | attorch_output = attorch_batch_norm(attorch_input) 82 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input)) 83 | 84 | assert_close((attorch_output, pytorch_output), 85 | (attorch_batch_norm.running_mean, pytorch_batch_norm.running_mean), 86 | (attorch_batch_norm.running_var, pytorch_batch_norm.running_var)) 87 | 88 | attorch_output.backward(create_input_like(attorch_output)) 89 | pytorch_output.backward(create_input_like(pytorch_output)) 90 | 91 | residual_grad_pair = ((attorch_residual.grad, pytorch_residual.grad) 92 | if add_pre_act else (None, None)) 93 | weight_grad_pair = ((attorch_batch_norm.weight.grad, pytorch_batch_norm.weight.grad) 94 | if affine else (None, None)) 95 | bias_grad_pair = ((attorch_batch_norm.bias.grad, pytorch_batch_norm.bias.grad) 96 | if affine else (None, None)) 97 | assert_close((attorch_input.grad, pytorch_input.grad), 98 | residual_grad_pair, weight_grad_pair, bias_grad_pair, 99 | rtol=1e-2, atol=1e-3) 100 | 101 | attorch_batch_norm.eval() 102 | pytorch_batch_norm.eval() 103 | 104 | with autocast('cuda', enabled=amp): 105 | if add_pre_act: 106 | attorch_output = attorch_batch_norm(attorch_input, attorch_residual) 107 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input) + 108 | pytorch_residual) 109 | 110 | else: 111 | attorch_output = attorch_batch_norm(attorch_input) 112 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input)) 113 | 114 | assert_close((attorch_output, pytorch_output), 115 | (attorch_batch_norm.running_mean, pytorch_batch_norm.running_mean), 116 | (attorch_batch_norm.running_var, pytorch_batch_norm.running_var)) 117 | -------------------------------------------------------------------------------- /examples/wikitext-2/gpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPT2 for language modelling. 3 | """ 4 | 5 | 6 | from typing import Optional 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | 12 | import attorch 13 | 14 | 15 | class MLP(nn.Module): 16 | """ 17 | Transforms the input using a multilayer perceptron with one hidden layer 18 | and the GELU activation function. 19 | 20 | Args: 21 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend. 22 | in_dim: Number of input features. 23 | hidden_dim: Number of hidden features. 24 | out_dim: Number of output features. 25 | If None, it is set to the number of input features. 26 | """ 27 | def __init__( 28 | self, 29 | use_attorch: bool, 30 | in_dim: int, 31 | hidden_dim: int, 32 | out_dim: Optional[int] = None, 33 | ) -> None: 34 | super().__init__() 35 | 36 | self.fc1 = (attorch.Linear(in_dim, hidden_dim, act_func='gelu') 37 | if use_attorch else nn.Linear(in_dim, hidden_dim)) 38 | self.act = nn.Identity() if use_attorch else nn.GELU() 39 | self.fc2 = (attorch.Linear(hidden_dim, out_dim or in_dim) 40 | if use_attorch else nn.Linear(hidden_dim, out_dim or in_dim)) 41 | 42 | def forward(self, input: Tensor) -> Tensor: 43 | return self.fc2(self.act(self.fc1(input))) 44 | 45 | 46 | class TransformerBlock(nn.Module): 47 | """ 48 | Passes the input through a transformer block. 49 | 50 | Args: 51 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend. 52 | dim: Embedding dimension. 53 | num_heads: Number of heads for multi-headed self-attention. 54 | """ 55 | def __init__( 56 | self, 57 | use_attorch: bool, 58 | dim: int, 59 | num_heads: int, 60 | ) -> None: 61 | super().__init__() 62 | self.use_attorch = use_attorch 63 | backend = attorch if use_attorch else nn 64 | 65 | self.ln1 = backend.LayerNorm(dim) 66 | self.attn = backend.MultiheadAttention(dim, num_heads, 67 | batch_first=True) 68 | 69 | self.ln2 = backend.LayerNorm(dim) 70 | self.mlp = MLP(use_attorch, dim, 4 * dim) 71 | 72 | def forward(self, input: Tensor) -> Tensor: 73 | if self.use_attorch: 74 | output = input + self.attn(self.ln1(input), causal=True) 75 | 76 | else: 77 | output = self.ln1(input) 78 | output = input + self.attn(output, output, output, 79 | attn_mask=torch.empty(2*(input.shape[1],)), 80 | is_causal=True, 81 | need_weights=False)[0] 82 | 83 | output = output + self.mlp(self.ln2(output)) 84 | return output 85 | 86 | 87 | class GPT2(nn.Module): 88 | """ 89 | Performs language modelling using GPT2, 90 | optionally computing the loss if return_loss is True. 91 | 92 | Args: 93 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend. 94 | vocab_size: Vocabulary size. 95 | depth: Depth of the transformer. 96 | dim: Embedding dimension. 97 | num_heads: Number of heads for multi-headed self-attention. 98 | max_seq_len: Maximum sequence length of the incoming inputs. 99 | """ 100 | def __init__( 101 | self, 102 | use_attorch: bool, 103 | vocab_size: int, 104 | depth: int, 105 | dim: int, 106 | num_heads: int, 107 | max_seq_len: int = 512, 108 | ) -> None: 109 | super().__init__() 110 | backend = attorch if use_attorch else nn 111 | 112 | self.tok_embed = nn.Embedding(vocab_size, dim) 113 | self.pos_embed = nn.Embedding(max_seq_len, dim) 114 | self.transformer = nn.Sequential(*[TransformerBlock(use_attorch, dim, num_heads) 115 | for _ in range(depth)]) 116 | self.norm = backend.LayerNorm(dim) 117 | self.fc = backend.Linear(dim, vocab_size) 118 | self.loss_func = backend.CrossEntropyLoss() 119 | 120 | def forward( 121 | self, 122 | input: Tensor, 123 | return_loss: bool = False, 124 | ) -> Tensor: 125 | tok_embed = self.tok_embed(input) 126 | pos_embed = self.pos_embed(torch.arange(0, input.shape[1], 127 | dtype=torch.long, 128 | device=input.device)) 129 | 130 | output = self.transformer(tok_embed + pos_embed) 131 | output = self.norm(output) 132 | output = self.fc(output) 133 | 134 | return (self.loss_func(output[:, :-1].contiguous().view(-1, output.shape[-1]), 135 | input[:, 1:].contiguous().view(-1)) 136 | if return_loss else output) 137 | 138 | 139 | def gpt2( 140 | use_attorch: bool, 141 | vocab_size: int, 142 | max_seq_len: 512, 143 | downsize: int = 1, 144 | ) -> GPT2: 145 | """ 146 | Returns a GPT2 model with optional cross entropy loss. 147 | 148 | Args: 149 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend. 150 | vocab_size: Vocabulary size. 151 | max_seq_len: Maximum sequence length of the incoming inputs. 152 | downsize: The depth and width of the model are calculated by dividing 153 | GPT2's original depth and width by this factor. 154 | """ 155 | return GPT2(use_attorch, vocab_size=vocab_size, 156 | depth=12 // downsize, dim=768 // downsize, num_heads=12 // downsize, 157 | max_seq_len=max_seq_len) 158 | -------------------------------------------------------------------------------- /attorch/pooling_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pooling layers with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from torch.nn.modules.utils import _pair 12 | 13 | from .conv_layer import Conv2dAutoGrad 14 | 15 | 16 | class AvgPool1d(nn.AvgPool1d): 17 | """ 18 | Averages 1D windows of pixels in the input. 19 | See also base class. 20 | 21 | Note: The Triton compiler does not perform well with convolutional kernels, 22 | which underlie this pooling layer, and a significant speed disparity between 23 | this module and its PyTorch equivalent should be expected. Use at your own discretion. 24 | 25 | Args: 26 | kernel_size: Kernel size. 27 | If an int, this value is used along both spatial dimensions. 28 | stride: Stride of kernel. 29 | If an int, this value is used along both spatial dimensions. 30 | If None, the kernel size is used as the stride. 31 | padding: Padding applied to the input. 32 | If an int, this value is used along both spatial dimensions. 33 | ceil_mode: Flag for using ceil instead of floor to calculate the output shape. 34 | Must be False. 35 | count_include_pad: Flag for including padding when averaging pixels. 36 | Must be True. 37 | 38 | Raises: 39 | RuntimeError: 1. Ceil was requested in place of floor to calculate the otuput shape. 40 | 2. Padding was requested to be excluded in average computations. 41 | """ 42 | def __init__( 43 | self, 44 | kernel_size: int, 45 | stride: Optional[int] = None, 46 | padding: int = 0, 47 | ceil_mode: bool = False, 48 | count_include_pad: bool = True, 49 | ) -> None: 50 | if ceil_mode: 51 | raise RuntimeError('The output shape can only be calculated using floor, not ceil.') 52 | 53 | if not count_include_pad: 54 | raise RuntimeError('Padding must be included when averaging pixels.') 55 | 56 | super().__init__(kernel_size, stride, padding, ceil_mode, 57 | count_include_pad) 58 | 59 | self.kernel = None 60 | 61 | def forward(self, input: Tensor) -> Tensor: 62 | if self.kernel is None: 63 | self.kernel = torch.full((input.shape[1], 1, self.kernel_size[0], 1), 64 | 1 / self.kernel_size[0], 65 | device='cuda') 66 | 67 | return Conv2dAutoGrad.apply(input.unsqueeze(-1), self.kernel, 68 | None, 69 | self.stride[0], 1, self.padding[0], 0, 70 | input.shape[1]).squeeze(-1) 71 | 72 | 73 | class AvgPool2d(nn.AvgPool2d): 74 | """ 75 | Averages 2D windows of pixels in the input. 76 | See also base class. 77 | 78 | Note: The Triton compiler does not perform well with convolutional kernels, 79 | which underlie this pooling layer, and a significant speed disparity between 80 | this module and its PyTorch equivalent should be expected. Use at your own discretion. 81 | 82 | Args: 83 | kernel_size: Kernel size. 84 | If an int, this value is used along both spatial dimensions. 85 | stride: Stride of kernel. 86 | If an int, this value is used along both spatial dimensions. 87 | If None, the kernel size is used as the stride. 88 | padding: Padding applied to the input. 89 | If an int, this value is used along both spatial dimensions. 90 | ceil_mode: Flag for using ceil instead of floor to calculate the output shape. 91 | Must be False. 92 | count_include_pad: Flag for including padding when averaging pixels. 93 | Must be True. 94 | divisor_override: Average divisor. 95 | Must be None. 96 | 97 | Raises: 98 | RuntimeError: 1. Ceil was requested in place of floor to calculate the otuput shape. 99 | 2. Padding was requested to be excluded in average computations. 100 | 3. Average divisor was overriden. 101 | """ 102 | def __init__( 103 | self, 104 | kernel_size: Union[int, Tuple[int, int]], 105 | stride: Optional[Union[int, Tuple[int, int]]] = None, 106 | padding: Union[int, Tuple[int, int]] = 0, 107 | ceil_mode: bool = False, 108 | count_include_pad: bool = True, 109 | divisor_override: Optional[int] = None, 110 | ) -> None: 111 | if ceil_mode: 112 | raise RuntimeError('The output shape can only be calculated using floor, not ceil.') 113 | 114 | if not count_include_pad: 115 | raise RuntimeError('Padding must be included when averaging pixels.') 116 | 117 | if divisor_override is not None: 118 | raise RuntimeError('The average divisor must be the size of the window.') 119 | 120 | super().__init__(kernel_size, stride, padding, ceil_mode, 121 | count_include_pad, divisor_override) 122 | 123 | self.kernel_size = _pair(self.kernel_size) 124 | self.stride = _pair(self.stride) 125 | self.padding = _pair(self.padding) 126 | self.kernel = None 127 | 128 | def forward(self, input: Tensor) -> Tensor: 129 | if self.kernel is None: 130 | self.kernel = torch.full((input.shape[1], 1, *self.kernel_size), 131 | 1 / (self.kernel_size[0] * self.kernel_size[1]), 132 | device='cuda') 133 | 134 | return Conv2dAutoGrad.apply(input, self.kernel, None, 135 | *self.stride, *self.padding, 136 | input.shape[1]) 137 | -------------------------------------------------------------------------------- /examples/mnist/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains and benchmarks a multilayer perceptron (MLP) on the MNIST dataset. 3 | """ 4 | 5 | 6 | import time 7 | from argparse import ArgumentParser 8 | from functools import partial 9 | from pathlib import Path 10 | from typing import Callable, Tuple 11 | 12 | import torch 13 | from torch import nn 14 | from torch.optim import SGD 15 | from torch.utils.data import DataLoader 16 | from torchvision import transforms as T 17 | from torchvision.datasets import MNIST 18 | 19 | from .mlp import MLP 20 | from ..utils import AvgMeter, benchmark_fw_and_bw 21 | 22 | 23 | def create_dls(batch_size: int = 1_024) -> Tuple[DataLoader, DataLoader]: 24 | """ 25 | Creates data loaders for MNIST with normalization. 26 | 27 | args: 28 | batch_size: Batch size. 29 | 30 | Returns: 31 | Training and validation MNIST data loaders. 32 | """ 33 | transform = T.Compose([T.ToTensor(), 34 | T.Normalize((0.1307,), (0.3081,))]) 35 | 36 | train_dataset = MNIST(root='.', train=True, transform=transform, 37 | download=not Path('MNIST/').exists()) 38 | valid_dataset = MNIST(root='.', train=False, transform=transform) 39 | 40 | train_dl = DataLoader(dataset=train_dataset, batch_size=batch_size, 41 | shuffle=True, drop_last=True) 42 | valid_dl = DataLoader(dataset=valid_dataset, batch_size=batch_size, 43 | shuffle=False, drop_last=True) 44 | 45 | return train_dl, valid_dl 46 | 47 | 48 | def train( 49 | model: nn.Module, 50 | train_dl: DataLoader, 51 | valid_dl: DataLoader, 52 | epochs: int = 10, 53 | batch_size: int = 1_024, 54 | ) -> float: 55 | """ 56 | Trains and validates a model for classification. 57 | 58 | Args: 59 | model: Model to train. Its forward pass must optionally accept targets 60 | to compute the loss. 61 | train_dl: Data loader for training. 62 | valid_dl: Data loader for validation. 63 | epochs: Number of epochs to train for. 64 | batch_size: Batch size. 65 | 66 | Returns: 67 | Total training and validation time. 68 | """ 69 | model = model.to('cuda') 70 | optim = SGD(model.parameters(), lr=batch_size / 64 * 0.005) 71 | optim.zero_grad() 72 | 73 | avg_meter = AvgMeter() 74 | start = time.time() 75 | 76 | for epoch in range(epochs): 77 | print(f'Epoch {epoch+1}/{epochs}') 78 | 79 | model.train() 80 | avg_meter.reset() 81 | for input, target in train_dl: 82 | input = input.to('cuda') 83 | target = target.to('cuda') 84 | 85 | loss = model(input, target) 86 | optim.zero_grad() 87 | 88 | avg_meter.update(loss.item(), len(input)) 89 | print(f'Training loss: {avg_meter.avg}') 90 | 91 | model.eval() 92 | avg_meter.reset() 93 | with torch.no_grad(): 94 | for input, target in valid_dl: 95 | input = input.to('cuda') 96 | target = target.to('cuda') 97 | 98 | output = model(input) 99 | acc = (output.argmax(dim=-1) == target).float().mean() 100 | avg_meter.update(acc.item(), len(input)) 101 | print(f'Validation accuracy: {avg_meter.avg}') 102 | 103 | return time.time() - start 104 | 105 | 106 | def main(model_cls: Callable, epochs: int = 10, batch_size: int = 1_024) -> None: 107 | """ 108 | Trains and benchmarks a vision model on the MNIST dataset. 109 | 110 | Args: 111 | model_cls: Model class to train, with a 'num_classes' argument for 112 | specifying the number of output classes. 113 | epochs: Number of epochs to train for. 114 | batch_size: Batch size. 115 | """ 116 | train_dl, valid_dl = create_dls(batch_size) 117 | model = model_cls(num_classes=len(MNIST.classes)).to('cuda') 118 | 119 | input, target = next(iter(train_dl)) 120 | input = input.to('cuda') 121 | target = target.to('cuda') 122 | 123 | for _ in range(10): 124 | model.train() 125 | with torch.autocast('cuda'): 126 | loss = model(input, target) 127 | loss.backward() 128 | 129 | model.eval() 130 | with torch.no_grad() and torch.autocast('cuda'): 131 | model(input) 132 | 133 | model.train() 134 | benchmark_fw_and_bw(model, input=input, target=target) 135 | 136 | print('Total training and validation time: ' 137 | f'{train(model, train_dl, valid_dl, epochs, batch_size)}') 138 | 139 | 140 | if __name__ == '__main__': 141 | parser = ArgumentParser(description='Trains and benchmarks a multilayer perceptron (MLP) on the MNIST dataset.') 142 | parser.add_argument('--hidden_dim', 143 | type=int, 144 | default=128, 145 | help='Number of hidden features in the MLP.') 146 | parser.add_argument('--depth', 147 | type=int, 148 | default=1, 149 | help='Number of hidden layers in the MLP.') 150 | parser.add_argument('--epochs', 151 | type=int, 152 | default=10, 153 | help='Number of epochs to train for') 154 | parser.add_argument('--batch_size', 155 | type=int, 156 | default=1_024, 157 | help='Batch size') 158 | args = parser.parse_args() 159 | 160 | model_cls = partial(MLP, hidden_dim=args.hidden_dim, depth=args.depth) 161 | 162 | print('attorch run:') 163 | main(partial(model_cls, use_attorch=True), 164 | epochs=args.epochs, batch_size=args.batch_size) 165 | 166 | print('PyTorch run:') 167 | main(partial(model_cls, use_attorch=False), 168 | epochs=args.epochs, batch_size=args.batch_size) 169 | -------------------------------------------------------------------------------- /attorch/p_loss_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kernels for p-norm-induced losses. 3 | """ 4 | 5 | 6 | import triton 7 | import triton.language as tl 8 | 9 | from .utils import element_wise_kernel_configs 10 | 11 | 12 | @triton.autotune( 13 | configs=element_wise_kernel_configs(), 14 | key=['size'], 15 | ) 16 | @triton.jit 17 | def p_loss_forward_kernel( 18 | input_pointer, target_pointer, output_pointer, 19 | param, size, p_loss: tl.constexpr, reduction: tl.constexpr, 20 | BLOCK_SIZE: tl.constexpr, 21 | ): 22 | """ 23 | Measures the smooth L1, L1, squared L2, or Huber loss of the difference 24 | between the input and target. 25 | 26 | Args: 27 | input_pointer: Pointer to the input. 28 | The input must be of shape [size]. 29 | target_pointer: Pointer to the target. 30 | The target must be of shape [size]. 31 | output_pointer: Pointer to a container the error is written to. 32 | The container must be of shape [size] if reduction is 'none', 33 | and otherwise of shape [size/BLOCK_SIZE]. 34 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber). 35 | size: Number of elements in the input and target. 36 | p_loss: p-norm used to compute the error. 37 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss. 38 | reduction: Reduction strategy for the output. 39 | Options are 'none' for no reduction, 'mean' for averaging the error 40 | across all entries, and 'sum' for summing the error across all entries. 41 | If a reduction method is specified, the reduced result of each 42 | program is written to a separate index in the output container, 43 | which should later be summed. 44 | BLOCK_SIZE: Block size. 45 | """ 46 | # This program processes BLOCK_SIZE rows. 47 | pid = tl.program_id(axis=0) 48 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 49 | mask = offset < size 50 | 51 | input = tl.load(input_pointer + offset, mask=mask).to(tl.float32) 52 | target = tl.load(target_pointer + offset, mask=mask).to(tl.float32) 53 | diff = input - target 54 | 55 | if p_loss == 0: 56 | error = tl.where(diff < param, 0.5 * diff * diff / param, tl.abs(diff) - 0.5 * param) 57 | 58 | elif p_loss == 1: 59 | error = tl.abs(diff) 60 | 61 | elif p_loss == 2: 62 | error = diff * diff 63 | 64 | elif p_loss == 3: 65 | error = tl.where(diff < param, 0.5 * diff * diff, param * (tl.abs(diff) - 0.5 * param)) 66 | 67 | if reduction == 'none': 68 | tl.store(output_pointer + offset, error, mask=mask) 69 | 70 | elif reduction == 'mean': 71 | tl.store(output_pointer + pid, tl.sum(error) / size) 72 | 73 | elif reduction == 'sum': 74 | tl.store(output_pointer + pid, tl.sum(error)) 75 | 76 | 77 | @triton.autotune( 78 | configs=element_wise_kernel_configs(), 79 | key=['size'], 80 | ) 81 | @triton.jit 82 | def p_loss_backward_kernel( 83 | output_grad_pointer, input_pointer, target_pointer, 84 | input_grad_pointer, target_grad_pointer, param, size, 85 | p_loss: tl.constexpr, reduction: tl.constexpr, 86 | BLOCK_SIZE: tl.constexpr, 87 | ): 88 | """ 89 | Calculates the input gradient of the smooth L1 norm, L1 norm, L2 norm, or Huber loss. 90 | 91 | Args: 92 | output_grad_pointer: Pointer to the error's output gradients. 93 | The output gradients must be a scalar or of shape [size]. 94 | input_pointer: Pointer to the input. 95 | The input must be of shape [size]. 96 | target_pointer: Pointer to the target. 97 | The target must be of shape [size]. 98 | input_grad_pointer: Pointer to a container the input's gradients are written to. 99 | The container must be of shape [size]. 100 | target_grad_pointer: Pointer to a container the target's gradients are written to. 101 | The container must be of shape [size]. 102 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber). 103 | size: Number of elements in the input and target. 104 | p_loss: p-norm used to compute the error whose gradient is calculated. 105 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss. 106 | reduction: Reduction strategy for the output whose gradient is calculated. 107 | Options are 'none' for no reduction, 'mean' for averaging the error 108 | across all entries, and 'sum' for summing the error across all entries. 109 | BLOCK_SIZE: Block size. 110 | """ 111 | # This program processes BLOCK_SIZE rows. 112 | pid = tl.program_id(axis=0) 113 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 114 | mask = offset < size 115 | 116 | output_grad_mask = None 117 | if reduction == 'none': 118 | output_grad_pointer += offset 119 | output_grad_mask = mask 120 | 121 | input = tl.load(input_pointer + offset, mask=mask).to(tl.float32) 122 | target = tl.load(target_pointer + offset, mask=mask).to(tl.float32) 123 | diff = input - target 124 | output_grad = tl.load(output_grad_pointer, mask=output_grad_mask).to(tl.float32) 125 | 126 | if p_loss == 0: 127 | input_grad = tl.where(diff < param, diff / param, tl.where(0 <= diff, 1, -1)) 128 | 129 | elif p_loss == 1: 130 | input_grad = tl.where(0 <= diff, 1, -1) 131 | 132 | elif p_loss == 2: 133 | input_grad = 2 * diff 134 | 135 | elif p_loss == 3: 136 | input_grad = tl.where(diff < param, diff, param * tl.where(0 <= diff, 1, -1)) 137 | 138 | if reduction == 'mean': 139 | input_grad /= size 140 | 141 | input_grad *= output_grad 142 | tl.store(input_grad_pointer + offset, input_grad, mask=mask) 143 | tl.store(target_grad_pointer + offset, -input_grad, mask=mask) 144 | -------------------------------------------------------------------------------- /examples/regression/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains and benchmarks a multilayer perceptron (MLP) on synthetic regression data. 3 | """ 4 | 5 | 6 | import time 7 | from argparse import ArgumentParser 8 | from types import ModuleType 9 | from typing import Callable, Tuple 10 | 11 | import torch 12 | from torch.optim import SGD 13 | from torch.utils.data import DataLoader, TensorDataset 14 | 15 | import attorch 16 | from ..utils import AvgMeter, benchmark_fw_and_bw 17 | 18 | 19 | def create_dls( 20 | n_samples: int = 50_000, 21 | dim: int = 10, 22 | batch_size: int = 1_024, 23 | ) -> Tuple[DataLoader, DataLoader]: 24 | """ 25 | Creates synthetic regression data loaders. 26 | 27 | Args: 28 | n_samples: Number of samples to generate. 29 | dim: Dimensionality of synthetic data. 30 | batch_size: Batch size. 31 | 32 | Returns: 33 | Training and validation data loaders. 34 | """ 35 | torch.manual_seed(0) 36 | input = torch.randn(n_samples, dim) 37 | target = input @ torch.randn(dim, 1) + 0.1 * torch.randn(n_samples, 1) 38 | n_train = int(0.8 * n_samples) 39 | 40 | train_dataset = TensorDataset(input[:n_train], target[:n_train]) 41 | valid_dataset = TensorDataset(input[n_train:], target[n_train:]) 42 | 43 | train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 44 | drop_last=True) 45 | valid_dl = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, 46 | drop_last=True) 47 | 48 | return train_dl, valid_dl 49 | 50 | 51 | def train( 52 | model: torch.nn.Module, 53 | loss_fn: Callable, 54 | train_dl: DataLoader, 55 | valid_dl: DataLoader, 56 | epochs: int = 10, 57 | batch_size: int = 1_024, 58 | ) -> float: 59 | """ 60 | Trains and validates a regression model. 61 | 62 | Args: 63 | model: Model to train. 64 | loss_fn: Loss function. 65 | train_dl: Data loader for training. 66 | valid_dl: Data loader for validation. 67 | epochs: Number of epochs to train for. 68 | batch_size: Batch size. 69 | 70 | Returns: 71 | Total training and validation time. 72 | """ 73 | model = model.to('cuda') 74 | optim = SGD(model.parameters(), lr=batch_size / 64 * 0.005) 75 | optim.zero_grad() 76 | 77 | avg_meter = AvgMeter() 78 | start = time.time() 79 | 80 | for epoch in range(epochs): 81 | print(f'Epoch {epoch+1}/{epochs}') 82 | 83 | model.train() 84 | avg_meter.reset() 85 | for input, target in train_dl: 86 | input = input.to('cuda') 87 | target = target.to('cuda') 88 | 89 | loss = loss_fn(model(input), target) 90 | optim.zero_grad() 91 | 92 | avg_meter.update(loss.item(), len(input)) 93 | print(f'Training loss: {avg_meter.avg}') 94 | 95 | model.eval() 96 | avg_meter.reset() 97 | with torch.no_grad(): 98 | for input, target in valid_dl: 99 | input = input.to('cuda') 100 | target = target.to('cuda') 101 | 102 | loss = loss_fn(model(input), target) 103 | avg_meter.update(loss.item(), len(input)) 104 | print(f'Validation loss: {avg_meter.avg}') 105 | 106 | return time.time() - start 107 | 108 | 109 | def main( 110 | nn: ModuleType, 111 | n_samples: int = 50_000, 112 | dim: int = 10, 113 | epochs: int = 10, 114 | batch_size: int = 1_024, 115 | ) -> None: 116 | """ 117 | Trains and benchmarks a regression MLP on synthetic data. 118 | 119 | Args: 120 | nn: Neural network module used to construct the MLP. 121 | n_samples: Number of samples to generate. 122 | dim: Dimensionality of synthetic data. 123 | epochs: Number of epochs to train for. 124 | batch_size: Batch size. 125 | """ 126 | train_dl, valid_dl = create_dls(n_samples, dim, batch_size) 127 | model = nn.Sequential(nn.Linear(dim, dim // 2), 128 | nn.ReLU(), 129 | nn.Linear(dim // 2, 1)).to('cuda') 130 | loss_fn = nn.MSELoss() 131 | 132 | input, target = next(iter(train_dl)) 133 | input = input.to('cuda') 134 | target = target.to('cuda') 135 | 136 | for _ in range(10): 137 | model.train() 138 | with torch.autocast('cuda'): 139 | loss = loss_fn(model(input), target) 140 | loss.backward() 141 | 142 | model.eval() 143 | with torch.no_grad() and torch.autocast('cuda'): 144 | model(input) 145 | 146 | model.train() 147 | benchmark_fw_and_bw(model, input=input) 148 | 149 | print('Total training and validation time: ' 150 | f'{train(model, loss_fn, train_dl, valid_dl, epochs, batch_size)}') 151 | 152 | 153 | if __name__ == '__main__': 154 | parser = ArgumentParser(description='Trains and benchmarks a multilayer perceptron (MLP) on synthetic regression data.') 155 | parser.add_argument('--n_samples', 156 | type=int, 157 | default=50_000, 158 | help='Number of samples to generate') 159 | parser.add_argument('--dim', 160 | type=int, 161 | default=10, 162 | help='Dimensionality of synthetic data') 163 | parser.add_argument('--epochs', 164 | type=int, 165 | default=10, 166 | help='Number of training epochs') 167 | parser.add_argument('--batch_size', 168 | type=int, 169 | default=1_024, 170 | help='Batch size') 171 | args = parser.parse_args() 172 | 173 | print('attorch run:') 174 | main(attorch.nn, n_samples=args.n_samples, dim=args.dim, 175 | epochs=args.epochs, batch_size=args.batch_size) 176 | 177 | print('PyTorch run:') 178 | main(torch.nn, n_samples=args.n_samples, dim=args.dim, 179 | epochs=args.epochs, batch_size=args.batch_size) 180 | -------------------------------------------------------------------------------- /attorch/rms_norm_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Root mean square normalization with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from torch.amp import custom_bwd, custom_fwd 12 | from triton import cdiv 13 | 14 | from .rms_norm_kernels import rms_norm_backward_kernel, rms_norm_forward_kernel 15 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic 16 | from .types import Context, Device 17 | 18 | 19 | class RMSNormAutoGrad(torch.autograd.Function): 20 | """ 21 | Autodiff for root mean square normalization. 22 | """ 23 | @staticmethod 24 | @custom_fwd(device_type='cuda') 25 | def forward( 26 | ctx: Context, 27 | input: Tensor, 28 | weight: Optional[Tensor] = None, 29 | eps: Optional[float] = None, 30 | ) -> Tensor: 31 | """ 32 | Root-mean-square-normalizes the input. 33 | 34 | Args: 35 | ctx: Context for variable storage. 36 | input: Input to root-mean-square-normalize. 37 | Can have arbitrary shape. 38 | weight: Optional weights for linear transform. 39 | If provided, must be of shape [feat_dim]. 40 | eps: Epsilon added in the square root in the denominator 41 | to avoid division by zero. If None, it defaults to 42 | torch.finfo(input.dtype).eps. 43 | 44 | Returns: 45 | Root-mean-square-normalized input. 46 | """ 47 | flattened_input = input.unsqueeze(0) if input.ndim == 1 else input 48 | flattened_input = flattened_input.flatten(0, -2) 49 | batch_dim, feat_dim = flattened_input.shape 50 | eps = torch.finfo(input.dtype).eps if eps is None else eps 51 | 52 | output = torch.empty_like(flattened_input) 53 | 54 | scale_by_weight = weight is not None 55 | requires_grad = (input.requires_grad or 56 | (scale_by_weight and weight.requires_grad)) 57 | 58 | if requires_grad: 59 | inv_rms = torch.empty(batch_dim, 60 | device=input.device, 61 | dtype=torch.float32) 62 | 63 | else: 64 | inv_rms = None 65 | 66 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 67 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),) 68 | rms_norm_forward_kernel[grid](flattened_input, weight, 69 | inv_rms, output, 70 | batch_dim, feat_dim, 71 | *flattened_input.stride(), *output.stride(), 72 | eps, 73 | scale_by_weight=scale_by_weight, 74 | save_stats=requires_grad) 75 | 76 | ctx.scale_by_weight = scale_by_weight 77 | if requires_grad: 78 | ctx.save_for_backward(flattened_input, inv_rms, weight) 79 | 80 | return output.view_as(input) 81 | 82 | @staticmethod 83 | @custom_bwd(device_type='cuda') 84 | def backward( 85 | ctx: Context, 86 | output_grad: Tensor, 87 | ) -> Tuple[Optional[Tensor], ...]: 88 | """ 89 | Calculates the input gradient of root mean square normalization. 90 | 91 | Args: 92 | ctx: Context containing stored variables. 93 | output_grad: Output gradients. 94 | Must be the same shape as the output. 95 | 96 | Returns: 97 | Input gradient of root mean square normalization. 98 | """ 99 | scale_by_weight = ctx.scale_by_weight 100 | (flattened_input, inv_rms, weight) = ctx.saved_tensors 101 | flattened_output_grad = output_grad.view_as(flattened_input) 102 | 103 | batch_dim, feat_dim = flattened_output_grad.shape 104 | input_grad = torch.empty_like(flattened_output_grad) 105 | 106 | if scale_by_weight: 107 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim, 108 | 'feat_dim': feat_dim}) 109 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH 110 | 111 | weight_grad = torch.empty((out_batch_dim, feat_dim), 112 | device=flattened_input.device) 113 | 114 | else: 115 | weight_grad = None 116 | 117 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 118 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),) 119 | rms_norm_backward_kernel[grid](flattened_output_grad, flattened_input, 120 | inv_rms, weight, 121 | input_grad, weight_grad, 122 | batch_dim, feat_dim, 123 | *flattened_output_grad.stride(), 124 | *flattened_input.stride(), 125 | *input_grad.stride(), 126 | *weight_grad.stride() if scale_by_weight else (1, 1), 127 | scale_by_weight=scale_by_weight) 128 | 129 | if scale_by_weight: 130 | weight_grad = weight_grad.sum(dim=0) 131 | 132 | # Pads output with None because a gradient is necessary for 133 | # all input arguments. 134 | return input_grad.view_as(output_grad), weight_grad, None 135 | 136 | 137 | class RMSNorm(nn.RMSNorm): 138 | """ 139 | Root-mean-square-normalizes the input. 140 | See also base class. 141 | 142 | Args: 143 | normalized_shape: Dimensionality of last feature that is normalized. 144 | eps: Epsilon added in the square root in the denominator 145 | to avoid division by zero. If None, it defaults to 146 | torch.finfo(input.dtype).eps. 147 | elementwise_affine: Flag for scaling the normalized output by weights. 148 | device: Device to use. 149 | dtype: Dtype of layer. 150 | 151 | Raises: 152 | RuntimeError: Normalized shape was not an integer. 153 | """ 154 | def __init__( 155 | self, 156 | normalized_shape: int, 157 | eps: Optional[float] = None, 158 | elementwise_affine: bool = True, 159 | device: Device = 'cuda', 160 | dtype: torch.dtype = torch.float32, 161 | ) -> None: 162 | if not isinstance(normalized_shape, int): 163 | raise RuntimeError('Normalized shape must be an integer.') 164 | 165 | super().__init__(normalized_shape, eps, elementwise_affine, device, dtype) 166 | 167 | def forward(self, input: Tensor) -> Tensor: 168 | return RMSNormAutoGrad.apply(input, self.weight, self.eps) 169 | -------------------------------------------------------------------------------- /attorch/multi_head_attention_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-headed attention with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from functools import partial 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import Tensor 11 | from torch import nn 12 | from torch.nn import functional as F 13 | 14 | from .multi_head_attention_kernels import _attention 15 | from .types import Device 16 | 17 | 18 | def extract_heads(input: Tensor, num_heads: int) -> Tensor: 19 | """ 20 | Reshapes the projected input to extract heads for multi-headed attention. 21 | 22 | Args: 23 | input: Input to reshape and extract heads from. 24 | num_heads: Number of heads to extract. 25 | 26 | Returns: 27 | Input reshaped with its heads extracted. 28 | """ 29 | batch_dim, n_tokens, _ = input.shape 30 | return input.reshape(batch_dim, n_tokens, num_heads, -1).transpose(1, 2) 31 | 32 | 33 | class MultiheadAttention(nn.MultiheadAttention): 34 | """ 35 | Applies multi-headed scaled dot-product attention to the inputs. 36 | See also base class. 37 | 38 | Args: 39 | embed_dim: Dimensionality of the query, key, and value inputs. 40 | num_heads: Number of heads. 41 | dropout: Dropout probability on the attention scores, 42 | currently not supported. 43 | bias: Flag for adding bias to the query-key-value-output projections. 44 | add_bias_kv: Flag for appending a bias vector to the key and value sequences, 45 | currently not supported. 46 | add_zero_attn: Flag for appending a zero vector to the key and value sequences, 47 | currently not supported. 48 | kdim: Dimensionality of the key input, which has to be None or equal to embed_dim. 49 | vdim: Dimensionality of the value input, which has to be None or equal to embed_dim. 50 | batch_first: Flag to indicate if the batch dimension comes first in the input, 51 | currently otherwise is not supported. 52 | device: Device to use. 53 | dtype: Dtype of layer. 54 | 55 | Raises: 56 | RuntimeError: 1. Dropout on the attention scores is requested. 57 | 2. Appending a bias vector to the key and value sequences is requested. 58 | 3. Appending a zero vector to the key and value sequences is requested. 59 | 4. The query and key dimensionalities are unequal. 60 | 5. The query and value dimensionalities are unequal. 61 | 6. The input is not batch-first. 62 | """ 63 | def __init__( 64 | self, 65 | embed_dim: int, 66 | num_heads: int, 67 | dropout: float = 0.0, 68 | bias: bool = True, 69 | add_bias_kv: bool = False, 70 | add_zero_attn: bool = False, 71 | kdim: Optional[int] = None, 72 | vdim: Optional[int] = None, 73 | batch_first: bool = True, 74 | device: Device = 'cuda', 75 | dtype: torch.dtype = torch.float32, 76 | ) -> None: 77 | if dropout > 0.0: 78 | raise RuntimeError('Dropout on the attention scores is not supported.') 79 | 80 | if add_bias_kv: 81 | raise RuntimeError('Appending a bias vector to the key and value ' 82 | 'sequences is not supported.') 83 | 84 | if add_zero_attn: 85 | raise RuntimeError('Appending a zero vector to the key and value ' 86 | 'sequences is not supported.') 87 | 88 | if kdim is not None and kdim != embed_dim: 89 | raise RuntimeError(f'The key dimensionality ({kdim}) is not equal to ' 90 | f'the query dimensionality ({embed_dim}).') 91 | 92 | if vdim is not None and vdim != embed_dim: 93 | raise RuntimeError(f'The value dimensionality ({vdim}) is not equal to ' 94 | f'the query dimensionality ({embed_dim}).') 95 | 96 | if not batch_first: 97 | raise RuntimeError('The input must be batch-first.') 98 | 99 | super().__init__(embed_dim, num_heads, dropout, bias, 100 | add_bias_kv, add_zero_attn, 101 | kdim, vdim, batch_first, 102 | device, dtype) 103 | 104 | def forward( 105 | self, 106 | input_q: Tensor, 107 | input_k: Optional[Tensor] = None, 108 | input_v: Optional[Tensor] = None, 109 | causal: bool = False, 110 | sequence_parallel: bool = False, 111 | ) -> Tensor: 112 | assert input_v is None or input_k is not None, \ 113 | f'Key inputs must be provided if value inputs have been passed' 114 | 115 | input_k = input_q if input_k is None else input_k 116 | input_v = input_k if input_v is None else input_v 117 | 118 | if input_k is input_v: 119 | if input_q is input_k: 120 | qkv = F.linear(input_q, self.in_proj_weight, self.in_proj_bias) 121 | q, k, v = map(partial(extract_heads, num_heads=self.num_heads), 122 | qkv.chunk(3, dim=-1)) 123 | 124 | else: 125 | weight_q, weight_kv = torch.split(self.in_proj_weight, 126 | [self.embed_dim, 2 * self.embed_dim], 127 | dim=0) 128 | 129 | if self.in_proj_bias is None: 130 | bias_q = bias_kv = None 131 | 132 | else: 133 | bias_q, bias_kv = torch.split(self.in_proj_bias, 134 | [self.embed_dim, 2 * self.embed_dim], 135 | dim=0) 136 | 137 | q = F.linear(input_q, weight_q, bias_q) 138 | kv = F.linear(input_k, weight_kv, bias_kv) 139 | q, k, v = map(partial(extract_heads, num_heads=self.num_heads), 140 | [q, *kv.chunk(2, dim=-1)]) 141 | 142 | else: 143 | weight_q, weight_k, weight_v = torch.split(self.in_proj_weight, 144 | 3*[self.embed_dim], 145 | dim=0) 146 | 147 | if self.in_proj_bias is None: 148 | bias_q = bias_k = bias_v = None 149 | 150 | else: 151 | bias_q, bias_k, bias_v = torch.split(self.in_proj_bias, 152 | 3*[self.embed_dim], 153 | dim=0) 154 | 155 | q = F.linear(input_q, weight_q, bias_q) 156 | k = F.linear(input_k, weight_k, bias_k) 157 | v = F.linear(input_v, weight_v, bias_v) 158 | q, k, v = map(partial(extract_heads, num_heads=self.num_heads), 159 | [q, k, v]) 160 | 161 | output = _attention.apply(q, k, v, causal, 0.5, sequence_parallel) 162 | output = output.transpose(1, 2).reshape(len(input_q), -1, self.embed_dim) 163 | return self.out_proj(output) 164 | -------------------------------------------------------------------------------- /attorch/cross_entropy_loss_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cross entropy loss with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from triton import cdiv 12 | 13 | from .cross_entropy_loss_kernels import cross_entropy_loss_backward_kernel, \ 14 | cross_entropy_loss_forward_kernel 15 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic 16 | from .types import Context 17 | from .utils import get_output_dtype 18 | 19 | 20 | class CrossEntropyLossAutoGrad(torch.autograd.Function): 21 | """ 22 | Autodiff for cross entropy loss. 23 | """ 24 | @staticmethod 25 | def forward( 26 | ctx: Context, 27 | input: Tensor, 28 | target: Tensor, 29 | weight: Optional[Tensor] = None, 30 | ) -> Tensor: 31 | """ 32 | Measures the mean cross entropy loss between the input and target, 33 | with optional reweighing of each class. 34 | 35 | Args: 36 | ctx: Context for variable storage. 37 | input: Input. 38 | Must be of shape [batch_dim, feat_dim]. 39 | target: Target. 40 | Must be of shape [batch_dim]. 41 | weight: Optional class weight vector, with None for no reweighing. 42 | If provided, must be of shape [feat_dim]. 43 | 44 | Returns: 45 | Loss. 46 | """ 47 | assert input.ndim == 2, f'Inputs of rank other than 2 not valid' 48 | assert len(input) == len(target), \ 49 | f'Incompatible input shape ({input.shape}) and target shape ({target.shape})' 50 | assert weight is None or len(weight) == input.shape[1], \ 51 | f'Dimensionality of weight vector ({len(weight)}) and input features ({input.shape[1]}) not equal' 52 | 53 | batch_dim, feat_dim = input.shape 54 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim, 55 | 'feat_dim': feat_dim}) 56 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH 57 | weighted = weight is not None 58 | 59 | output_dtype = get_output_dtype(input.dtype, autocast='fp32') 60 | output = torch.empty(out_batch_dim, 61 | dtype=output_dtype, 62 | device=input.device) 63 | 64 | if weighted: 65 | sum_weights = torch.empty_like(output, dtype=torch.float32) 66 | 67 | else: 68 | sum_weights = None 69 | 70 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 71 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),) 72 | cross_entropy_loss_forward_kernel[grid](input, target, weight, sum_weights, output, 73 | batch_dim, feat_dim, 74 | *input.stride(), 75 | weighted=weighted) 76 | output = output.sum() 77 | 78 | if weighted: 79 | sum_weights = sum_weights.sum() 80 | output /= sum_weights 81 | 82 | ctx.sum_weights = sum_weights 83 | ctx.weight = weight 84 | ctx.output_dtype = output_dtype 85 | if input.requires_grad: 86 | ctx.save_for_backward(input, target) 87 | 88 | return output 89 | 90 | @staticmethod 91 | def backward( 92 | ctx: Context, 93 | output_grad: Tensor, 94 | ) -> Tuple[Optional[Tensor], ...]: 95 | """ 96 | Calculates the input gradient of the loss. 97 | 98 | Args: 99 | ctx: Context containing stored variables. 100 | output_grad: Output gradients. 101 | Must be a scalar. 102 | 103 | Returns: 104 | Input gradient of the loss. 105 | """ 106 | (input, target) = ctx.saved_tensors 107 | batch_dim, feat_dim = input.shape 108 | input_grad = torch.empty_like(input, dtype=ctx.output_dtype) 109 | 110 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 111 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),) 112 | cross_entropy_loss_backward_kernel[grid](output_grad, target, input, ctx.weight, 113 | ctx.sum_weights, input_grad, 114 | batch_dim, feat_dim, 115 | *input.stride(), 116 | *input_grad.stride(), 117 | weighted=ctx.weight is not None) 118 | 119 | # Pads output with None because a gradient is necessary for 120 | # all input arguments. 121 | return input_grad, None, None 122 | 123 | 124 | class CrossEntropyLoss(nn.CrossEntropyLoss): 125 | """ 126 | Measures the mean cross entropy loss between the input and target, 127 | with optional reweighing of each class. 128 | See also base class. 129 | 130 | Note: To keep its implementation compact, this module does not support 131 | advanced features such as inputs with spatial dimensions or label smoothing. 132 | For greater flexibility, a combination of attorch.LogSoftmax and 133 | attorch.NLLLoss can be used. 134 | 135 | Args: 136 | reduction: Reduction strategy for the output. Only 'mean' is supported. 137 | Providing size_average and reduce overrides this argument. 138 | size_average: Flag for averaging instead of summing the loss entries 139 | when reduce is True. Only averaging is supported. 140 | reduce: Flag for averaging or summing all the loss entries instead of 141 | returning a loss per element. Only averaging is supported. 142 | weight: Optional class weight vector, with None for no reweighing. 143 | If provided, must be of shape [feat_dim]. 144 | ignore_index: This argument is not supported. 145 | label_smoothing: This argument is not supported. 146 | 147 | Raises: 148 | RuntimeError: 1. Reduction method was not set to 'mean'. 149 | 2. Label smoothing is requested. 150 | """ 151 | def __init__( 152 | self, 153 | reduction: str = 'mean', 154 | size_average: Optional[bool] = None, 155 | reduce: Optional[bool] = None, 156 | weight: Optional[Tensor] = None, 157 | ignore_index: int = -100, 158 | label_smoothing: float = 0.0, 159 | ) -> None: 160 | super().__init__(weight, size_average, ignore_index, reduce, 161 | reduction, label_smoothing) 162 | 163 | if self.reduction != 'mean': 164 | raise RuntimeError('Cross entropy only supports averaging the loss.') 165 | 166 | if label_smoothing > 0.0: 167 | raise RuntimeError('Cross entropy does not support label smoothing.') 168 | 169 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 170 | return CrossEntropyLossAutoGrad.apply(input, target, self.weight) 171 | -------------------------------------------------------------------------------- /attorch/nll_loss_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Negative log likelihood loss with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from triton import cdiv 12 | 13 | from .nll_loss_kernels import nll_loss_backward_kernel, nll_loss_forward_kernel, \ 14 | BLOCK_SIZE_BATCH_heuristic 15 | from .types import Context 16 | from .utils import get_output_dtype 17 | 18 | 19 | class NLLLossAutoGrad(torch.autograd.Function): 20 | """ 21 | Autodiff for negative log likelihood loss. 22 | """ 23 | @staticmethod 24 | def forward( 25 | ctx: Context, 26 | input: Tensor, 27 | target: Tensor, 28 | reduction: str, 29 | weight: Optional[Tensor] = None, 30 | ) -> Tensor: 31 | """ 32 | Measures the negative log likelihood loss between the input and target, 33 | with optional reweighing of each class. 34 | 35 | Args: 36 | ctx: Context for variable storage. 37 | input: Input. 38 | Must be of shape [batch_dim, feat_dim, ...], 39 | where ... denotes an arbitrary number of spatial dimensions. 40 | target: Target. 41 | Must be of shape [batch_dim, ...], 42 | where ... denotes the same spatial dimensions as the input. 43 | reduction: Reduction strategy for the output. 44 | Options are 'none' for no reduction, 'mean' for averaging the loss 45 | across all entries, and 'sum' for summing the loss across all entries. 46 | weight: Optional class weight vector, with None for no reweighing. 47 | If provided, must be of shape [feat_dim]. 48 | 49 | Returns: 50 | Loss. 51 | """ 52 | assert len(input) == len(target) and input.shape[2:] == target.shape[1:], \ 53 | f'Incompatible input shape ({input.shape}) and target shape ({target.shape})' 54 | assert weight is None or len(weight) == input.shape[1], \ 55 | f'Dimensionality of weight vector ({len(weight)}) and input features ({input.shape[1]}) not equal' 56 | 57 | flattened_input = input.unsqueeze(-1) if input.ndim == 2 else input 58 | flattened_input = flattened_input.flatten(2, -1) 59 | 60 | flattened_target = target.unsqueeze(-1) if target.ndim == 1 else target 61 | flattened_target = flattened_target.flatten(1, -1) 62 | 63 | batch_dim, _, spatial_dim = flattened_input.shape 64 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim, 65 | 'spatial_dim': spatial_dim}) 66 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH 67 | 68 | output_dtype = get_output_dtype(input.dtype, autocast='fp32') 69 | sum_weights = (torch.empty(out_batch_dim, dtype=torch.float32, device=input.device) 70 | if reduction == 'mean' else None) 71 | output = (torch.empty_like(flattened_target, dtype=output_dtype) 72 | if reduction == 'none' else 73 | torch.empty(out_batch_dim, dtype=output_dtype, device=input.device)) 74 | 75 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 76 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),) 77 | nll_loss_forward_kernel[grid](input, target, weight, sum_weights, output, 78 | batch_dim, spatial_dim, 79 | *flattened_input.stride(), 80 | *flattened_target.stride(), 81 | *output.stride() if reduction == 'none' else (1, 1), 82 | reduction=reduction, 83 | weighted=weight is not None) 84 | 85 | if reduction != 'none': 86 | output = output.sum() 87 | 88 | if reduction == 'mean' and weight is not None: 89 | sum_weights = sum_weights.sum() 90 | output /= sum_weights 91 | 92 | else: 93 | output = output.view_as(target) 94 | 95 | ctx.sum_weights = sum_weights 96 | ctx.reduction = reduction 97 | ctx.weight = weight 98 | ctx.output_dtype = output_dtype 99 | if input.requires_grad: 100 | ctx.save_for_backward(input, flattened_target) 101 | 102 | return output 103 | 104 | @staticmethod 105 | def backward( 106 | ctx: Context, 107 | output_grad: Tensor, 108 | ) -> Tuple[Optional[Tensor], ...]: 109 | """ 110 | Calculates the input gradient of the loss. 111 | 112 | Args: 113 | ctx: Context containing stored variables. 114 | output_grad: Output gradients. 115 | Must be the same shape as the output. 116 | 117 | Returns: 118 | Input gradient of the loss. 119 | """ 120 | (input, flattened_target) = ctx.saved_tensors 121 | flattened_input = input.view(len(flattened_target), -1, 122 | flattened_target.shape[-1]) 123 | output_grad = (output_grad.view_as(flattened_target) 124 | if output_grad.ndim > 0 else output_grad) 125 | 126 | batch_dim, _, spatial_dim = flattened_input.shape 127 | input_grad = torch.zeros_like(flattened_input, dtype=ctx.output_dtype) 128 | 129 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 130 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),) 131 | nll_loss_backward_kernel[grid](output_grad, flattened_target, ctx.weight, 132 | ctx.sum_weights, input_grad, 133 | batch_dim, spatial_dim, 134 | *output_grad.stride() if ctx.reduction == 'none' else (1, 1), 135 | *flattened_target.stride(), 136 | *input_grad.stride(), 137 | reduction=ctx.reduction, 138 | weighted=ctx.weight is not None) 139 | 140 | # Pads output with None because a gradient is necessary for 141 | # all input arguments. 142 | return input_grad.view_as(input), None, None, None 143 | 144 | 145 | class NLLLoss(nn.NLLLoss): 146 | """ 147 | Measures the negative log likelihood loss between the input and target, 148 | with optional reweighing of each class. 149 | See also base class. 150 | 151 | Args: 152 | reduction: Reduction strategy for the output. 153 | Options are 'none' for no reduction, 'mean' for averaging the loss 154 | across all entries, and 'sum' for summing the loss across all entries. 155 | Providing size_average and reduce overrides this argument. 156 | size_average: Flag for averaging instead of summing the loss entries 157 | when reduce is True. 158 | reduce: Flag for averaging or summing all the loss entries instead of 159 | returning a loss per element. 160 | weight: Optional class weight vector, with None for no reweighing. 161 | If provided, must be of shape [feat_dim]. 162 | ignore_index: This argument is not supported. 163 | """ 164 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 165 | return NLLLossAutoGrad.apply(input, target, self.reduction, self.weight) 166 | -------------------------------------------------------------------------------- /attorch/p_loss_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | p-norm-induced losses with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from triton import cdiv 12 | 13 | from .p_loss_kernels import p_loss_backward_kernel, p_loss_forward_kernel 14 | from .types import Context 15 | from .utils import get_output_dtype 16 | 17 | 18 | class PLossAutoGrad(torch.autograd.Function): 19 | """ 20 | Autodiff for p-losses. 21 | """ 22 | @staticmethod 23 | def forward( 24 | ctx: Context, 25 | input: Tensor, 26 | target: Tensor, 27 | p_loss: int, 28 | reduction: str, 29 | param: float = 1.0, 30 | ) -> Tensor: 31 | """ 32 | Measures the smooth L1, L1, squared L2, or Huber loss of the difference 33 | between the input and target. 34 | 35 | Args: 36 | ctx: Context for variable storage. 37 | input: Input. 38 | Can have arbitrary shape. 39 | target: Target. 40 | Must be the same shape as input. 41 | p_loss: p-norm used to compute the error. 42 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss. 43 | reduction: Reduction strategy for the output. 44 | Options are 'none' for no reduction, 'mean' for averaging the error 45 | across all entries, and 'sum' for summing the error across all entries. 46 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber). 47 | 48 | Returns: 49 | Error. 50 | """ 51 | assert input.shape == target.shape, \ 52 | f'Input shape {input.shape} and target shape {target.shape} not equal' 53 | 54 | output_dtype = get_output_dtype(input.dtype, autocast='fp32') 55 | 56 | ctx.p_loss = p_loss 57 | ctx.reduction = reduction 58 | ctx.param = param 59 | ctx.output_dtype = output_dtype 60 | if input.requires_grad or target.requires_grad: 61 | ctx.save_for_backward(input, target) 62 | 63 | flattened_input = input.flatten() 64 | flattened_target = target.flatten() 65 | size = len(flattened_input) 66 | 67 | output = (torch.empty_like(flattened_input, dtype=output_dtype) if reduction == 'none' 68 | else torch.empty(cdiv(size, 32), dtype=output_dtype, device=input.device)) 69 | 70 | # Launches 1D grid where each program operates over 71 | # BLOCK_SIZE elements. 72 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),) 73 | p_loss_forward_kernel[grid](flattened_input, flattened_target, output, 74 | param, size, p_loss=p_loss, reduction=reduction) 75 | 76 | if reduction != 'none': 77 | BLOCK_SIZE = p_loss_forward_kernel.best_config.kwargs['BLOCK_SIZE'] 78 | output = output[:cdiv(size, BLOCK_SIZE)].sum() 79 | 80 | else: 81 | output = output.view_as(input) 82 | 83 | return output 84 | 85 | @staticmethod 86 | def backward( 87 | ctx: Context, 88 | output_grad: Tensor, 89 | ) -> Tuple[Optional[Tensor], ...]: 90 | """ 91 | Calculates the input gradient of the error. 92 | 93 | Args: 94 | ctx: Context containing stored variables. 95 | output_grad: Output gradients. 96 | Must be the same shape as the output. 97 | 98 | Returns: 99 | Input gradient of the error. 100 | """ 101 | (input, target) = ctx.saved_tensors 102 | flattened_input = input.flatten() 103 | flattened_target = target.flatten() 104 | output_grad = output_grad.flatten() 105 | 106 | size = len(flattened_input) 107 | input_grad = torch.empty_like(flattened_input, dtype=ctx.output_dtype) 108 | target_grad = torch.empty_like(flattened_target, dtype=ctx.output_dtype) 109 | 110 | # Launches 1D grid where each program operates over 111 | # BLOCK_SIZE elements. 112 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),) 113 | p_loss_backward_kernel[grid](output_grad, flattened_input, flattened_target, 114 | input_grad, target_grad, ctx.param, size, 115 | p_loss=ctx.p_loss, reduction=ctx.reduction) 116 | 117 | # Pads output with None because a gradient is necessary for 118 | # all input arguments. 119 | return input_grad.view_as(input), target_grad.view_as(input), None, None 120 | 121 | 122 | class L1Loss(nn.L1Loss): 123 | """ 124 | Measures the L1 error (mean absolute error) between the input and target. 125 | See also base class. 126 | 127 | Args: 128 | reduction: Reduction strategy for the output. 129 | Options are 'none' for no reduction, 'mean' for averaging the error 130 | across all entries, and 'sum' for summing the error across all entries. 131 | Providing size_average and reduce overrides this argument. 132 | size_average: Flag for averaging instead of summing the error entries 133 | when reduce is True. 134 | reduce: Flag for averaging or summing all the error entries instead of 135 | returning a loss per element. 136 | """ 137 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 138 | return PLossAutoGrad.apply(input, target, 1, self.reduction) 139 | 140 | 141 | class MSELoss(nn.MSELoss): 142 | """ 143 | Measures the squared L2 error (mean squared error) between the input and target. 144 | See also base class. 145 | 146 | Args: 147 | reduction: Reduction strategy for the output. 148 | Options are 'none' for no reduction, 'mean' for averaging the error 149 | across all entries, and 'sum' for summing the error across all entries. 150 | Providing size_average and reduce overrides this argument. 151 | size_average: Flag for averaging instead of summing the error entries 152 | when reduce is True. 153 | reduce: Flag for averaging or summing all the error entries instead of 154 | returning a loss per element. 155 | """ 156 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 157 | return PLossAutoGrad.apply(input, target, 2, self.reduction) 158 | 159 | 160 | class SmoothL1Loss(nn.SmoothL1Loss): 161 | """ 162 | Measures the smooth L1 error between the input and target. 163 | See also base class. 164 | 165 | Args: 166 | reduction: Reduction strategy for the output. 167 | Options are 'none' for no reduction, 'mean' for averaging the error 168 | across all entries, and 'sum' for summing the error across all entries. 169 | Providing size_average and reduce overrides this argument. 170 | size_average: Flag for averaging instead of summing the error entries 171 | when reduce is True. 172 | reduce: Flag for averaging or summing all the error entries instead of 173 | returning a loss per element. 174 | beta: Beta value for the softening threshold. 175 | """ 176 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 177 | return PLossAutoGrad.apply(input, target, 0, self.reduction, self.beta) 178 | 179 | 180 | class HuberLoss(nn.HuberLoss): 181 | """ 182 | Measures the Huber loss between the input and target. 183 | See also base class. 184 | 185 | Args: 186 | reduction: Reduction strategy for the output. 187 | Options are 'none' for no reduction, 'mean' for averaging the error 188 | across all entries, and 'sum' for summing the error across all entries. 189 | beta: Beta value for the softening threshold. 190 | """ 191 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 192 | return PLossAutoGrad.apply(input, target, 3, self.reduction, self.delta) 193 | -------------------------------------------------------------------------------- /attorch/cross_entropy_loss_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kernels for cross entropy loss. 3 | """ 4 | 5 | 6 | import triton 7 | import triton.language as tl 8 | from triton import next_power_of_2 9 | 10 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic 11 | from .utils import warps_kernel_configs 12 | 13 | 14 | @triton.autotune( 15 | configs=warps_kernel_configs(), 16 | key=['batch_dim', 'feat_dim'], 17 | ) 18 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic, 19 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])}) 20 | @triton.jit 21 | def cross_entropy_loss_forward_kernel( 22 | input_pointer, target_pointer, weight_pointer, 23 | sum_weights_pointer, output_pointer, 24 | batch_dim, feat_dim, 25 | input_batch_stride, input_feat_stride, 26 | weighted: tl.constexpr, 27 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr, 28 | ): 29 | """ 30 | Measures the mean cross entropy loss between the input and target, 31 | with optional reweighing of each class. 32 | 33 | Args: 34 | input_pointer: Pointer to the input. 35 | The input must be of shape [batch_dim, feat_dim]. 36 | target_pointer: Pointer to the target. 37 | The target must be of shape [batch_dim]. 38 | weight_pointer: Pointer to an optional class weight vector. 39 | The class weight vector, if provided, must be of shape [feat_dim]. 40 | sum_weights_pointer: Pointer to a container the sum of the class weights is written to. 41 | The container must be of shape [batch_dim/BLOCK_SIZE_BATCH]. 42 | output_pointer: Pointer to a container the loss is written to. 43 | The container must be of shape [batch_dim/BLOCK_SIZE_BATCH]. 44 | batch_dim: Batch dimension. 45 | feat_dim: Dimensionality of the features. 46 | input_batch_stride: Stride necessary to jump one element along the 47 | input's batch dimension. 48 | input_feat_stride: Stride necessary to jump one element along the 49 | input's feature dimension. 50 | weighted: Flag for weighing each class. 51 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 52 | BLOCK_SIZE_FEAT: Block size across the feature dimension. 53 | """ 54 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns. 55 | batch_pid = tl.program_id(axis=0) 56 | 57 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) 58 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT) 59 | 60 | batch_mask = batch_offset < batch_dim 61 | feat_mask = feat_offset < feat_dim 62 | 63 | target = tl.load(target_pointer + batch_offset, mask=batch_mask) 64 | 65 | pred_pointer = (input_pointer + 66 | input_feat_stride * target + 67 | input_batch_stride * batch_offset) 68 | input_pointer += (input_batch_stride * batch_offset[:, None] + 69 | input_feat_stride * feat_offset[None, :]) 70 | 71 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :], 72 | other=-float('inf')).to(tl.float32) 73 | pred = tl.load(pred_pointer, mask=batch_mask).to(tl.float32) 74 | mx = tl.max(input, axis=1) 75 | input -= mx[:, None] 76 | loss = tl.log(tl.sum(tl.exp(input), axis=1)) - pred + mx 77 | 78 | if weighted: 79 | weight = tl.load(weight_pointer + target, mask=batch_mask).to(tl.float32) 80 | loss *= weight 81 | tl.store(sum_weights_pointer + batch_pid, tl.sum(weight)) 82 | 83 | else: 84 | loss /= batch_dim 85 | 86 | tl.store(output_pointer + batch_pid, tl.sum(loss)) 87 | 88 | 89 | @triton.autotune( 90 | configs=warps_kernel_configs(), 91 | key=['batch_dim', 'feat_dim'], 92 | ) 93 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic, 94 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])}) 95 | @triton.jit 96 | def cross_entropy_loss_backward_kernel( 97 | output_grad_pointer, target_pointer, input_pointer, weight_pointer, 98 | sum_weights_pointer, input_grad_pointer, 99 | batch_dim, feat_dim, 100 | input_batch_stride, input_feat_stride, 101 | input_grad_batch_stride, input_grad_feat_stride, 102 | weighted: tl.constexpr, 103 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr, 104 | ): 105 | """ 106 | Calculates the input gradient of cross entropy loss. 107 | 108 | Args: 109 | output_grad_pointer: Pointer to the loss's output gradients. 110 | The output gradient must be a scalar. 111 | target_pointer: Pointer to the target. 112 | The target must be of shape [batch_dim]. 113 | input_pointer: Pointer to the input. 114 | The input must be of shape [batch_dim, feat_dim]. 115 | weight_pointer: Pointer to an optional class weight vector. 116 | The class weight vector, if provided, must be of shape [feat_dim]. 117 | sum_weights_pointer: Pointer to the sum of the class weights if the classes were weighed. 118 | The sum of weights must be a scalar. 119 | input_grad_pointer: Pointer to a container the input's gradients are written to. 120 | The container must be of shape [batch_dim, feat_dim]. 121 | batch_dim: Batch dimension. 122 | feat_dim: Dimensionality of the features. 123 | input_batch_stride: Stride necessary to jump one element along the 124 | input's batch dimension. 125 | input_feat_stride: Stride necessary to jump one element along the 126 | input's feature dimension. 127 | input_grad_batch_stride: Stride necessary to jump one element along the 128 | input gradient container's batch dimension. 129 | input_grad_feat_stride: Stride necessary to jump one element along the 130 | input gradient container's feature dimension. 131 | weighted: Flag for weighing each class. 132 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 133 | BLOCK_SIZE_FEAT: Block size across the feature dimension. 134 | """ 135 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns. 136 | batch_pid = tl.program_id(axis=0) 137 | 138 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) 139 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT) 140 | 141 | batch_mask = batch_offset < batch_dim 142 | feat_mask = feat_offset < feat_dim 143 | 144 | input_pointer += (input_batch_stride * batch_offset[:, None] + 145 | input_feat_stride * feat_offset[None, :]) 146 | input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] + 147 | input_grad_feat_stride * feat_offset[None, :]) 148 | 149 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :], 150 | other=-float('inf')).to(tl.float32) 151 | input -= tl.max(input, axis=1)[:, None] 152 | numerator = tl.exp(input) 153 | softmax = numerator / tl.sum(numerator, axis=1)[:, None] 154 | 155 | output_grad = tl.load(output_grad_pointer).to(tl.float32) 156 | target = tl.load(target_pointer + batch_offset, mask=batch_mask) 157 | broadcasted_feat_offset = tl.broadcast_to(feat_offset[None, :], 158 | (BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT)) 159 | broadcasted_target = tl.broadcast_to(target[:, None], 160 | (BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT)) 161 | input_grad = output_grad * (softmax - (broadcasted_feat_offset == broadcasted_target)) 162 | 163 | if weighted: 164 | weight = tl.load(weight_pointer + target, mask=batch_mask).to(tl.float32) 165 | sum_weights = tl.load(sum_weights_pointer) 166 | input_grad *= weight[:, None] / sum_weights 167 | 168 | else: 169 | input_grad /= batch_dim 170 | 171 | tl.store(input_grad_pointer, input_grad, 172 | mask=batch_mask[:, None] & feat_mask[None, :]) 173 | -------------------------------------------------------------------------------- /attorch/softmax_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kernels for softmax and related functions. 3 | """ 4 | 5 | 6 | from typing import Dict 7 | 8 | import triton 9 | import triton.language as tl 10 | from triton import next_power_of_2 11 | 12 | from .utils import warps_kernel_configs 13 | 14 | 15 | def BLOCK_SIZE_BATCH_heuristic(args: Dict) -> int: 16 | """ 17 | Approximates an appropriate batch block size for softmax using a heuristic. 18 | 19 | Args: 20 | args: Arguments to softmax kernel. 21 | 22 | Returns: 23 | Appropriate batch block size. 24 | """ 25 | # This heuristic was derived manually. 26 | # Essentially, if the batch dimension is greater than 1024, 27 | # for small feature sizes (less than 64), it is much more efficient 28 | # to process multiple rows at once in a given program. 29 | # Specifically, each time the number of samples is doubled, 30 | # the block size across the batch dimension should be doubled too, 31 | # with an upper bound of 128. 32 | return (min(max(1, next_power_of_2(args['batch_dim'] // 2 ** 10)), 128) 33 | if args['feat_dim'] < 64 else 1) 34 | 35 | 36 | @triton.autotune( 37 | configs=warps_kernel_configs(), 38 | key=['batch_dim', 'feat_dim'], 39 | ) 40 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic, 41 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])}) 42 | @triton.jit 43 | def softmax_forward_kernel( 44 | input_pointer, output_pointer, 45 | batch_dim, feat_dim, 46 | input_batch_stride, input_feat_stride, 47 | output_batch_stride, output_feat_stride, 48 | neg: tl.constexpr, log: tl.constexpr, 49 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr, 50 | ): 51 | """ 52 | Normalizes the input using softmax. 53 | 54 | Args: 55 | input_pointer: Pointer to the input to normalize. 56 | The input must be of shape [batch_dim, feat_dim]. 57 | output_pointer: Pointer to a container the result is written to. 58 | The container must be of shape [batch_dim, feat_dim]. 59 | batch_dim: Batch dimension. 60 | feat_dim: Dimensionality of the features. 61 | input_batch_stride: Stride necessary to jump one element along the 62 | input's batch dimension. 63 | input_feat_stride: Stride necessary to jump one element along the 64 | input's feature dimension. 65 | output_batch_stride: Stride necessary to jump one element along the 66 | output container's batch dimension. 67 | output_feat_stride: Stride necessary to jump one element along the 68 | output container's feature dimension. 69 | neg: Flag indicating if the input should be negated to get softmin. 70 | log: Flag indicating if the log of softmax should be taken. 71 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 72 | BLOCK_SIZE_FEAT: Block size across the feature dimension. 73 | """ 74 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns. 75 | batch_pid = tl.program_id(axis=0) 76 | 77 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) 78 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT) 79 | 80 | batch_mask = batch_offset < batch_dim 81 | feat_mask = feat_offset < feat_dim 82 | 83 | input_pointer += (input_batch_stride * batch_offset[:, None] + 84 | input_feat_stride * feat_offset[None, :]) 85 | output_pointer += (output_batch_stride * batch_offset[:, None] + 86 | output_feat_stride * feat_offset[None, :]) 87 | 88 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :], 89 | other=float('inf') if neg else -float('inf')).to(tl.float32) 90 | if neg: 91 | input = -input 92 | 93 | input -= tl.max(input, axis=1)[:, None] 94 | numerator = tl.exp(input) 95 | denominator = tl.sum(numerator, axis=1)[:, None] 96 | 97 | if log: 98 | output = input - tl.log(denominator) 99 | 100 | else: 101 | output = numerator / denominator 102 | 103 | tl.store(output_pointer, output, mask=batch_mask[:, None] & feat_mask[None, :]) 104 | 105 | 106 | @triton.autotune( 107 | configs=warps_kernel_configs(), 108 | key=['batch_dim', 'feat_dim'], 109 | ) 110 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic, 111 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])}) 112 | @triton.jit 113 | def softmax_backward_kernel( 114 | output_grad_pointer, output_pointer, input_grad_pointer, 115 | batch_dim, feat_dim, 116 | output_grad_batch_stride, output_grad_feat_stride, 117 | output_batch_stride, output_feat_stride, 118 | input_grad_batch_stride, input_grad_feat_stride, 119 | neg: tl.constexpr, log: tl.constexpr, 120 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr, 121 | ): 122 | """ 123 | Calculates the input gradient of softmax. 124 | 125 | Args: 126 | output_grad_pointer: Pointer to softmax's output gradients. 127 | The output gradients must be of shape [batch_dim, feat_dim]. 128 | output_pointer: Pointer to softmax's output. 129 | The output must be of shape [batch_dim, feat_dim]. 130 | input_grad_pointer: Pointer to a container the input's gradients are written to. 131 | The container must be of shape [batch_dim, feat_dim]. 132 | batch_dim: Batch dimension. 133 | feat_dim: Dimensionality of the features. 134 | output_grad_batch_stride: Stride necessary to jump one element along the 135 | output gradients' batch dimension. 136 | output_grad_feat_stride: Stride necessary to jump one element along the 137 | output gradients' feature dimension. 138 | output_batch_stride: Stride necessary to jump one element along the 139 | output's batch dimension. 140 | output_feat_stride: Stride necessary to jump one element along the 141 | output's feature dimension. 142 | input_grad_batch_stride: Stride necessary to jump one element along the 143 | input gradient container's batch dimension. 144 | input_grad_feat_stride: Stride necessary to jump one element along the 145 | input gradient container's feature dimension. 146 | neg: Flag indicating if the input was negated to get softmin. 147 | log: Flag indicating if log of softmax was taken. 148 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 149 | BLOCK_SIZE_FEAT: Block size across the feature dimension. 150 | """ 151 | # This program processes a single row and BLOCK_SIZE_FEAT columns. 152 | batch_pid = tl.program_id(axis=0) 153 | 154 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) 155 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT) 156 | 157 | batch_mask = batch_offset < batch_dim 158 | feat_mask = feat_offset < feat_dim 159 | 160 | output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] + 161 | output_grad_feat_stride * feat_offset[None, :]) 162 | output_pointer += (output_batch_stride * batch_offset[:, None] + 163 | output_feat_stride * feat_offset[None, :]) 164 | input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] + 165 | input_grad_feat_stride * feat_offset[None, :]) 166 | 167 | output_grad = tl.load(output_grad_pointer, 168 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32) 169 | output = tl.load(output_pointer, 170 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32) 171 | 172 | if log: 173 | input_grad = (output_grad - 174 | tl.exp(output) * tl.sum(output_grad, axis=1)[:, None]) 175 | 176 | else: 177 | input_grad = output * (output_grad - 178 | tl.sum(output_grad * output, axis=1)[:, None]) 179 | 180 | tl.store(input_grad_pointer, -input_grad if neg else input_grad, 181 | mask=batch_mask[:, None] & feat_mask[None, :]) 182 | -------------------------------------------------------------------------------- /examples/wikitext-2/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains and benchmarks a language model on the WikiText-2 dataset. 3 | """ 4 | 5 | 6 | import time 7 | from argparse import ArgumentParser 8 | from functools import partial 9 | from math import exp, sqrt 10 | from typing import Callable, Tuple 11 | 12 | import torch 13 | from datasets import load_dataset 14 | from torch import nn 15 | from torch.optim import AdamW 16 | from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR 17 | from torch.utils.data import DataLoader 18 | from transformers import AutoTokenizer 19 | 20 | from .gpt import gpt2 21 | from ..utils import AvgMeter, benchmark_fw_and_bw 22 | 23 | 24 | def create_dls( 25 | batch_size: int = 32, 26 | seq_len: int = 512, 27 | num_workers: int = 4, 28 | ) -> Tuple[DataLoader, DataLoader]: 29 | """ 30 | Creates data loaders for WikiText-2 with tokenization. 31 | 32 | args: 33 | batch_size: Batch size. 34 | seq_len: Sequence length. 35 | num_workers: Number of workers for data loading. 36 | 37 | Returns: 38 | Training and validation WikiText-2 data loaders. 39 | """ 40 | tokenizer = AutoTokenizer.from_pretrained('gpt2') 41 | tokenizer.pad_token = tokenizer.eos_token 42 | 43 | dataset = load_dataset('wikitext', 'wikitext-2-v1') 44 | dataset = dataset.map(lambda input: tokenizer(input['text'], 45 | max_length=seq_len, 46 | padding='max_length', 47 | truncation=True), 48 | batched=True) 49 | dataset.set_format('torch') 50 | 51 | train_dl = DataLoader(dataset=dataset['train'], batch_size=batch_size, 52 | shuffle=True, num_workers=num_workers, pin_memory=True, 53 | drop_last=True) 54 | valid_dl = DataLoader(dataset=dataset['validation'], batch_size=batch_size, 55 | shuffle=False, num_workers=num_workers, pin_memory=True, 56 | drop_last=True) 57 | 58 | return train_dl, valid_dl 59 | 60 | 61 | def train( 62 | model: nn.Module, 63 | train_dl: DataLoader, 64 | valid_dl: DataLoader, 65 | scheduler: str = 'one-cycle', 66 | epochs: int = 10, 67 | batch_size: int = 32, 68 | ) -> float: 69 | """ 70 | Trains and validates a model for language modelling. 71 | 72 | Args: 73 | model: Model to train. Its forward pass must optionally accept return_loss 74 | to compute the loss. 75 | scheduler: Learning rate scheduler. 76 | Options are 'one-cycle' and 'cosine'. 77 | train_dl: Data loader for training. 78 | valid_dl: Data loader for validation. 79 | epochs: Number of epochs to train for. 80 | batch_size: Batch size. 81 | 82 | Returns: 83 | Total training and validation time. 84 | """ 85 | model = model.to('cuda') 86 | optim = AdamW(model.parameters(), lr=1e-4) 87 | optim.zero_grad() 88 | scaler = torch.GradScaler('cuda') 89 | 90 | if scheduler == 'one-cycle': 91 | sched = OneCycleLR(optim, max_lr=sqrt(batch_size / 32) * 4e-4, 92 | steps_per_epoch=len(train_dl), epochs=epochs) 93 | 94 | elif scheduler == 'cosine': 95 | sched = CosineAnnealingLR(optim, T_max=len(train_dl) * epochs) 96 | 97 | else: 98 | raise RuntimeError(f'Scheduler {scheduler} not supported.') 99 | 100 | avg_meter = AvgMeter() 101 | start = time.time() 102 | 103 | for epoch in range(epochs): 104 | print(f'Epoch {epoch+1}/{epochs}') 105 | 106 | model.train() 107 | avg_meter.reset() 108 | for ind, batch in enumerate(train_dl): 109 | if (ind + 1) % 5 == 0: 110 | print(f'Training iteration {ind+1}/{len(train_dl)}', end='\r') 111 | 112 | input = batch['input_ids'].to('cuda') 113 | 114 | with torch.autocast('cuda'): 115 | loss = model(input, return_loss=True) 116 | 117 | scaler.scale(loss).backward() 118 | scaler.step(optim) 119 | scaler.update() 120 | sched.step() 121 | optim.zero_grad() 122 | 123 | avg_meter.update(loss.item(), len(input)) 124 | print(f'Training loss: {avg_meter.avg}') 125 | 126 | model.eval() 127 | avg_meter.reset() 128 | with torch.no_grad(): 129 | for ind, batch in enumerate(valid_dl): 130 | if (ind + 1) % 5 == 0: 131 | print(f'Validation iteration {ind+1}/{len(valid_dl)}', end='\r') 132 | 133 | input = batch['input_ids'].to('cuda') 134 | 135 | with torch.autocast('cuda'): 136 | loss = model(input, return_loss=True) 137 | avg_meter.update(loss.item(), len(input)) 138 | print(f'Validation perplexity: {exp(avg_meter.avg)}') 139 | 140 | return time.time() - start 141 | 142 | 143 | def main( 144 | model_cls: Callable, 145 | scheduler: str = 'one-cycle', 146 | epochs: int = 10, 147 | batch_size: int = 32, 148 | seq_len: int = 512, 149 | num_workers: int = 4, 150 | ) -> None: 151 | """ 152 | Trains and benchmarks a language model on the WikiText-2 dataset. 153 | 154 | Args: 155 | model_cls: Model class to train, with a 'vocab_size' argument for 156 | specifying the vocabulary size and a 'max_seq_len' argument for 157 | specifying the maximum sequence length of the incoming inputs. 158 | scheduler: Learning rate scheduler. 159 | Options are 'one-cycle' and 'cosine'. 160 | epochs: Number of epochs to train for. 161 | batch_size: Batch size. 162 | seq_len: Sequence length. 163 | num_workers: Number of workers for data loading. 164 | """ 165 | train_dl, valid_dl = create_dls(batch_size, seq_len, num_workers) 166 | vocab_size = AutoTokenizer.from_pretrained('gpt2').vocab_size 167 | model = model_cls(vocab_size=vocab_size, max_seq_len=seq_len).to('cuda') 168 | 169 | batch = next(iter(train_dl)) 170 | input = batch['input_ids'].to('cuda') 171 | 172 | for _ in range(10): 173 | model.train() 174 | with torch.autocast('cuda'): 175 | loss = model(input, return_loss=True) 176 | loss.backward() 177 | 178 | model.eval() 179 | with torch.no_grad() and torch.autocast('cuda'): 180 | model(input, return_loss=True) 181 | 182 | model.train() 183 | benchmark_fw_and_bw(model, input=input, return_loss=True) 184 | 185 | print('Total training and validation time: ' 186 | f'{train(model, train_dl, valid_dl, scheduler, epochs, batch_size)}') 187 | 188 | 189 | if __name__ == '__main__': 190 | parser = ArgumentParser(description='Trains and benchmarks a language model on the WikiText-2 dataset.') 191 | parser.add_argument('--model', 192 | type=str, 193 | default='gpt2', 194 | choices=['gpt2'], 195 | help='Name of language model to train') 196 | parser.add_argument('--downsize', 197 | type=int, 198 | default=1, 199 | help='The depth and width of the model are calculated by dividing GPT2\'s original depth and width by this factor.') 200 | parser.add_argument('--scheduler', 201 | type=str, 202 | default='one-cycle', 203 | choices=['one-cycle', 'cosine'], 204 | help='Learning rate scheduler.') 205 | parser.add_argument('--epochs', 206 | type=int, 207 | default=10, 208 | help='Number of epochs to train for') 209 | parser.add_argument('--batch_size', 210 | type=int, 211 | default=32, 212 | help='Batch size') 213 | parser.add_argument('--seq_len', 214 | type=int, 215 | default=512, 216 | help='Sequence length') 217 | parser.add_argument('--num_workers', 218 | type=int, 219 | default=4, 220 | help='Number of workers for data loading') 221 | args = parser.parse_args() 222 | 223 | print('attorch run:') 224 | main(partial(locals()[args.model], use_attorch=True, downsize=args.downsize), 225 | scheduler=args.scheduler, epochs=args.epochs, batch_size=args.batch_size, 226 | seq_len=args.seq_len, num_workers=args.num_workers) 227 | 228 | print('PyTorch run:') 229 | main(partial(locals()[args.model], use_attorch=False, downsize=args.downsize), 230 | scheduler=args.scheduler, epochs=args.epochs, batch_size=args.batch_size, 231 | seq_len=args.seq_len, num_workers=args.num_workers) 232 | -------------------------------------------------------------------------------- /attorch/layer_norm_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layer normalization with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from triton import cdiv 12 | 13 | from .layer_norm_kernels import layer_norm_backward_kernel, layer_norm_forward_kernel 14 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic 15 | from .types import Context, Device 16 | from .utils import get_output_dtype 17 | 18 | 19 | class LayerNormAutoGrad(torch.autograd.Function): 20 | """ 21 | Autodiff for layer normalization. 22 | """ 23 | @staticmethod 24 | def forward( 25 | ctx: Context, 26 | input: Tensor, 27 | weight: Optional[Tensor] = None, 28 | bias: Optional[Tensor] = None, 29 | eps: float = 1e-5, 30 | autocast_to_fp32: bool = True, 31 | ) -> Tensor: 32 | """ 33 | Layer-normalizes the input. 34 | 35 | Args: 36 | ctx: Context for variable storage. 37 | input: Input to layer-normalize. 38 | Can have arbitrary shape. 39 | weight: Optional weights for affine transform. 40 | If provided, must be of shape [feat_dim]. 41 | bias: Optional bias vector for affine transform when weight is provided. 42 | If provided, must be of shape [feat_dim]. 43 | eps: Epsilon added in the square root in the denominator 44 | to avoid division by zero. 45 | autocast_to_fp32: Flag for autocasting the output dtype to fp32. 46 | If False, the input dtype flows through. 47 | 48 | Returns: 49 | Layer-normalized input. 50 | """ 51 | flattened_input = input.unsqueeze(0) if input.ndim == 1 else input 52 | flattened_input = flattened_input.flatten(0, -2) 53 | batch_dim, feat_dim = flattened_input.shape 54 | 55 | output_dtype = get_output_dtype(input.dtype, 56 | autocast='fp32' if autocast_to_fp32 else None) 57 | output = torch.empty_like(flattened_input, dtype=output_dtype) 58 | 59 | scale_by_weight = weight is not None 60 | add_bias = scale_by_weight and bias is not None 61 | requires_grad = (input.requires_grad or 62 | (scale_by_weight and weight.requires_grad) or 63 | (add_bias and bias.requires_grad)) 64 | 65 | if requires_grad: 66 | mean = torch.empty(batch_dim, 67 | device=input.device, 68 | dtype=torch.float32) 69 | inv_std = torch.empty(batch_dim, 70 | device=input.device, 71 | dtype=torch.float32) 72 | 73 | else: 74 | mean = inv_std = None 75 | 76 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 77 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),) 78 | layer_norm_forward_kernel[grid](flattened_input, weight, bias, 79 | mean, inv_std, output, 80 | batch_dim, feat_dim, 81 | *flattened_input.stride(), *output.stride(), 82 | eps, 83 | scale_by_weight=scale_by_weight, 84 | add_bias=add_bias, 85 | save_stats=requires_grad) 86 | 87 | ctx.scale_by_weight = scale_by_weight 88 | ctx.add_bias = add_bias 89 | ctx.output_dtype = output_dtype 90 | if requires_grad: 91 | ctx.save_for_backward(flattened_input, mean, inv_std, weight) 92 | 93 | return output.view_as(input) 94 | 95 | @staticmethod 96 | def backward( 97 | ctx: Context, 98 | output_grad: Tensor, 99 | ) -> Tuple[Optional[Tensor], ...]: 100 | """ 101 | Calculates the input gradient of layer normalization. 102 | 103 | Args: 104 | ctx: Context containing stored variables. 105 | output_grad: Output gradients. 106 | Must be the same shape as the output. 107 | 108 | Returns: 109 | Input gradient of layer normalization. 110 | """ 111 | scale_by_weight, add_bias = ctx.scale_by_weight, ctx.add_bias 112 | (flattened_input, mean, inv_std, weight) = ctx.saved_tensors 113 | flattened_output_grad = output_grad.view_as(flattened_input) 114 | 115 | batch_dim, feat_dim = flattened_output_grad.shape 116 | input_grad = torch.empty_like(flattened_output_grad, 117 | dtype=ctx.output_dtype) 118 | 119 | if scale_by_weight: 120 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim, 121 | 'feat_dim': feat_dim}) 122 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH 123 | 124 | weight_grad = torch.empty((out_batch_dim, feat_dim), 125 | device=flattened_input.device) 126 | if add_bias: 127 | bias_grad = torch.empty((out_batch_dim, feat_dim), 128 | device=flattened_input.device) 129 | 130 | else: 131 | bias_grad = None 132 | 133 | else: 134 | weight_grad = bias_grad = None 135 | 136 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows. 137 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),) 138 | layer_norm_backward_kernel[grid](flattened_output_grad, flattened_input, 139 | mean, inv_std, weight, 140 | input_grad, weight_grad, bias_grad, 141 | batch_dim, feat_dim, 142 | *flattened_output_grad.stride(), 143 | *flattened_input.stride(), 144 | *input_grad.stride(), 145 | *weight_grad.stride() if scale_by_weight else (1, 1), 146 | *bias_grad.stride() if scale_by_weight and add_bias else (1, 1), 147 | scale_by_weight=scale_by_weight, 148 | add_bias=add_bias) 149 | 150 | if scale_by_weight: 151 | weight_grad = weight_grad.sum(dim=0) 152 | if add_bias: 153 | bias_grad = bias_grad.sum(dim=0) 154 | 155 | # Pads output with None because a gradient is necessary for 156 | # all input arguments. 157 | return input_grad.view_as(output_grad), weight_grad, bias_grad, None, None 158 | 159 | 160 | class LayerNorm(nn.LayerNorm): 161 | """ 162 | Layer-normalizes the input. 163 | See also base class. 164 | 165 | Note: During automatic mixed precision training, PyTorch autocasts 166 | the output dtype of layer normalization to fp32. This conversion is not necessary 167 | and retaining the input dtype is generally numerically stable in addition to 168 | being slightly more efficient. Layer normalization in attorch thus offers 169 | an additional argument, autocast_to_fp32, to manually disable this 170 | autocasting behaviour. 171 | 172 | Args: 173 | normalized_shape: Dimensionality of last feature that is normalized. 174 | eps: Epsilon added in the square root in the denominator 175 | to avoid division by zero. 176 | elementwise_affine: Flag for scaling the normalized output by weights. 177 | bias: Flag for adding a bias vector to the normalized output 178 | if elementwise_affine is True. 179 | autocast_to_fp32: Flag for autocasting the output dtype to fp32. 180 | If False, the input dtype flows through. 181 | device: Device to use. 182 | dtype: Dtype of layer. 183 | 184 | Raises: 185 | RuntimeError: Normalized shape was not an integer. 186 | """ 187 | def __init__( 188 | self, 189 | normalized_shape: int, 190 | eps: float = 1e-5, 191 | elementwise_affine: bool = True, 192 | bias: bool = True, 193 | autocast_to_fp32: bool = True, 194 | device: Device = 'cuda', 195 | dtype: torch.dtype = torch.float32, 196 | ) -> None: 197 | if not isinstance(normalized_shape, int): 198 | raise RuntimeError('Normalized shape must be an integer.') 199 | 200 | super().__init__(normalized_shape, eps, elementwise_affine, bias, 201 | device, dtype) 202 | self.autocast_to_fp32 = autocast_to_fp32 203 | 204 | def forward(self, input: Tensor) -> Tensor: 205 | return LayerNormAutoGrad.apply(input, self.weight, self.bias, self.eps, 206 | self.autocast_to_fp32) 207 | -------------------------------------------------------------------------------- /attorch/linear_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Linear layer with fused activation with PyTorch autodiff support. 3 | """ 4 | 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | from torch import nn 11 | from triton import cdiv 12 | 13 | from .act_kernels import act_func_backward_kernel 14 | from .linear_kernels import linear_forward_kernel 15 | from .types import Context, Device 16 | from .utils import get_output_dtype 17 | 18 | 19 | class LinearAutoGrad(torch.autograd.Function): 20 | """ 21 | Autodiff for linear layer. 22 | """ 23 | @staticmethod 24 | def forward( 25 | ctx: Context, 26 | input: Tensor, 27 | weight: Tensor, 28 | bias: Optional[Tensor] = None, 29 | act_func: Optional[str] = None, 30 | ) -> Tensor: 31 | """ 32 | Linearly transforms the input using weights, optionally adding bias 33 | and fusing an activation function. 34 | 35 | Args: 36 | input: Input to transform. 37 | Must be of shape [..., in_feat_dim]. 38 | weight: Weights input is transformed by. 39 | Must be of shape [in_feat_dim, out_feat_dim]. 40 | bias: Optional additive bias vector, with None for no bias. 41 | If provided, must be of shape [out_feat_dim]. 42 | act_func: Name of activation function to apply, with None for identity. 43 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 44 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish', 45 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM', 46 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM' 47 | where PARAM stands for the parameter in the case of parameterized 48 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a 49 | negative slope of 0.01). 50 | 51 | Returns: 52 | Input linearly transformed, potentially with added biased and 53 | fused activation. 54 | """ 55 | assert weight.ndim == 2, \ 56 | f'Weights must be 2D, received shape {weight.shape}' 57 | assert bias is None or bias.ndim == 1, \ 58 | f'Bias must be 1D, received shape {bias.shape}' 59 | 60 | assert input.shape[-1] == weight.shape[0], \ 61 | f'Incompatible input ({input.shape}) and weights ({weight.shape}) shape' 62 | assert bias is None or weight.shape[1] == bias.shape[0], \ 63 | f'Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape' 64 | 65 | param = None 66 | if act_func is not None and '_' in act_func: 67 | comps = act_func.split('_') 68 | act_func = '_'.join(comps[:-1]) 69 | param = float(comps[-1]) 70 | 71 | flattened_input = input.flatten(0, -2) 72 | batch_dim, in_feat_dim = flattened_input.shape 73 | _, out_feat_dim = weight.shape 74 | 75 | requires_grad = (input.requires_grad or 76 | weight.requires_grad or 77 | (bias is not None and bias.requires_grad)) 78 | save_pre_act = requires_grad and (act_func is not None) 79 | 80 | output_dtype = get_output_dtype(input.dtype, autocast='fp16') 81 | output = torch.empty((batch_dim, out_feat_dim), 82 | device=input.device, 83 | dtype=output_dtype) 84 | pre_act = torch.empty_like(output) if save_pre_act else output 85 | 86 | # Launches a 1D grid, where each program outputs blocks of 87 | # BLOCK_SIZE_BATCH rows and BLOCK_SIZE_OUT_FEAT columns. 88 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']) * 89 | cdiv(out_feat_dim, META['BLOCK_SIZE_OUT_FEAT']),) 90 | linear_forward_kernel[grid](flattened_input, weight, 91 | input if bias is None else bias, 92 | pre_act, output, 93 | batch_dim, in_feat_dim, out_feat_dim, 94 | *flattened_input.stride(), *weight.stride(), 95 | *pre_act.stride(), *output.stride(), param, 96 | add_bias=bias is not None, act_func=act_func, 97 | save_pre_act=save_pre_act, 98 | fp16=output_dtype is torch.float16) 99 | 100 | ctx.param = param 101 | ctx.act_func = act_func 102 | ctx.bias_requires_grad = False if bias is None else bias.requires_grad 103 | ctx.output_dtype = output_dtype 104 | if requires_grad: 105 | ctx.save_for_backward(input, pre_act if save_pre_act else None, weight) 106 | 107 | return output.view(*input.shape[:-1], out_feat_dim) 108 | 109 | @staticmethod 110 | def backward( 111 | ctx: Context, 112 | output_grad: Tensor, 113 | ) -> Tuple[Optional[Tensor], ...]: 114 | """ 115 | Calculates the input gradient of the linear layer. 116 | 117 | Args: 118 | ctx: Context containing stored variables. 119 | output_grad: Output gradients. 120 | Must be the same shape as the output. 121 | 122 | Returns: 123 | Input gradient of the linear layer. 124 | """ 125 | input, pre_act, weight = ctx.saved_tensors 126 | 127 | output_grad = output_grad.flatten(0, -2) 128 | flattened_input = input.flatten(0, -2) 129 | 130 | batch_dim, _ = flattened_input.shape 131 | _, out_feat_dim = weight.shape 132 | 133 | if ctx.act_func is None: 134 | pre_act_grad = output_grad 135 | 136 | else: 137 | size = batch_dim * out_feat_dim 138 | pre_act_grad = torch.empty(size, dtype=pre_act.dtype, 139 | device=pre_act.device) 140 | 141 | # Launches 1D grid where each program operates over 142 | # BLOCK_SIZE elements. 143 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),) 144 | act_func_backward_kernel[grid](output_grad, pre_act, pre_act_grad, 145 | size, None, None, ctx.param, 146 | ctx.act_func, False) 147 | 148 | pre_act_grad = pre_act_grad.view_as(pre_act) 149 | 150 | # Using PyTorch's matmul, but linear_forward_kernel 151 | # could have also been used. 152 | with torch.autocast('cuda', dtype=ctx.output_dtype): 153 | input_grad = pre_act_grad @ weight.T if input.requires_grad else None 154 | weight_grad = (flattened_input.T @ pre_act_grad 155 | if weight.requires_grad else None) 156 | bias_grad = pre_act_grad.sum(dim=0) if ctx.bias_requires_grad else None 157 | 158 | # Pads output with None because a gradient is necessary for 159 | # all input arguments. 160 | return (input_grad.view_as(input) if input_grad is not None else None, 161 | weight_grad, bias_grad, None) 162 | 163 | 164 | class Linear(nn.Linear): 165 | """ 166 | Linearly transforms the input using weights, optionally adding bias 167 | and fusing an activation function. 168 | See also base class. 169 | 170 | Note: Unlike PyTorch's linear layer, the weight matrix in this module is 171 | of shape [in_features, out_features] instead of [out_features, in_features]. 172 | This may cause unexpected issues when manipulating the weights (e.g., porting 173 | parameters, initializing them, and so forth). 174 | 175 | Args: 176 | in_features: Number of input features. 177 | out_features: Number of output features. 178 | bias: Flag for additive bias. 179 | act_func: Name of activation function to apply, with None for identity. 180 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 181 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish', 182 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM', 183 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM' 184 | where PARAM stands for the parameter in the case of parameterized 185 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a 186 | negative slope of 0.01). 187 | device: Device to use. 188 | dtype: Dtype of layer. 189 | """ 190 | def __init__( 191 | self, 192 | in_features: int, 193 | out_features: int, 194 | bias: bool = True, 195 | act_func: Optional[str] = None, 196 | device: Device = 'cuda', 197 | dtype: torch.dtype = torch.float32, 198 | ) -> None: 199 | super().__init__(in_features, out_features, bias, device, dtype) 200 | self.weight = nn.Parameter(self.weight.T.contiguous()) 201 | self.act_func = act_func 202 | 203 | def forward(self, input: Tensor) -> Tensor: 204 | return LinearAutoGrad.apply(input, self.weight, self.bias, 205 | self.act_func) 206 | -------------------------------------------------------------------------------- /attorch/linear_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kernels for linear layer with fused activation. 3 | """ 4 | 5 | 6 | import triton 7 | import triton.language as tl 8 | 9 | from .act_kernels import apply_act_func 10 | from .utils import allow_tf32, get_n_stages 11 | 12 | 13 | def linear_forward_config( 14 | BLOCK_SIZE_BATCH: int, 15 | BLOCK_SIZE_IN_FEAT: int, 16 | BLOCK_SIZE_OUT_FEAT: int, 17 | GROUP_SIZE_BATCH: int = 8, 18 | n_warps: int = 4, 19 | n_stages: int = 2, 20 | ) -> triton.Config: 21 | """ 22 | Creates a triton.Config object for linear_forward_kernel 23 | given meta-parameters for auto-tuning. 24 | 25 | Args: 26 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 27 | BLOCK_SIZE_IN_FEAT: Block size across the input feature dimension. 28 | BLOCK_SIZE_OUT_FEAT: Block size across the output feature dimension. 29 | GROUP_SIZE_BATCH: Group size across the batch dimension. 30 | n_warps: Number of warps to use for the kernel when compiled for GPUs. 31 | n_stages: Number of stages the compiler uses to software-pipeline. 32 | On GPU architectures older than Ampere, this is fixed at 2. 33 | 34 | Returns: 35 | Kernel configuration. 36 | """ 37 | return triton.Config({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH, 38 | 'BLOCK_SIZE_IN_FEAT': BLOCK_SIZE_IN_FEAT, 39 | 'BLOCK_SIZE_OUT_FEAT': BLOCK_SIZE_OUT_FEAT, 40 | 'GROUP_SIZE_BATCH': GROUP_SIZE_BATCH}, 41 | num_warps=n_warps, num_stages=get_n_stages(n_stages)) 42 | 43 | 44 | @triton.autotune( 45 | configs=[ 46 | linear_forward_config(32, 32, 32, n_warps=2, n_stages=2), 47 | linear_forward_config(64, 32, 32, n_warps=2, n_stages=5), 48 | linear_forward_config(64, 32, 128, n_warps=4, n_stages=4), 49 | linear_forward_config(64, 32, 256, n_warps=4, n_stages=4), 50 | linear_forward_config(128, 32, 32, n_warps=4, n_stages=4), 51 | linear_forward_config(128, 32, 64, n_warps=4, n_stages=4), 52 | linear_forward_config(128, 32, 128, n_warps=4, n_stages=4), 53 | linear_forward_config(128, 64, 256, n_warps=8, n_stages=3), 54 | ], 55 | key=['batch_dim', 'in_feat_dim', 'out_feat_dim', 'fp16'], 56 | ) 57 | @triton.heuristics({'tf32': lambda _: allow_tf32()}) 58 | @triton.jit 59 | def linear_forward_kernel( 60 | input_pointer, weight_pointer, bias_pointer, pre_act_pointer, output_pointer, 61 | batch_dim, in_feat_dim, out_feat_dim, 62 | input_batch_stride, input_in_feat_stride, 63 | weight_in_feat_stride, weight_out_feat_stride, 64 | pre_act_batch_stride, pre_act_out_feat_stride, 65 | output_batch_stride, output_out_feat_stride, param, 66 | add_bias: tl.constexpr, act_func: tl.constexpr, save_pre_act: tl.constexpr, 67 | fp16: tl.constexpr, tf32: tl.constexpr, 68 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_IN_FEAT: tl.constexpr, 69 | BLOCK_SIZE_OUT_FEAT: tl.constexpr, GROUP_SIZE_BATCH: tl.constexpr, 70 | ): 71 | """ 72 | Linearly transforms the input using weights, optionally adding bias 73 | and fusing an activation function. 74 | 75 | Args: 76 | input_pointer: Pointer to the input to transform. 77 | The input must be of shape [batch_dim, in_feat_dim]. 78 | weight_pointer: Pointer to the weights input is transformed by. 79 | The weights must be of shape [in_feat_dim, out_feat_dim]. 80 | bias_pointer: Pointer to an optional additive bias vector. 81 | The bias vector, if provided, must be of shape [out_feat_dim]. 82 | pre_act_pointer: Pointer to an optional container the pre-activation input 83 | is written to if act_func is not None and save_pre_act is True. 84 | The container, if provided, must be of shape [batch_dim, out_feat_dim]. 85 | output_pointer: Pointer to a container the result is written to. 86 | The container must be of shape [batch_dim, out_feat_dim]. 87 | batch_dim: Batch dimension of the input and output. 88 | in_feat_dim: Dimensionality of the input features. 89 | out_feat_dim: Dimensionality of the output features. 90 | input_batch_stride: Stride necessary to jump one element along the 91 | input's batch dimension. 92 | input_in_feat_stride: Stride necessary to jump one element along the 93 | input's feature dimension. 94 | weight_in_feat_stride: Stride necessary to jump one element along the 95 | weights' input feature dimension. 96 | weight_out_feat_stride: Stride necessary to jump one element along the 97 | weights' output feature dimension. 98 | pre_act_batch_stride: Stride necessary to jump one element along the 99 | pre-activation input container's batch dimension. 100 | pre_act_out_feat_stride: Stride necessary to jump one element along the 101 | pre-activation input container's feature dimension. 102 | output_batch_stride: Stride necessary to jump one element along the 103 | output container's batch dimension. 104 | output_out_feat_stride: Stride necessary to jump one element along the 105 | output container's feature dimension. 106 | param: Parameter in the case of parameterized activation functions. 107 | add_bias: Flag for adding a bias vector. 108 | act_func: Name of activation function to apply, with None for identity. 109 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 110 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', 'hardshrink', 111 | and 'softshrink'. 112 | save_pre_act: Flag for saving the pre-activation input. 113 | fp16: Flag for loading the input, weights, and bias in FP16. 114 | tf32: Flag for performing matrix products in TF32. 115 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 116 | BLOCK_SIZE_IN_FEAT: Block size across the input feature dimension. 117 | BLOCK_SIZE_OUT_FEAT: Block size across the output feature dimension. 118 | GROUP_SIZE_BATCH: Group size across the batch dimension. 119 | """ 120 | # Programs are blocked together, GROUP_SIZE_BATCH rows at a time, 121 | # to alleviate L2 miss rates. 122 | pid = tl.program_id(axis=0) 123 | n_batch_pids = tl.cdiv(batch_dim, BLOCK_SIZE_BATCH) 124 | n_out_feat_pids = tl.cdiv(out_feat_dim, BLOCK_SIZE_OUT_FEAT) 125 | pids_per_group = GROUP_SIZE_BATCH * n_out_feat_pids 126 | group_id = pid // pids_per_group 127 | first_batch_pid = group_id * GROUP_SIZE_BATCH 128 | GROUP_SIZE_BATCH = min(n_batch_pids - first_batch_pid, GROUP_SIZE_BATCH) 129 | batch_pid = first_batch_pid + (pid % GROUP_SIZE_BATCH) 130 | out_feat_pid = (pid % pids_per_group) // GROUP_SIZE_BATCH 131 | 132 | batch_offset = (batch_pid * BLOCK_SIZE_BATCH + 133 | tl.arange(0, BLOCK_SIZE_BATCH)) 134 | out_feat_offset = (out_feat_pid * BLOCK_SIZE_OUT_FEAT + 135 | tl.arange(0, BLOCK_SIZE_OUT_FEAT)) 136 | 137 | batch_mask = batch_offset < batch_dim 138 | out_feat_mask = out_feat_offset < out_feat_dim 139 | 140 | input_pointer += input_batch_stride * batch_offset[:, None] 141 | weight_pointer += weight_out_feat_stride * out_feat_offset[None, :] 142 | 143 | accum = tl.zeros((BLOCK_SIZE_BATCH, BLOCK_SIZE_OUT_FEAT), 144 | dtype=tl.float32) 145 | 146 | for block_ind in range(0, tl.cdiv(in_feat_dim, BLOCK_SIZE_IN_FEAT)): 147 | in_feat_offset = (block_ind * BLOCK_SIZE_IN_FEAT + 148 | tl.arange(0, BLOCK_SIZE_IN_FEAT)) 149 | in_feat_mask = in_feat_offset < in_feat_dim 150 | 151 | curr_input_pointer = (input_pointer + 152 | input_in_feat_stride * in_feat_offset[None, :]) 153 | curr_weight_pointer = (weight_pointer + 154 | weight_in_feat_stride * in_feat_offset[:, None]) 155 | 156 | input_block = tl.load(curr_input_pointer, 157 | mask=batch_mask[:, None] & in_feat_mask[None, :]) 158 | weight_block = tl.load(curr_weight_pointer, 159 | mask=out_feat_mask[None, :] & in_feat_mask[:, None]) 160 | 161 | if fp16: 162 | input_block = input_block.to(tl.float16) 163 | weight_block = weight_block.to(tl.float16) 164 | 165 | accum += tl.dot(input_block, weight_block, allow_tf32=tf32) 166 | 167 | if add_bias: 168 | bias = tl.load(bias_pointer + out_feat_offset, 169 | mask=out_feat_mask) 170 | 171 | if fp16: 172 | bias = bias.to(tl.float16) 173 | 174 | accum += bias[None, :] 175 | 176 | if act_func is not None: 177 | if save_pre_act: 178 | pre_act_pointer += (pre_act_batch_stride * batch_offset[:, None] + 179 | pre_act_out_feat_stride * out_feat_offset[None, :]) 180 | tl.store(pre_act_pointer, accum, 181 | mask=batch_mask[:, None] & out_feat_mask[None, :]) 182 | 183 | accum = apply_act_func(accum, None, None, None, param, act_func, False) 184 | 185 | output_pointer += (output_batch_stride * batch_offset[:, None] + 186 | output_out_feat_stride * out_feat_offset[None, :]) 187 | tl.store(output_pointer, accum, 188 | mask=batch_mask[:, None] & out_feat_mask[None, :]) 189 | -------------------------------------------------------------------------------- /attorch/rms_norm_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kernels for root mean square normalization. 3 | """ 4 | 5 | 6 | import triton 7 | import triton.language as tl 8 | from triton import next_power_of_2 9 | 10 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic 11 | from .utils import warps_kernel_configs 12 | 13 | 14 | @triton.autotune( 15 | configs=warps_kernel_configs(), 16 | key=['batch_dim', 'feat_dim'], 17 | ) 18 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic, 19 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])}) 20 | @triton.jit 21 | def rms_norm_forward_kernel( 22 | input_pointer, weight_pointer, 23 | inv_rms_pointer, output_pointer, 24 | batch_dim, feat_dim, 25 | input_batch_stride, input_feat_stride, 26 | output_batch_stride, output_feat_stride, 27 | eps, 28 | scale_by_weight: tl.constexpr, save_stats: tl.constexpr, 29 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr, 30 | ): 31 | """ 32 | Root-mean-square-normalizes the input. 33 | 34 | Args: 35 | input_pointer: Pointer to the input to root-mean-square-normalize. 36 | The input must be of shape [batch_dim, feat_dim]. 37 | weight_pointer: Pointer to optional weights for linear transform. 38 | The weights, if provided, must be of shape [feat_dim]. 39 | inv_rms_pointer: Pointer to an optional container the input's inverse 40 | root mean square is written to if save_stats is True. 41 | The container, if provided, must be of shape [batch_dim]. 42 | output_pointer: Pointer to a container the result is written to. 43 | The container must be of shape [batch_dim, feat_dim]. 44 | batch_dim: Batch dimension. 45 | feat_dim: Dimensionality of the features. 46 | input_batch_stride: Stride necessary to jump one element along the 47 | input's batch dimension. 48 | input_feat_stride: Stride necessary to jump one element along the 49 | input's feature dimension. 50 | output_batch_stride: Stride necessary to jump one element along the 51 | output container's batch dimension. 52 | output_feat_stride: Stride necessary to jump one element along the 53 | output container's feature dimension. 54 | eps: Epsilon added in the square root in the denominator 55 | to avoid division by zero. 56 | scale_by_weight: Flag for scaling the normalized output by weights. 57 | save_stats: Flag for saving the root mean square. 58 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 59 | BLOCK_SIZE_FEAT: Block size across the feature dimension. 60 | """ 61 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns. 62 | batch_pid = tl.program_id(axis=0) 63 | 64 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) 65 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT) 66 | 67 | batch_mask = batch_offset < batch_dim 68 | feat_mask = feat_offset < feat_dim 69 | 70 | input_pointer += (input_batch_stride * batch_offset[:, None] + 71 | input_feat_stride * feat_offset[None, :]) 72 | output_pointer += (output_batch_stride * batch_offset[:, None] + 73 | output_feat_stride * feat_offset[None, :]) 74 | 75 | input = tl.load(input_pointer, 76 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32) 77 | inv_rms = tl.rsqrt(tl.sum(input * input, axis=1) / feat_dim + eps) 78 | output = input * inv_rms[:, None] 79 | 80 | if save_stats: 81 | tl.store(inv_rms_pointer + batch_offset, inv_rms, mask=batch_mask) 82 | 83 | if scale_by_weight: 84 | weight = tl.load(weight_pointer + feat_offset, mask=feat_mask) 85 | output *= weight 86 | 87 | tl.store(output_pointer, output, 88 | mask=batch_mask[:, None] & feat_mask[None, :]) 89 | 90 | 91 | @triton.autotune( 92 | configs=warps_kernel_configs(), 93 | key=['batch_dim', 'feat_dim'], 94 | ) 95 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic, 96 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])}) 97 | @triton.jit 98 | def rms_norm_backward_kernel( 99 | output_grad_pointer, input_pointer, inv_rms_pointer, weight_pointer, 100 | input_grad_pointer, weight_grad_pointer, 101 | batch_dim, feat_dim, 102 | output_grad_batch_stride, output_grad_feat_stride, 103 | input_batch_stride, input_feat_stride, 104 | input_grad_batch_stride, input_grad_feat_stride, 105 | weight_grad_batch_stride, weight_grad_feat_stride, 106 | scale_by_weight: tl.constexpr, 107 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr, 108 | ): 109 | """ 110 | Calculates the input gradient of root mean square normalization. 111 | 112 | Args: 113 | output_grad_pointer: Pointer to root mean square normalization's output gradients. 114 | The output gradients must be of shape [batch_dim, feat_dim]. 115 | input_pointer: Pointer to the input. 116 | The input must be of shape [batch_dim, feat_dim]. 117 | inv_rms_pointer: Pointer to the input's inverse root mean square. 118 | The inverse root mean square should be of shape [batch_dim]. 119 | weight_pointer: Pointer to optional weights if affine transform occurred. 120 | The weights, if provided, must be of shape [feat_dim]. 121 | input_grad_pointer: Pointer to a container the input's gradients are written to. 122 | The container must be of shape [batch_dim, feat_dim]. 123 | weight_grad_pointer: Pointer to an optional container the weights' row-wise gradients 124 | are written to if scale_by_weight is True, which should later be summed. 125 | The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim]. 126 | bias_grad_pointer: Pointer to an optional container the bias vector's row-wise gradients 127 | are written to if scale_by_weight and add_bias are True, which should later be summed. 128 | The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim]. 129 | batch_dim: Batch dimension. 130 | feat_dim: Dimensionality of the features. 131 | output_grad_batch_stride: Stride necessary to jump one element along the 132 | output gradients' batch dimension. 133 | output_grad_feat_stride: Stride necessary to jump one element along the 134 | output gradients' feature dimension. 135 | input_batch_stride: Stride necessary to jump one element along the 136 | input's batch dimension. 137 | input_feat_stride: Stride necessary to jump one element along the 138 | input's feature dimension. 139 | input_grad_batch_stride: Stride necessary to jump one element along the 140 | input gradient container's batch dimension. 141 | input_grad_feat_stride: Stride necessary to jump one element along the 142 | input gradient container's feature dimension. 143 | weight_grad_batch_stride: Stride necessary to jump one element along the 144 | weight gradient container's batch dimension. 145 | weight_grad_feat_stride: Stride necessary to jump one element along the 146 | weight gradient container's feature dimension. 147 | scale_by_weight: Flag for scaling the normalized output by weights. 148 | BLOCK_SIZE_BATCH: Block size across the batch dimension. 149 | BLOCK_SIZE_FEAT: Block size across the feature dimension. 150 | """ 151 | # This program processes a single row and BLOCK_SIZE_FEAT columns. 152 | batch_pid = tl.program_id(axis=0) 153 | 154 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH) 155 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT) 156 | 157 | batch_mask = batch_offset < batch_dim 158 | feat_mask = feat_offset < feat_dim 159 | 160 | output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] + 161 | output_grad_feat_stride * feat_offset[None, :]) 162 | input_pointer += (input_batch_stride * batch_offset[:, None] + 163 | input_feat_stride * feat_offset[None, :]) 164 | input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] + 165 | input_grad_feat_stride * feat_offset[None, :]) 166 | 167 | output_grad = tl.load(output_grad_pointer, 168 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32) 169 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32) 170 | inv_rms = tl.load(inv_rms_pointer + batch_offset, mask=batch_mask) 171 | pre_lin = input * inv_rms[:, None] 172 | 173 | if scale_by_weight: 174 | weight = tl.load(weight_pointer + feat_offset, mask=feat_mask) 175 | weight_output_grad_prod = weight * output_grad 176 | 177 | else: 178 | weight_output_grad_prod = output_grad 179 | 180 | term1 = input * tl.sum(input * weight_output_grad_prod, axis=1) 181 | term2 = inv_rms[:, None] * inv_rms[:, None] 182 | input_grad = (inv_rms[:, None] * 183 | (weight_output_grad_prod - term1 * term2 / feat_dim)) 184 | 185 | tl.store(input_grad_pointer, input_grad, 186 | mask=batch_mask[:, None] & feat_mask[None, :]) 187 | 188 | if scale_by_weight: 189 | weight_grad_pointer += (weight_grad_batch_stride * batch_pid + 190 | weight_grad_feat_stride * feat_offset) 191 | tl.store(weight_grad_pointer, 192 | tl.sum(output_grad * pre_lin, axis=0), 193 | mask=feat_mask) 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # attorch 2 | 3 | • **[Introduction](#introduction)**
4 | • **[Installation](#installation)**
5 | • **[Layers](#layers)**
6 | • **[Math Functions](#math-functions)**
7 | • **[PyTorch Fallback](#pytorch-fallback)**
8 | • **[Tests](#tests)**
9 | • **[Examples](#examples)**
10 | • **[Citations](#citations)**
11 | 12 | ## Introduction 13 | 14 | attorch is a subset of PyTorch's [```nn```](https://pytorch.org/docs/stable/nn.html) module, written purely in Python using OpenAI's [Triton](https://github.com/openai/triton). Its goal is to be an easily hackable, self-contained, and readable collection of neural network modules whilst maintaining or improving upon the efficiency of PyTorch. In other words, it intends to be a forkable project endowed with a simple, intuitive design that can serve as an accessible starting point for those who are seeking to develop custom deep learning operations but are not satisfied with the speed of a pure PyTorch implementation and do not have the technical expertise or resources to write CUDA kernels. 15 | 16 | There already exist a number of wonderful PyTorch-like frameworks powered by Triton, including [kernl](https://github.com/ELS-RD/kernl/tree/main), [xFormers](https://github.com/facebookresearch/xformers), [Unsloth](https://github.com/unslothai/unsloth), and [```fla```](https://github.com/sustcsonglin/flash-linear-attention), but most concentrate mainly on Transformers and NLP applications, whereas attorch aims to be more inclusive by also presenting a variety of layers pertaining to areas beyond NLP, such as computer vision. Moreover, attorch is not an inference-only package and fully supports both forward and backward passes, meaning it can be used during training as well as inference, though its performance for the latter is generally not on par with dedicated inference engines. 17 | 18 | ## Installation 19 | 20 | The only dependencies of attorch are ```torch==2.4.0``` and ```triton==3.0.0```. Please install the specified versions of these two libraries and clone this repository to get started. 21 | 22 | ## Layers 23 | 24 | Currently implemented layers, with automatic mixed precision (AMP) support, are: 25 | 26 | * ```attorch.Conv1d```: 1D-convolves over the input using weights, optionally adding bias. 27 | * ```attorch.Conv2d```: 2D-convolves over the input using weights, optionally adding bias. 28 | * ```attorch.AvgPool1d```: Averages 1D windows of pixels in the input. 29 | * ```attorch.AvgPool2d```: Averages 2D windows of pixels in the input. 30 | * ```attorch.MultiheadAttention```: Applies multi-headed scaled dot-product attention to the inputs. 31 | * ```attorch.ELU```: Applies ELU to the input, optionally fusing dropout. 32 | * ```attorch.Hardshrink```: Applies hard shrink to the input, optionally fusing dropout. 33 | * ```attorch.Hardsigmoid```: Applies hard sigmoid to the input, optionally fusing dropout. 34 | * ```attorch.Hardswish```: Applies hard Swish to the input, optionally fusing dropout. 35 | * ```attorch.Hardtanh```: Applies hard tanh to the input, optionally fusing dropout. 36 | * ```attorch.LeakyReLU```: Applies leaky ReLU to the input, optionally fusing dropout. 37 | * ```attorch.LogSigmoid```: Applies the log of sigmoid to the input, optionally fusing dropout. 38 | * ```attorch.CELU```: Applies CELU to the input, optionally fusing dropout. 39 | * ```attorch.GELU```: Applies GELU to the input, optionally fusing dropout. 40 | * ```attorch.ReLU```: Applies ReLU to the input, optionally fusing dropout. 41 | * ```attorch.ReLU6```: Applies ReLU6 to the input, optionally fusing dropout. 42 | * ```attorch.SELU```: Applies SELU to the input, optionally fusing dropout. 43 | * ```attorch.SiLU```: Applies SiLU to the input, optionally fusing dropout. 44 | * ```attorch.Mish```: Applies Mish to the input, optionally fusing dropout. 45 | * ```attorch.Softplus```: Applies softplus to the input, optionally fusing dropout. 46 | * ```attorch.Softshrink```: Applies softshrink to the input, optionally fusing dropout. 47 | * ```attorch.Softsign```: Applies softsign to the input, optionally fusing dropout. 48 | * ```attorch.Sigmoid```: Applies sigmoid to the input, optionally fusing dropout. 49 | * ```attorch.Tanh```: Applies tanh to the input, optionally fusing dropout. 50 | * ```attorch.Tanhshrink```: Applies tanh to the input, optionally fusing dropout. 51 | * ```attorch.GLU```: Applies the gated linear unit with an arbitrary activation function to the input. 52 | * ```attorch.LogSoftmax```: Normalizes the input using softmax and takes its log. 53 | * ```attorch.Softmax```: Normalizes the input using softmax. 54 | * ```attorch.Softmin```: Normalizes the input using softmin. 55 | * ```attorch.BatchNorm1d```: Batch-normalizes the 2D or 3D input, optionally fusing an activation function and adding a residual to the pre-activation result. 56 | * ```attorch.BatchNorm2d```: Batch-normalizes the 4D input, optionally fusing an activation function and adding a residual to the pre-activation result. 57 | * ```attorch.LayerNorm```: Layer-normalizes the input. 58 | * ```attorch.RMSNorm```: Root-mean-square-normalizes the input. 59 | * ```attorch.Linear```: Linearly transforms the input using weights, optionally adding bias and fusing an activation function. 60 | * ```attorch.Dropout```: Randomly zeroes elements in the input during training. 61 | * ```attorch.L1Loss```: Measures the L1 error (mean absolute error) between the input and target. 62 | * ```attorch.MSELoss```: Measures the squared L2 error (mean squared error) between the input and target. 63 | * ```attorch.CrossEntropyLoss```: Measures the mean cross entropy loss between the input and target, with optional reweighting of each class. 64 | * ```attorch.NLLLoss```: Measures the negative log likelihood loss between the input and target, with optional reweighting of each class. 65 | * ```attorch.HuberLoss```: Measures the Huber loss between the input and target. 66 | * ```attorch.SmoothL1Loss```: Measures the smooth L1 error between the input and target. 67 | 68 | Unless otherwise noted in their docstrings, the aforementioned layers behave identically to their PyTorch equivalents. 69 | 70 | ## Math Functions 71 | Triton kernels are generally composed of two parts: One handles the loading and storing of the relevant tensors, the other transforms the data using appropriate mathematical functions. For instance, a layer normalization kernel reads one or several rows from the input (load), standardizes the features (math), and writes the results into a container (store). A selection of these pure math functions is supplied by ```attorch.math```, the objective being to facilitate the implementation of custom kernels and operation fusion. Although only the forward passes of the said functions are available in ```attorch.math```, thanks to their purity and absence of I/O actions, their gradients can be automatically derived via the [```triton-autodiff```](https://github.com/srush/triton-autodiff) library. Significant portions of attorch's kernels can be refactored by supplanting their math bits with the corresponding ```attorch.math``` transformations or their derivatives, but doing so would sacrifice the single-file and self-contained design of attorch, so ```attorch.math``` and the rest of the library will remain separate. 72 | 73 | ## PyTorch Fallback 74 | 75 | To enable easier integration of attorch and PyTorch layers, ```attorch.nn``` is offered, which provides an interface to attorch's modules with PyTorch fallback should a desired layer not be available, as seen below: 76 | 77 | ```python 78 | from attorch import nn 79 | 80 | 81 | lin = nn.Linear(10, 20) # Uses attorch's linear layer 82 | gap = nn.AdaptiveAvgPool2d(1) # Uses PyTorch's global pooling since GAP is not available in attorch 83 | ``` 84 | 85 | Even though attorch comes with convolutional and pooling layers, the performance of these modules is extremely slow compared to PyTorch. Therefore, ```attorch.nn``` exposes PyTorch's, and not attorch's, convolutions and pools. 86 | 87 | ## Tests 88 | 89 | Each module can be tested against its PyTorch counterpart to ensure correctness. These tests are included under ```tests/``` and can be executed using ```pytest```. A switch, `--subset`, is provided that runs the tests on a smaller subset of data shapes for faster yet less thorough evaluation. It should be noted that some tests may fail owing to numerical precision issues, but in most practical use cases, that should not be a problem. 90 | 91 | ## Examples 92 | 93 | [`examples/`](https://github.com/BobMcDear/attorch/tree/main/examples) contains a handful of common deep learning workflows implemented in both attorch and PyTorch. Although attorch can be used as drop-in replacements for `torch.nn`, some modules offer optional kernel fusion and therefore require additional arguments that are not compatible with their PyTorch analogs. Such distinctions are illustrated in these examples. 94 | 95 | ## Citations 96 | 97 | ```bibtex 98 | @inproceedings{tillet2019triton, 99 | title={Triton: an intermediate language and compiler for tiled neural network computations}, 100 | author={Tillet, Philippe and Kung, Hsiang-Tsung and Cox, David}, 101 | booktitle={Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages}, 102 | pages={10--19}, 103 | year={2019} 104 | } 105 | ``` 106 | 107 | ```bibtex 108 | @Misc{xFormers2022, 109 | author = {Benjamin Lefaudeux and Francisco Massa and Diana Liskovich and Wenhan Xiong and Vittorio Caggiano and Sean Naren and Min Xu and Jieru Hu and Marta Tintore and Susan Zhang and Patrick Labatut and Daniel Haziza and Luca Wehrstedt and Jeremy Reizenstein and Grigory Sizov}, 110 | title = {xFormers: A modular and hackable Transformer modelling library}, 111 | howpublished = {\url{https://github.com/facebookresearch/xformers}}, 112 | year = {2022} 113 | } 114 | ``` 115 | 116 | ```bibtex 117 | @software{yang2024fla, 118 | title = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism}, 119 | author = {Yang, Songlin and Zhang, Yu}, 120 | url = {https://github.com/fla-org/flash-linear-attention}, 121 | month = jan, 122 | year = {2024} 123 | } 124 | ``` 125 | -------------------------------------------------------------------------------- /attorch/math.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pure math operations to be performed on loaded Triton tensors. 3 | """ 4 | 5 | 6 | import triton 7 | import triton.language as tl 8 | 9 | from .act_kernels import apply_act_func 10 | 11 | 12 | @triton.jit 13 | def accum_linear(accum, input1, input2, 14 | fp16: tl.constexpr, tf32: tl.constexpr): 15 | """ 16 | Accumulates matrix multiplications of input tensors for linear functions. 17 | 18 | Args: 19 | accum: Accumulator holding aggregation of matrix multiplications. 20 | The accumulator must be of shape [BLOCK_SIZE1, BLOCK_SIZE3]. 21 | input1: First operand of matrix multiplication. 22 | The operand must be of shape [BLOCK_SIZE1, BLOCK_SIZE2]. 23 | input2: Second operand of matrix multiplication. 24 | The operand must be of shape [BLOCK_SIZE2, BLOCK_SIZE3]. 25 | fp16: Flag for converting operands to FP16. 26 | tf32: Flag for performing matrix multiplication in TF32. 27 | 28 | Returns: 29 | Accumulator with the result of the new matrix multiplication added to it. 30 | """ 31 | if fp16: 32 | input1 = input1.to(tl.float16) 33 | input2 = input2.to(tl.float16) 34 | 35 | return accum + tl.dot(input1, input2, allow_tf32=tf32) 36 | 37 | 38 | @triton.jit 39 | def glu(input1, input2, param, act_func: tl.constexpr): 40 | """ 41 | Applies the gated linear unit with an arbitrary activation function 42 | to the input. 43 | 44 | Args: 45 | input1: First half of input to gate. 46 | The first half must be of the same shape as the second half. 47 | input2: Second half of input to gate. 48 | The second half must be of the same shape as the first half. 49 | param: Parameter in the case of parameterized activation functions. 50 | act_func: Name of activation function to apply. 51 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu', 52 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', 'hardshrink', 53 | and 'softshrink'. 54 | param: Parameter in the case of parameterized activation functions. 55 | 56 | Args: 57 | Input transformed by the gated linear unit 58 | with an arbitrary activation function. 59 | """ 60 | return input1 * apply_act_func(input2, None, None, None, param, act_func, False) 61 | 62 | 63 | @triton.jit 64 | def softmax(input, 65 | neg: tl.constexpr, 66 | log: tl.constexpr): 67 | """ 68 | Normalizes the input using softmax along the last dimension. 69 | 70 | Args: 71 | input: Input to normalize. 72 | The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2]. 73 | neg: Flag indicating if the input should be negated to get softmin. 74 | log: Flag indicating if the log of softmax should be taken. 75 | 76 | Returns: 77 | Input normalized by softmax. 78 | """ 79 | input = input.to(tl.float32) 80 | if neg: 81 | input = -input 82 | 83 | input = input - tl.max(input, axis=1)[:, None] 84 | numerator = tl.exp(input) 85 | denominator = tl.sum(numerator, axis=1)[:, None] 86 | 87 | if log: 88 | output = input - tl.log(denominator) 89 | 90 | else: 91 | output = numerator / denominator 92 | 93 | return output 94 | 95 | 96 | @triton.jit 97 | def calc_mean_and_inv_std(input, last_dim, eps, 98 | last_dim_mask: tl.constexpr): 99 | """ 100 | Calculates the mean and inverse standard deviation of the input 101 | along the last dimension. 102 | 103 | Args: 104 | input: Input whose mean and inverse standard deviation are calculated. 105 | The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2]. 106 | last_dim: Size of the last dimension of input. 107 | eps: Epsilon added in the square root in the denominator 108 | to avoid division by zero. 109 | last_dim_mask: Mask for the last dimension indicating 110 | which elements should be included in the calculations. 111 | The mask must be of shape [BLOCK_SIZE2]. 112 | 113 | Returns: 114 | Mean and inverse standard deviation of the input. 115 | """ 116 | input = input.to(tl.float32) 117 | 118 | mean = tl.sum(input, axis=1) / last_dim 119 | diff = tl.where(last_dim_mask[None, :], input - mean[:, None], 0) 120 | inv_std = tl.rsqrt(tl.sum(diff * diff, axis=1) / last_dim + eps) 121 | 122 | return mean, inv_std 123 | 124 | 125 | @triton.jit 126 | def update_welford(input, prev_count, prev_mean, prev_var, curr_count, 127 | mask: tl.constexpr): 128 | """ 129 | Updates count, mean, and variance (M2) statistics for Welford's algorithm. 130 | 131 | Args: 132 | input: Input used to update statistics. 133 | The input must be of the same shape as the mask. 134 | prev_count: Previous count statistic to update. 135 | prev_mean: Previous mean statistic to update. 136 | prev_var: Previous variance (M2) statistic to update. 137 | curr_count: Count of elements in current input. 138 | mask: Mask indicating which elements should be included in the calculations. 139 | The mask must be of the same shape as the input. 140 | 141 | Returns: 142 | Updated count, mean, and variance (M2) statistics 143 | """ 144 | input = input.to(tl.float32) 145 | 146 | count = prev_count + curr_count 147 | mean = (tl.sum(input) - curr_count * prev_mean) / count 148 | deltas = tl.where(mask, (input - mean) * (input - prev_mean), 0.) 149 | var = prev_var + tl.sum(deltas) 150 | 151 | return count, mean, var 152 | 153 | 154 | @triton.jit 155 | def update_ema(prev_ema, new_val, momentum): 156 | """ 157 | Updates exponential moving average. 158 | 159 | Args: 160 | prev_ema: Previous exponential moving average. 161 | new_val: Value used to update the exponential moving average. 162 | momentum: Momentum. 163 | 164 | Returns: 165 | Updated running statistic. 166 | """ 167 | return (1 - momentum) * prev_ema + momentum * new_val 168 | 169 | 170 | @triton.jit 171 | def standardize(input, mean, inv_std, weight, bias): 172 | """ 173 | Standardizes the input given its mean and inverse standard deviation, 174 | multiplies the result by weights, and adds a bias vector. 175 | 176 | Args: 177 | input: Input to standardize. 178 | mean: Mean of input. 179 | inv_std: Inverse standard deviation of input. 180 | weight: Weight multiplied by the standardized input. 181 | bias: Bias added to the result of the weight multiplication. 182 | 183 | Returns: 184 | Standardized input. 185 | """ 186 | return weight * inv_std * (input - mean) + bias 187 | 188 | 189 | @triton.jit 190 | def calc_p_loss(input, target, param, size, 191 | p_loss: tl.constexpr, reduction: tl.constexpr): 192 | """ 193 | Measures the smooth L1, L1, squared L2, or Huber loss of the difference 194 | between the input and target. 195 | 196 | Args: 197 | input: Input. 198 | The input must be of shape [BLOCK_SIZE]. 199 | target: Target. 200 | The target must be of shape [BLOCK_SIZE]. 201 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber). 202 | size: Number of elements in the input and target. 203 | This value is used only if reduction is 'mean'. 204 | p_loss: p-norm used to compute the error. 205 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss. 206 | reduction: Reduction strategy for the output. 207 | Options are 'none' for no reduction, 'mean' for averaging the error 208 | across all entries, and 'sum' for summing the error across all entries. 209 | 210 | Returns: 211 | Error. 212 | """ 213 | input = input.to(tl.float32) 214 | target = target.to(tl.float32) 215 | 216 | diff = input - target 217 | 218 | if p_loss == 0: 219 | error = tl.where(diff < param, 0.5 * diff * diff / param, tl.abs(diff) - 0.5 * param) 220 | 221 | elif p_loss == 1: 222 | error = tl.abs(diff) 223 | 224 | elif p_loss == 2: 225 | error = diff * diff 226 | 227 | elif p_loss == 3: 228 | error = tl.where(diff < param, 0.5 * diff * diff, param * (tl.abs(diff) - 0.5 * param)) 229 | 230 | if reduction == 'none': 231 | output = error 232 | 233 | elif reduction == 'mean': 234 | output = tl.sum(error) / size 235 | 236 | elif reduction == 'sum': 237 | output = tl.sum(error) 238 | 239 | return output 240 | 241 | 242 | @triton.jit 243 | def nll_loss(input, size, 244 | reduction: tl.constexpr): 245 | """ 246 | Measures the negative log likelihood loss given log-probabilities of target class. 247 | 248 | Args: 249 | input: Input containing predicted log-probabilities corresponding to target class. 250 | The input can have arbitrary shape. 251 | size: Number of elements in the input. 252 | This value is used only if reduction is 'mean'. 253 | reduction: Reduction strategy for the output. 254 | Options are 'none' for no reduction, 'mean' for averaging the loss 255 | across all entries, and 'sum' for summing the loss across all entries. 256 | 257 | Returns: 258 | Loss. 259 | """ 260 | input = input.to(tl.float32) 261 | 262 | if reduction == 'none': 263 | output = -input 264 | 265 | elif reduction == 'mean': 266 | output = -tl.sum(input) / size 267 | 268 | elif reduction == 'sum': 269 | output = -tl.sum(input) 270 | 271 | return output 272 | 273 | 274 | @triton.jit 275 | def cross_entropy_loss(input, pred): 276 | """ 277 | Measures the per-row cross entropy loss given 278 | input and predicted logits corresponding to target class. 279 | 280 | Args: 281 | input: Input. 282 | The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2]. 283 | pred: Predicted logits corresponding to target class. 284 | The predictions must be of shape [BLOCK_SIZE1]. 285 | 286 | Returns: 287 | Loss. 288 | """ 289 | input = input.to(tl.float32) 290 | pred = pred.to(tl.float32) 291 | 292 | mx = tl.max(input, axis=1) 293 | input -= mx[:, None] 294 | loss = tl.log(tl.sum(tl.exp(input), axis=1)) - pred + mx 295 | 296 | return loss 297 | --------------------------------------------------------------------------------