├── tests
├── __init__.py
├── conftest.py
├── test_dropout_layer.py
├── test_softmax_layers.py
├── test_cross_entropy_loss_layer.py
├── test_nll_loss_layer.py
├── test_p_loss_layers.py
├── test_glu_layer.py
├── test_pooling_layer.py
├── test_act_layers.py
├── test_rms_norm_layer.py
├── test_layer_norm_layer.py
├── test_linear_layer.py
├── test_conv_layer.py
├── test_multi_head_attention_layer.py
├── utils.py
└── test_batch_norm_layer.py
├── examples
├── __init__.py
├── mnist
│ ├── README.md
│ ├── mlp.py
│ └── main.py
├── regression
│ ├── README.md
│ └── main.py
├── README.md
├── wikitext-2
│ ├── README.md
│ ├── gpt.py
│ └── main.py
├── imagenette
│ └── README.md
└── utils.py
├── attorch
├── nn.py
├── types.py
├── __init__.py
├── utils.py
├── dropout_kernels.py
├── dropout_layer.py
├── glu_kernels.py
├── softmax_layers.py
├── glu_layer.py
├── pooling_layer.py
├── p_loss_kernels.py
├── rms_norm_layer.py
├── multi_head_attention_layer.py
├── cross_entropy_loss_layer.py
├── nll_loss_layer.py
├── p_loss_layers.py
├── cross_entropy_loss_kernels.py
├── softmax_kernels.py
├── layer_norm_layer.py
├── linear_layer.py
├── linear_kernels.py
├── rms_norm_kernels.py
└── math.py
├── LICENSE
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Deep learning examples using attorch.
3 | """
4 |
--------------------------------------------------------------------------------
/attorch/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Interface for attorch with PyTorch fallback.
3 | """
4 |
5 |
6 | from torch.nn import *
7 | from attorch import *
8 | from torch.nn import AvgPool1d, AvgPool2d, Conv1d, Conv2d
9 |
--------------------------------------------------------------------------------
/attorch/types.py:
--------------------------------------------------------------------------------
1 | """
2 | Type aliases for certain objects.
3 | """
4 |
5 |
6 | from typing import Any, Optional, Union
7 |
8 | import torch
9 |
10 |
11 | Context = Any
12 | Device = Optional[Union[torch.device, str]]
13 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """
2 | Fixture for subset, a switch for running tests on a small subset of shapes.
3 | """
4 |
5 |
6 | import pytest
7 | from _pytest.config.argparsing import Parser
8 | from pytest import FixtureRequest
9 |
10 |
11 | def pytest_addoption(parser: Parser) -> None:
12 | parser.addoption('--subset',
13 | action='store_true',
14 | help='Flag to test on a small subset of shapes.')
15 |
16 |
17 | @pytest.fixture
18 | def subset(request: FixtureRequest) -> bool:
19 | return request.config.getoption('--subset')
20 |
--------------------------------------------------------------------------------
/examples/mnist/README.md:
--------------------------------------------------------------------------------
1 | # MNIST Classification
2 |
3 | This example trains and benchmarks a multilayer perceptron (MLP) on the MNIST dataset.
4 |
5 | ## Requirements
6 |
7 | The requirement for this example, aside from attorch and its dependencies, is,
8 |
9 | * ```torchvision==0.19.0```
10 |
11 | ## Training
12 |
13 | To run this example, please run ```python -m examples.mnist.main``` from the root directory. The arguments are as follows.
14 | * ```--hidden_dim```: Number of hidden features in the MLP.
15 | * ```--depth```: Number of hidden layers in the MLP.
16 | * ```--epochs```: Number of epochs to train for.
17 | * ```--batch_size```: Batch size.
18 |
--------------------------------------------------------------------------------
/examples/regression/README.md:
--------------------------------------------------------------------------------
1 | # Synthetic Regression
2 |
3 | This example trains and benchmarks a multilayer perceptron (MLP) on synthetic regression data. It relies on ```attorch.nn``` to demonstrate how attorch can be used as a drop-in replacement for PyTorch without kernel fusion or other enhancements.
4 |
5 | ## Requirements
6 |
7 | This example has no requirements aside from attorch and its dependencies.
8 |
9 | ## Training
10 |
11 | To run this example, please run ```python -m examples.regression.main``` from the root directory. The arguments are as follows.
12 | * ```--n_samples```: Number of samples to generate.
13 | * ```--dim```: Dimensionality of synthetic data.
14 | * ```--epochs```: Number of epochs to train for.
15 | * ```--batch_size```: Batch size.
16 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | This folder - inspired by [```pytorch/examples```](https://github.com/pytorch/examples/) - contains a few common deep learning examples whose backends can be switched between PyTorch and attorch to demonstrate attorch's usage and time it against PyTorch. Each example implements relevant models with the option to switch between PyTorch and attorch, and backend-agnostic training can be conducted by running the designated ```main.py``` file, which also performs benchmarks.
4 |
5 | The examples are listed below.
6 |
7 | * [Imagenette image classification](https://github.com/bobmcdear/attorch/examples/imagenette)
8 | * [WikiText-2 language modelling](https://github.com/bobmcdear/attorch/examples/wikitext-2)
9 | * [MNIST classification](https://github.com/bobmcdear/attorch/examples/mnist)
10 | * [Synthetic regression](https://github.com/bobmcdear/attorch/examples/regression)
11 |
--------------------------------------------------------------------------------
/examples/wikitext-2/README.md:
--------------------------------------------------------------------------------
1 | # WikiText-2 Language Modelling
2 |
3 | This example trains and benchmarks a language model on the WikiText-2 dataset.
4 |
5 | ## Requirements
6 |
7 | The requirements for this example, aside from attorch and its dependencies, are,
8 |
9 | * ```datasets==2.18.0```
10 | * ```transformers==4.39.0```
11 |
12 | ## Training
13 |
14 | To run this example, please run ```python -m examples.wikitext-2.main``` from the root directory. The arguments are as follows.
15 |
16 | * ```--model```: Name of language model to train. The only option is ```gpt2```.
17 | * `--downsize`: The depth and width of the model are calculated by dividing GPT2's original depth and width by this factor.
18 | * ```--scheduler```: Learning rate scheduler. Options are `one-cycle` and `cosine`.
19 | * ```--epochs```: Number of epochs to train for.
20 | * ```--batch_size```: Batch size.
21 | * ```--seq_len```: Sequence length.
22 | * ```--num_workers```: Number of workers for data loading.
23 |
--------------------------------------------------------------------------------
/attorch/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | A subset of PyTorch's neural network modules,
3 | written in Python using OpenAI's Triton.
4 | """
5 |
6 |
7 | from . import math, nn
8 | from .act_layers import CELU, ELU, GELU, Hardshrink, Hardsigmoid, Hardswish, Hardtanh, LeakyReLU, \
9 | LogSigmoid, Mish, ReLU, ReLU6, SELU, SiLU, Sigmoid, Softplus, Softshrink, Softsign, Tanh, Tanhshrink
10 | from .batch_norm_layer import BatchNorm1d, BatchNorm2d
11 | from .conv_layer import Conv1d, Conv2d
12 | from .cross_entropy_loss_layer import CrossEntropyLoss
13 | from .dropout_layer import Dropout
14 | from .glu_layer import GLU
15 | from .layer_norm_layer import LayerNorm
16 | from .linear_layer import Linear
17 | from .multi_head_attention_layer import MultiheadAttention
18 | from .nll_loss_layer import NLLLoss
19 | from .p_loss_layers import HuberLoss, L1Loss, MSELoss, SmoothL1Loss
20 | from .pooling_layer import AvgPool1d, AvgPool2d
21 | from .rms_norm_layer import RMSNorm
22 | from .softmax_layers import LogSoftmax, Softmax, Softmin
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Borna Ahmadzadeh
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/test_dropout_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 |
6 | import attorch
7 | from .utils import assert_close, create_input, create_input_like, default_shapes
8 |
9 |
10 | @pytest.mark.parametrize('shape', default_shapes())
11 | @pytest.mark.parametrize('drop_p', [0.0, 0.15, 0.3, 0.5, 0.75, 0.9, 1.0])
12 | def test_dropout_layer(shape: Tuple[int, ...], drop_p: float, subset: bool) -> None:
13 | if subset and (shape not in default_shapes(subset=True)):
14 | return
15 |
16 | input = create_input(shape)
17 | dropout = attorch.Dropout(drop_p)
18 | output = dropout(input)
19 | n_zeroed = (torch.count_nonzero(input) - torch.count_nonzero(output)).item()
20 |
21 | if drop_p == 0:
22 | assert n_zeroed == 0
23 |
24 | elif drop_p == 1:
25 | assert torch.count_nonzero(output).item() == 0
26 |
27 | else:
28 | assert_close((output, torch.where(output == 0, output, input / (1 - drop_p))))
29 | assert_close((n_zeroed / input.numel(), drop_p), rtol=1e-1, atol=5e-2)
30 |
31 | output_grad = create_input_like(output)
32 | output.backward(output_grad)
33 | input_grad = torch.where(output == 0, output, output_grad / (1 - drop_p))
34 |
35 | assert_close((input.grad, input_grad))
36 |
--------------------------------------------------------------------------------
/examples/imagenette/README.md:
--------------------------------------------------------------------------------
1 | # Imagenette Image Classification
2 |
3 | This example trains and benchmarks a vision model on the Imagenette classification dataset.
4 |
5 | ## Requirements
6 |
7 | The requirement for this example, aside from attorch and its dependencies, is,
8 |
9 | * ```torchvision==0.19.0```
10 |
11 | ## Training
12 |
13 | To run this example, please run ```python -m examples.imagenette.main``` from the root directory. The arguments are as follows.
14 | * ```--model```: Name of vision model to train. Options are `resnet18`, `resnet34`, ```resnet14```, ```resnet26```, ```resnet50```, ```resnet101```, ```resnet152```, ```convnext_atto```, ```convnext_femto```, ```convnext_pico```, ```convnext_nano```, ```convnext_tiny```, ```convnext_small```, ```convnext_base```, ```convnext_large```, ```convnext_xlarge```, `vit_tiny_patch16`, `vit_xsmall_patch16`, `vit_small_patch32`, `vit_small_patch16`, `vit_small_patch8`, `vit_medium_patch32`, `vit_medium_patch16`, `vit_base_patch32`, `vit_base_patch16`, `vit_base_patch8`, `vit_large_patch32`, `vit_large_patch16`, and `vit_large_patch14`.
15 | * ```--scheduler```: Learning rate scheduler. Options are `one-cycle` and `cosine`.
16 | * ```--epochs```: Number of epochs to train for.
17 | * ```--batch_size```: Batch size.
18 | * ```--center_crop_size```: Center crop size for validation.
19 | * ```--image_size```: Input image size.
20 | * ```--num_workers```: Number of workers for data loading.
21 |
--------------------------------------------------------------------------------
/tests/test_softmax_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 |
7 | import attorch
8 | from .utils import assert_close, create_input, create_input_like, default_shapes
9 |
10 |
11 | @pytest.mark.parametrize('shape', default_shapes())
12 | @pytest.mark.parametrize('softmax', ['Softmax', 'LogSoftmax', 'Softmin'])
13 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
14 | @pytest.mark.parametrize('amp', [False, True])
15 | def test_softmax_layers(
16 | shape: Tuple[int, ...],
17 | softmax: str,
18 | input_dtype: bool,
19 | amp: bool,
20 | subset: bool,
21 | ) -> None:
22 | if subset and (shape not in default_shapes(subset=True)):
23 | return
24 |
25 | if input_dtype is torch.float16 and not amp:
26 | return
27 |
28 | attorch_input = create_input(shape, dtype=input_dtype)
29 | pytorch_input = create_input(shape, dtype=input_dtype)
30 |
31 | attorch_softmax = getattr(attorch, softmax)(dim=-1)
32 | pytorch_softmax = getattr(nn, softmax)(dim=-1)
33 |
34 | with autocast('cuda', enabled=amp):
35 | attorch_output = attorch_softmax(attorch_input)
36 | pytorch_output = pytorch_softmax(pytorch_input)
37 |
38 | assert_close((attorch_output, pytorch_output))
39 |
40 | attorch_output.backward(create_input_like(attorch_output))
41 | pytorch_output.backward(create_input_like(pytorch_output))
42 |
43 | assert_close((attorch_input.grad, pytorch_input.grad))
44 |
--------------------------------------------------------------------------------
/examples/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for examples.
3 | """
4 |
5 |
6 | from typing import Tuple
7 |
8 | import torch
9 | from torch import autocast, nn
10 | from triton.testing import do_bench
11 |
12 |
13 | def benchmark_fw_and_bw(
14 | model: nn.Module,
15 | amp: bool = True,
16 | **input,
17 | ) -> Tuple[float, float, float]:
18 | """
19 | Benchmarks the forward and backward pass of a model.
20 |
21 | Args:
22 | model: Model to run.
23 | amp: Flag for running the forward pass using automatic mixed precision.
24 | **input: Input to the model.
25 | """
26 | with autocast('cuda', enabled=amp):
27 | fw = do_bench(lambda: model(**input))
28 |
29 | with autocast('cuda', enabled=amp):
30 | output = model(**input)
31 | output_grad = torch.randn_like(output)
32 | bw = do_bench(lambda: output.backward(output_grad, retain_graph=True))
33 |
34 | print(f'Forward pass mean execution time: {fw}')
35 | print(f'Backward pass mean execution time: {bw}')
36 | print(f'Forward plus backward pass mean execution time: {fw+bw}')
37 |
38 |
39 | class AvgMeter:
40 | """
41 | Keeps track of the running average of a series of values.
42 | """
43 | def __init__(self) -> None:
44 | self.reset()
45 |
46 | def reset(self) -> None:
47 | self.sum = 0
48 | self.count = 0
49 |
50 | def update(self, val: float, count: int = 1) -> None:
51 | self.sum += count * val
52 | self.count += count
53 |
54 | @property
55 | def avg(self):
56 | return self.sum / self.count
57 |
--------------------------------------------------------------------------------
/tests/test_cross_entropy_loss_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 |
7 | import attorch
8 | from .utils import assert_close, create_input, default_shapes
9 |
10 |
11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2, max_dim=2))
12 | @pytest.mark.parametrize('weighted', [False, True])
13 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
14 | @pytest.mark.parametrize('amp', [False, True])
15 | def test_cross_entropy_loss_layer(
16 | shape: Tuple[int, ...],
17 | weighted: bool,
18 | input_dtype: bool,
19 | amp: bool,
20 | subset: bool,
21 | ) -> None:
22 | if subset and (shape not in default_shapes(subset=True)):
23 | return
24 |
25 | if input_dtype is torch.float16 and not amp:
26 | return
27 |
28 | attorch_input = create_input(shape, dtype=input_dtype)
29 | pytorch_input = create_input(shape, dtype=input_dtype)
30 | target = torch.randint(0, shape[1],
31 | size=(shape[0],),
32 | device='cuda')
33 | weight = (torch.randn(shape[1], device='cuda')
34 | if weighted else None)
35 |
36 | attorch_loss = attorch.CrossEntropyLoss(weight=weight)
37 | pytorch_loss = nn.CrossEntropyLoss(weight=weight)
38 |
39 | with autocast('cuda', enabled=amp):
40 | attorch_output = attorch_loss(attorch_input, target)
41 | pytorch_output = pytorch_loss(pytorch_input, target)
42 |
43 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3)
44 |
45 | attorch_output.backward()
46 | pytorch_output.backward()
47 |
48 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-3)
49 |
--------------------------------------------------------------------------------
/examples/mnist/mlp.py:
--------------------------------------------------------------------------------
1 | """
2 | Multilayer perceptron (MLP) for MNIST classification.
3 | """
4 |
5 |
6 | from typing import Optional
7 |
8 | from torch import Tensor
9 | from torch import nn
10 |
11 | import attorch
12 |
13 |
14 | class MLP(nn.Module):
15 | """
16 | Transforms the input using a multilayer perceptron (MLP) with an arbitrary
17 | number of hidden layers, optionally computing the loss if targets are passed.
18 |
19 | Args:
20 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend.
21 | in_dim: Number of input features.
22 | hidden_dim: Number of hidden features.
23 | depth: Number of hidden layers.
24 | num_classes: Number of output classes.
25 | """
26 | def __init__(
27 | self,
28 | use_attorch: bool,
29 | in_dim: int = 784,
30 | hidden_dim: int = 128,
31 | depth: int = 1,
32 | num_classes: int = 10,
33 | ) -> None:
34 | super().__init__()
35 |
36 | backend = attorch if use_attorch else nn
37 | layer_fn = lambda dim: ([attorch.Linear(dim, hidden_dim, act_func='relu')]
38 | if use_attorch else [nn.Linear(dim, hidden_dim), nn.ReLU()])
39 |
40 | layers = layer_fn(in_dim)
41 | for _ in range(depth - 1):
42 | layers += layer_fn(hidden_dim)
43 | layers.append(backend.Linear(hidden_dim, num_classes))
44 |
45 | self.layers = nn.Sequential(*layers)
46 | self.loss_func = backend.CrossEntropyLoss()
47 |
48 | def forward(self, input: Tensor, target: Optional[Tensor] = None) -> Tensor:
49 | output = self.layers(input.flatten(start_dim=1))
50 | return output if target is None else self.loss_func(output, target)
51 |
--------------------------------------------------------------------------------
/tests/test_nll_loss_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 |
7 | import attorch
8 | from .utils import assert_close, create_input, create_input_like, default_shapes
9 |
10 |
11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2))
12 | @pytest.mark.parametrize('reduction', ['none', 'mean', 'sum'])
13 | @pytest.mark.parametrize('weighted', [False, True])
14 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
15 | @pytest.mark.parametrize('amp', [False, True])
16 | def test_nll_loss_layer(
17 | shape: Tuple[int, ...],
18 | reduction: str,
19 | weighted: bool,
20 | input_dtype: bool,
21 | amp: bool,
22 | subset: bool,
23 | ) -> None:
24 | if subset and (shape not in default_shapes(subset=True)):
25 | return
26 |
27 | if input_dtype is torch.float16 and not amp:
28 | return
29 |
30 | attorch_input = create_input(shape)
31 | pytorch_input = create_input(shape)
32 | target = torch.randint(0, shape[1],
33 | size=(shape[0], *shape[2:]),
34 | device='cuda')
35 | weight = (torch.randn(shape[1], device='cuda', dtype=torch.float32)
36 | if weighted else None)
37 |
38 | attorch_loss = attorch.NLLLoss(reduction=reduction, weight=weight)
39 | pytorch_loss = nn.NLLLoss(reduction=reduction, weight=weight)
40 |
41 | with autocast('cuda', enabled=amp):
42 | attorch_output = attorch_loss(attorch_input, target)
43 | pytorch_output = pytorch_loss(pytorch_input, target)
44 |
45 | assert_close((attorch_output, pytorch_output))
46 |
47 | attorch_output.backward(create_input_like(attorch_output))
48 | pytorch_output.backward(create_input_like(pytorch_output))
49 |
50 | assert_close((attorch_input.grad, pytorch_input.grad))
51 |
--------------------------------------------------------------------------------
/tests/test_p_loss_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 |
7 | import attorch
8 | from .utils import assert_close, create_input, create_input_like, default_shapes
9 |
10 |
11 | @pytest.mark.parametrize('shape', default_shapes())
12 | @pytest.mark.parametrize('p_loss', ['SmoothL1Loss', 'L1Loss', 'MSELoss', 'HuberLoss'])
13 | @pytest.mark.parametrize('reduction', ['none', 'mean', 'sum'])
14 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
15 | @pytest.mark.parametrize('amp', [False, True])
16 | def test_p_loss_layers(
17 | shape: Tuple[int, ...],
18 | p_loss: str,
19 | reduction: str,
20 | input_dtype: bool,
21 | amp: bool,
22 | subset: bool,
23 | ) -> None:
24 | if subset and (shape not in default_shapes(subset=True)):
25 | return
26 |
27 | if input_dtype is torch.float16 and not amp:
28 | return
29 |
30 | attorch_input = create_input(shape, dtype=input_dtype)
31 | attorch_target = create_input(shape, dtype=input_dtype, seed=1)
32 |
33 | pytorch_input = create_input(shape, dtype=input_dtype)
34 | pytorch_target = create_input(shape, dtype=input_dtype, seed=1)
35 |
36 | attorch_loss = getattr(attorch, p_loss)(reduction=reduction)
37 | pytorch_loss = getattr(nn, p_loss)(reduction=reduction)
38 |
39 | with autocast('cuda', enabled=amp):
40 | attorch_output = attorch_loss(attorch_input, attorch_target)
41 | pytorch_output = pytorch_loss(pytorch_input, pytorch_target)
42 |
43 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3)
44 |
45 | attorch_output.backward(create_input_like(attorch_output))
46 | pytorch_output.backward(create_input_like(pytorch_output))
47 |
48 | assert_close((attorch_input.grad, pytorch_input.grad),
49 | (attorch_target.grad, pytorch_target.grad),
50 | rtol=1e-3, atol=1e-3)
51 |
--------------------------------------------------------------------------------
/tests/test_glu_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast
6 | from torch.nn import functional as F
7 |
8 | import attorch
9 | from .utils import assert_close, create_input, create_input_like, default_shapes
10 |
11 |
12 | @pytest.mark.parametrize('shape', default_shapes(max_dim=3))
13 | @pytest.mark.parametrize('act_func', ['sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'silu',
14 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu',
15 | 'mish', 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_0.01',
16 | 'elu_1', 'celu_1', 'hardshrink_0.5', 'softshrink_0.5'])
17 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
18 | @pytest.mark.parametrize('amp', [False, True])
19 | def test_glu_layer(
20 | shape: Tuple[int, ...],
21 | act_func: str,
22 | input_dtype: bool,
23 | amp: bool,
24 | subset: bool,
25 | ) -> None:
26 | if subset and (shape not in default_shapes(subset=True)):
27 | return
28 |
29 | if input_dtype is torch.float16 and not amp:
30 | return
31 |
32 | attorch_input = create_input(shape, dtype=input_dtype)
33 | pytorch_input = create_input(shape, dtype=input_dtype)
34 |
35 | attorch_glu = attorch.GLU(act_func=act_func)
36 | pytorch_glu = lambda input1, input2: input1 * getattr(F, act_func.rsplit('_', 1)[0])(input2)
37 |
38 | with autocast('cuda', enabled=amp):
39 | attorch_output = attorch_glu(attorch_input)
40 | pytorch_output = pytorch_glu(*pytorch_input.chunk(2, dim=-1))
41 |
42 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3)
43 |
44 | attorch_output.backward(create_input_like(attorch_output))
45 | pytorch_output.backward(create_input_like(pytorch_output))
46 |
47 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-3)
48 |
--------------------------------------------------------------------------------
/tests/test_pooling_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 |
7 | import attorch
8 | from .utils import assert_close, create_input, create_input_like, default_shapes
9 |
10 |
11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=3, max_dim=4))
12 | @pytest.mark.parametrize('kernel_size', [2, 3, 4])
13 | @pytest.mark.parametrize('stride', [None, 1, 2])
14 | @pytest.mark.parametrize('padding', [0, 1, -1])
15 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
16 | @pytest.mark.parametrize('amp', [False, True])
17 | def test_pooling_layer(
18 | shape: Tuple[int, ...],
19 | kernel_size: Union[int, Tuple[int, int]],
20 | stride: Union[int, Tuple[int, int]],
21 | padding: Union[int, Tuple[int, int]],
22 | input_dtype: bool,
23 | amp: bool,
24 | subset: bool,
25 | ) -> None:
26 | if subset and (shape not in default_shapes(subset=True)):
27 | return
28 |
29 | if input_dtype is torch.float16 and not amp:
30 | return
31 |
32 | if padding == -1:
33 | padding = kernel_size // 2
34 |
35 | attorch_input = create_input(shape, dtype=input_dtype)
36 | pytorch_input = create_input(shape, dtype=input_dtype)
37 |
38 | pooling_name = 'AvgPool2d' if len(shape) == 4 else 'AvgPool1d'
39 | attorch_pool = getattr(attorch, pooling_name)(kernel_size, stride, padding)
40 | pytorch_pool = getattr(nn, pooling_name)(kernel_size, stride, padding)
41 |
42 | with autocast('cuda', enabled=amp):
43 | attorch_output = attorch_pool(attorch_input)
44 | pytorch_output = pytorch_pool(pytorch_input)
45 |
46 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-2)
47 |
48 | attorch_output.backward(create_input_like(attorch_output))
49 | pytorch_output.backward(create_input_like(pytorch_output))
50 |
51 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-2)
52 |
--------------------------------------------------------------------------------
/tests/test_act_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 |
7 | import attorch
8 | from .utils import assert_close, create_input, create_input_like, default_shapes
9 |
10 |
11 | @pytest.mark.parametrize('shape', default_shapes())
12 | @pytest.mark.parametrize('act_func', ['Sigmoid', 'LogSigmoid', 'Tanh', 'ReLU', 'GELU', 'SiLU',
13 | 'ReLU6', 'Hardsigmoid', 'Hardtanh', 'Hardswish', 'SELU',
14 | 'Mish', 'Softplus', 'Softsign', 'Tanhshrink', 'LeakyReLU',
15 | 'ELU', 'CELU', 'Hardshrink', 'Softshrink'])
16 | @pytest.mark.parametrize('drop_p', [0.0, 0.5])
17 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
18 | @pytest.mark.parametrize('amp', [False, True])
19 | def test_act_layers(
20 | shape: Tuple[int, ...],
21 | act_func: str,
22 | drop_p: float,
23 | input_dtype: bool,
24 | amp: bool,
25 | subset: bool,
26 | ) -> None:
27 | if subset and (shape not in default_shapes(subset=True)):
28 | return
29 |
30 | if input_dtype is torch.float16 and not amp:
31 | return
32 |
33 | attorch_input = create_input(shape, dtype=input_dtype)
34 | pytorch_input = create_input(shape, dtype=input_dtype)
35 |
36 | attorch_act_func = getattr(attorch, act_func)(drop_p=drop_p)
37 | pytorch_act_func = getattr(nn, act_func)()
38 |
39 | with autocast('cuda', enabled=amp):
40 | attorch_output = attorch_act_func(attorch_input)
41 | pytorch_output = pytorch_act_func(pytorch_input)
42 |
43 | if drop_p > 0.0:
44 | mask = attorch_output == 0
45 | pytorch_output = torch.where(mask, 0, pytorch_output / (1 - drop_p))
46 |
47 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3)
48 |
49 | attorch_output.backward(create_input_like(attorch_output))
50 | pytorch_output.backward(create_input_like(pytorch_output))
51 |
52 | assert_close((attorch_input.grad, pytorch_input.grad), rtol=1e-3, atol=1e-3)
53 |
--------------------------------------------------------------------------------
/tests/test_rms_norm_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 | from torch.nn import init
7 |
8 | import attorch
9 | from .utils import assert_close, create_input, create_input_like, default_shapes
10 |
11 |
12 | @pytest.mark.parametrize('shape', default_shapes(max_dim=3))
13 | @pytest.mark.parametrize('eps', [1e-5, 1e-6])
14 | @pytest.mark.parametrize('elementwise_affine', [False, True])
15 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
16 | @pytest.mark.parametrize('amp', [False, True])
17 | def test_rms_norm_layer(
18 | shape: Tuple[int, ...],
19 | eps: float,
20 | elementwise_affine: bool,
21 | input_dtype: bool,
22 | amp: bool,
23 | subset: bool,
24 | ) -> None:
25 | if subset and (shape not in default_shapes(subset=True)):
26 | return
27 |
28 | if input_dtype is torch.float16 and not amp:
29 | return
30 |
31 | attorch_input = create_input(shape, dtype=input_dtype)
32 | pytorch_input = create_input(shape, dtype=input_dtype)
33 |
34 | attorch_rms_norm = attorch.RMSNorm(shape[-1], eps, elementwise_affine)
35 | pytorch_rms_norm = nn.RMSNorm(shape[-1], eps, elementwise_affine,
36 | device='cuda')
37 |
38 | if elementwise_affine:
39 | torch.manual_seed(0)
40 | init.normal_(attorch_rms_norm.weight)
41 |
42 | torch.manual_seed(0)
43 | init.normal_(pytorch_rms_norm.weight)
44 |
45 | with autocast('cuda', enabled=amp):
46 | attorch_output = attorch_rms_norm(attorch_input)
47 | pytorch_output = pytorch_rms_norm(pytorch_input)
48 |
49 | assert_close((attorch_output, pytorch_output),
50 | rtol=1e-3, atol=1e-3)
51 |
52 | attorch_output.backward(create_input_like(attorch_output))
53 | pytorch_output.backward(create_input_like(pytorch_output))
54 |
55 | weight_grad_pair = ((attorch_rms_norm.weight.grad, pytorch_rms_norm.weight.grad)
56 | if elementwise_affine else (None, None))
57 | assert_close((attorch_input.grad, pytorch_input.grad),
58 | weight_grad_pair,
59 | rtol=1e-3, atol=1e-3)
60 |
--------------------------------------------------------------------------------
/tests/test_layer_norm_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 | from torch.nn import init
7 |
8 | import attorch
9 | from .utils import assert_close, create_input, create_input_like, default_shapes
10 |
11 |
12 | @pytest.mark.parametrize('shape', default_shapes(max_dim=3))
13 | @pytest.mark.parametrize('eps', [1e-5, 1e-6])
14 | @pytest.mark.parametrize('elementwise_affine', [False, True])
15 | @pytest.mark.parametrize('bias', [False, True])
16 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
17 | @pytest.mark.parametrize('amp', [False, True])
18 | def test_layer_norm_layer(
19 | shape: Tuple[int, ...],
20 | eps: float,
21 | elementwise_affine: bool,
22 | bias: bool,
23 | input_dtype: bool,
24 | amp: bool,
25 | subset: bool,
26 | ) -> None:
27 | if subset and (shape not in default_shapes(subset=True)):
28 | return
29 |
30 | if input_dtype is torch.float16 and not amp:
31 | return
32 |
33 | attorch_input = create_input(shape, dtype=input_dtype)
34 | pytorch_input = create_input(shape, dtype=input_dtype)
35 |
36 | attorch_layer_norm = attorch.LayerNorm(shape[-1], eps, elementwise_affine, bias)
37 | pytorch_layer_norm = nn.LayerNorm(shape[-1], eps, elementwise_affine, bias,
38 | device='cuda')
39 |
40 | if elementwise_affine:
41 | torch.manual_seed(0)
42 | init.normal_(attorch_layer_norm.weight)
43 | if bias:
44 | init.normal_(attorch_layer_norm.bias)
45 |
46 | torch.manual_seed(0)
47 | init.normal_(pytorch_layer_norm.weight)
48 | if bias:
49 | init.normal_(pytorch_layer_norm.bias)
50 |
51 | with autocast('cuda', enabled=amp):
52 | attorch_output = attorch_layer_norm(attorch_input)
53 | pytorch_output = pytorch_layer_norm(pytorch_input)
54 |
55 | assert_close((attorch_output, pytorch_output),
56 | rtol=1e-3, atol=1e-3)
57 |
58 | attorch_output.backward(create_input_like(attorch_output))
59 | pytorch_output.backward(create_input_like(pytorch_output))
60 |
61 | weight_grad_pair = ((attorch_layer_norm.weight.grad, pytorch_layer_norm.weight.grad)
62 | if elementwise_affine else (None, None))
63 | bias_grad_pair = ((attorch_layer_norm.bias.grad, pytorch_layer_norm.bias.grad)
64 | if elementwise_affine and bias else (None, None))
65 | assert_close((attorch_input.grad, pytorch_input.grad),
66 | weight_grad_pair, bias_grad_pair,
67 | rtol=1e-3, atol=1e-3)
68 |
--------------------------------------------------------------------------------
/tests/test_linear_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 | from torch.nn import functional as F
7 |
8 | import attorch
9 | from .utils import assert_close, create_input, create_input_like, default_shapes
10 |
11 |
12 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2, max_dim=3))
13 | @pytest.mark.parametrize('out_dim', [16, 96, 128, 196, 384, 512, 768, 1024])
14 | @pytest.mark.parametrize('bias', [False, True])
15 | @pytest.mark.parametrize('act_func', [None, 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'silu',
16 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu',
17 | 'mish', 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_0.01',
18 | 'elu_1', 'celu_1', 'hardshrink_0.5', 'softshrink_0.5'])
19 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
20 | @pytest.mark.parametrize('amp', [False, True])
21 | def test_linear_layer(
22 | shape: Tuple[int, ...],
23 | out_dim: int,
24 | bias: bool,
25 | act_func: Optional[str],
26 | input_dtype: bool,
27 | amp: bool,
28 | subset: bool,
29 | ) -> None:
30 | if subset and (shape not in default_shapes(subset=True)):
31 | return
32 |
33 | if input_dtype is torch.float16 and not amp:
34 | return
35 |
36 | attorch_input = create_input(shape, dtype=input_dtype)
37 | pytorch_input = create_input(shape, dtype=input_dtype)
38 |
39 | torch.manual_seed(0)
40 | attorch_linear = attorch.Linear(shape[-1], out_dim,
41 | bias=bias,
42 | act_func=act_func)
43 |
44 | torch.manual_seed(0)
45 | pytorch_linear = nn.Linear(shape[-1], out_dim,
46 | bias=bias, device='cuda')
47 | pytorch_act = nn.Identity() if act_func is None else getattr(F, act_func.rsplit('_', 1)[0])
48 |
49 | with autocast('cuda', enabled=amp):
50 | attorch_output = attorch_linear(attorch_input)
51 | pytorch_output = pytorch_act(pytorch_linear(pytorch_input))
52 |
53 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-3)
54 |
55 | attorch_output.backward(create_input_like(attorch_output))
56 | pytorch_output.backward(create_input_like(pytorch_output))
57 |
58 | bias_grad_pair = ((attorch_linear.bias.grad, pytorch_linear.bias.grad)
59 | if bias else (None, None))
60 | assert_close((attorch_input.grad, pytorch_input.grad),
61 | (attorch_linear.weight.grad, pytorch_linear.weight.grad.T.contiguous()),
62 | bias_grad_pair, rtol=1e-3, atol=1e-3)
63 |
--------------------------------------------------------------------------------
/attorch/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for attorch kernels and layers.
3 | """
4 |
5 |
6 | from typing import List, Optional
7 |
8 | import torch
9 | import triton
10 |
11 |
12 | def allow_tf32() -> bool:
13 | """
14 | Returns whether the current GPU architecture supports TF32.
15 | """
16 | return torch.cuda.get_device_capability()[0] >= 8
17 |
18 |
19 | def get_n_stages(n_stages: int = 2) -> int:
20 | """
21 | Receives number of stages for software pipelining and returns it as-is
22 | if the GPU architecture is Ampere or newer and 2 otherwise.
23 | """
24 | return 2 if torch.cuda.get_device_capability()[0] < 8 else n_stages
25 |
26 |
27 | def get_output_dtype(
28 | input_dtype: torch.dtype = torch.float32,
29 | autocast: Optional[str] = None,
30 | ) -> torch.dtype:
31 | """
32 | Returns the appropriate output dtype for automatic mixed precision
33 | given the input dtype and the operation's autocast behaviour.
34 |
35 | Args:
36 | input_dtype: Input dtype.
37 | autocast: The relevent operation's autocast behaviour.
38 | None signifies the input dtype should flow through,
39 | 'fp16' signifies autocasting to FP16 when AMP is enabled,
40 | and 'fp32' signifies autocasting to FP32 when AMP is enabled.
41 | """
42 | dtype = torch.get_autocast_dtype('cuda')
43 | assert dtype, \
44 | f'Only autocast to float16 is supported, received {dtype}'
45 |
46 | if torch.is_autocast_enabled():
47 | if autocast is None:
48 | return input_dtype
49 |
50 | elif autocast == 'fp16':
51 | return torch.float16
52 |
53 | elif autocast == 'fp32':
54 | return torch.float32
55 |
56 | else:
57 | raise RuntimeError(f'Autocast type {autocast} is invalid. '
58 | 'Options are None, fp16, and fp32')
59 |
60 | else:
61 | return input_dtype
62 |
63 |
64 | def element_wise_kernel_configs(
65 | block_name: str = 'BLOCK_SIZE',
66 | ) -> List[triton.Config]:
67 | """
68 | Returns kernel configurations for element-wise operations.
69 |
70 | Args:
71 | block_name: Name of block argument rows are distributed over.
72 | """
73 | return [triton.Config({block_name: 64}, num_warps=2),
74 | triton.Config({block_name: 128}, num_warps=2),
75 | triton.Config({block_name: 256}, num_warps=4),
76 | triton.Config({block_name: 512}, num_warps=4),
77 | triton.Config({block_name: 1024}, num_warps=4)]
78 |
79 |
80 | def warps_kernel_configs() -> List[triton.Config]:
81 | """
82 | Returns kernel configurations with all possible number of warps.
83 | """
84 | return [triton.Config({}, num_warps=2**i) for i in range(6)]
85 |
--------------------------------------------------------------------------------
/tests/test_conv_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 |
7 | import attorch
8 | from .utils import assert_close, create_input, create_input_like, default_shapes
9 |
10 |
11 | @pytest.mark.parametrize('shape', default_shapes(min_dim=3, max_dim=4))
12 | @pytest.mark.parametrize('out_dim', [16, 96, 128, 196, 384, 512, 768, 1024])
13 | @pytest.mark.parametrize('kernel_size', [1, 2, 3, 4, 5, 7, (5, 3), (3, 5)])
14 | @pytest.mark.parametrize('stride', [1, 2, (1, 2), (2, 1)])
15 | @pytest.mark.parametrize('padding', [0, 1, 3, -1])
16 | @pytest.mark.parametrize('groups', [1, 2, 4, -1])
17 | @pytest.mark.parametrize('bias', [False, True])
18 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
19 | @pytest.mark.parametrize('amp', [False, True])
20 | def test_conv2d_layer(
21 | shape: Tuple[int, ...],
22 | out_dim: int,
23 | kernel_size: Union[int, Tuple[int, int]],
24 | stride: Union[int, Tuple[int, int]],
25 | padding: Union[int, Tuple[int, int]],
26 | groups: int,
27 | bias: bool,
28 | input_dtype: bool,
29 | amp: bool,
30 | subset: bool,
31 | ) -> None:
32 | if subset and (shape not in default_shapes(subset=True)):
33 | return
34 |
35 | if input_dtype is torch.float16 and not amp:
36 | return
37 |
38 | if len(shape) == 3 and (isinstance(kernel_size, tuple) or isinstance(padding, tuple)):
39 | return
40 |
41 | if padding == -1:
42 | padding = kernel_size // 2
43 |
44 | if groups == -1:
45 | groups = shape[1]
46 |
47 | if shape[1] % groups != 0:
48 | groups = 1
49 |
50 | conv_name = 'Conv2d' if len(shape) == 4 else 'Conv1d'
51 | attorch_input = create_input(shape, dtype=input_dtype)
52 | pytorch_input = create_input(shape, dtype=input_dtype)
53 |
54 | torch.manual_seed(0)
55 | attorch_conv = getattr(attorch, conv_name)(shape[1], out_dim,
56 | kernel_size,
57 | stride=stride, padding=padding,
58 | groups=groups, bias=bias)
59 |
60 | torch.manual_seed(0)
61 | pytorch_conv = getattr(nn, conv_name)(shape[1], out_dim,
62 | kernel_size,
63 | stride=stride, padding=padding,
64 | groups=groups, bias=bias,
65 | device='cuda')
66 |
67 | with autocast('cuda', enabled=amp):
68 | attorch_output = attorch_conv(attorch_input)
69 | pytorch_output = pytorch_conv(pytorch_input)
70 |
71 | assert_close((attorch_output, pytorch_output), rtol=1e-3, atol=1e-2)
72 |
73 | attorch_output.backward(create_input_like(attorch_output))
74 | pytorch_output.backward(create_input_like(pytorch_output))
75 |
76 | bias_grad_pair = ((attorch_conv.bias.grad, pytorch_conv.bias.grad)
77 | if bias else (torch.tensor(0), torch.tensor(0)))
78 | assert_close((attorch_input.grad, pytorch_input.grad),
79 | (attorch_conv.weight.grad, pytorch_conv.weight.grad),
80 | bias_grad_pair, rtol=1e-3, atol=1e-2)
81 |
--------------------------------------------------------------------------------
/attorch/dropout_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Kernels for dropout.
3 | """
4 |
5 |
6 | import triton
7 | import triton.language as tl
8 |
9 | from .utils import element_wise_kernel_configs
10 |
11 |
12 | @triton.jit
13 | def apply_dropout(input, drop_p, seed, offset):
14 | """
15 | Randomly zeroes elements in the input.
16 |
17 | Args:
18 | input: Input. The input must be loaded and cannot be a pointer.
19 | drop_p: Probability of dropping an element.
20 | seed: Seed for generating the dropout mask.
21 | offset: Offset to generate the mask for.
22 |
23 | Returns:
24 | Input with elements randomly zeroed out.
25 | """
26 | random = tl.rand(seed, offset)
27 | return tl.where(random < drop_p, 0, input / (1 - drop_p))
28 |
29 |
30 | @triton.jit
31 | def apply_dropout_grad(output_grad, drop_p, seed, offset):
32 | """
33 | Calculates the input gradient of dropout.
34 |
35 | Args:
36 | output_grad: Output gradients. The output gradients must be
37 | loaded and cannot be a pointer.
38 | drop_p: Probability of dropping an element.
39 | seed: Seed for generating the dropout mask.
40 | offset: Offset to generate the mask for.
41 |
42 | Returns:
43 | Gradient of dropout.
44 | """
45 | random = tl.rand(seed, offset)
46 | return tl.where(random < drop_p, 0, output_grad / (1 - drop_p))
47 |
48 |
49 | @triton.autotune(
50 | configs=element_wise_kernel_configs(),
51 | key=['size'],
52 | )
53 | @triton.jit
54 | def dropout_forward_kernel(
55 | input_pointer, output_pointer, size,
56 | drop_p, seed,
57 | BLOCK_SIZE: tl.constexpr,
58 | ):
59 | """
60 | Randomly zeroes elements in the input.
61 |
62 | Args:
63 | input_pointer: Pointer to the input to perform dropout on.
64 | The input must be of shape [size].
65 | output_pointer: Pointer to a container the result is written to.
66 | The container must be of shape [size].
67 | size: Number of elements in the input.
68 | drop_p: Probability of dropping an element.
69 | seed: Seed for generating the dropout mask.
70 | BLOCK_SIZE: Block size.
71 | """
72 | # This program processes BLOCK_SIZE rows.
73 | pid = tl.program_id(axis=0)
74 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
75 | mask = offset < size
76 |
77 | input = tl.load(input_pointer + offset, mask=mask)
78 | output = apply_dropout(input, drop_p, seed, offset)
79 | tl.store(output_pointer + offset, output, mask=mask)
80 |
81 |
82 | @triton.autotune(
83 | configs=element_wise_kernel_configs(),
84 | key=['size'],
85 | )
86 | @triton.jit
87 | def dropout_backward_kernel(
88 | output_grad_pointer, input_grad_pointer, size,
89 | drop_p, seed,
90 | BLOCK_SIZE: tl.constexpr,
91 | ):
92 | """
93 | Calculates the input gradient of dropout.
94 |
95 | Args:
96 | output_grad_pointer: Pointer to dropout's output gradients.
97 | The output gradients must be of shape [size].
98 | input_grad_pointer: Pointer to a container the input's gradients are written to.
99 | The container must be of shape [size].
100 | size: Number of elements in the input.
101 | drop_p: Probability of dropping an element used in dropout.
102 | seed: Seed for generating the dropout mask.
103 | BLOCK_SIZE: Block size.
104 | """
105 | # This program processes BLOCK_SIZE rows.
106 | pid = tl.program_id(axis=0)
107 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
108 | mask = offset < size
109 |
110 | output_grad = tl.load(output_grad_pointer + offset, mask=mask)
111 | input_grad = apply_dropout_grad(output_grad, drop_p, seed, offset)
112 | tl.store(input_grad_pointer + offset, input_grad, mask=mask)
113 |
--------------------------------------------------------------------------------
/attorch/dropout_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Dropout layer with PyTorch autodiff support.
3 | """
4 |
5 |
6 | import warnings
7 | from random import randint
8 | from typing import Optional, Tuple
9 |
10 | import torch
11 | from torch import Tensor
12 | from torch import nn
13 | from torch.amp import custom_bwd, custom_fwd
14 | from triton import cdiv
15 |
16 | from .dropout_kernels import dropout_backward_kernel, dropout_forward_kernel
17 | from .types import Context
18 |
19 |
20 | class DropoutAutoGrad(torch.autograd.Function):
21 | """
22 | Autodiff for dropout.
23 | """
24 | @staticmethod
25 | @custom_fwd(device_type='cuda')
26 | def forward(
27 | ctx: Context,
28 | input: Tensor,
29 | drop_p: float,
30 | training: bool,
31 | ) -> Tensor:
32 | """
33 | Randomly zeroes elements in the input.
34 |
35 | Args:
36 | ctx: Context for variable storage.
37 | input: Input to perform dropout on.
38 | Can have arbitrary shape.
39 | drop_p: Probability of dropping an element.
40 | training: Flag indicating if the model is in training mode.
41 | If False, no dropout is applied.
42 |
43 | Returns:
44 | Input with some elements zeroed out.
45 | """
46 | ctx.do_dropout = True
47 | if not training or drop_p == 0:
48 | ctx.do_dropout = False
49 | return input
50 |
51 | ctx.drop_all = False
52 | if drop_p == 1:
53 | ctx.drop_all = True
54 | return torch.zeros_like(input)
55 |
56 | flattened_input = input.flatten()
57 | size = len(flattened_input)
58 | output = torch.empty_like(flattened_input)
59 |
60 | seed = randint(0, 65535)
61 | ctx.seed = seed
62 | ctx.drop_p = drop_p
63 |
64 | # Launches 1D grid where each program operates over
65 | # BLOCK_SIZE elements.
66 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),)
67 | dropout_forward_kernel[grid](flattened_input, output,
68 | size, drop_p, seed)
69 |
70 | return output.view_as(input)
71 |
72 | @staticmethod
73 | @custom_bwd(device_type='cuda')
74 | def backward(
75 | ctx: Context,
76 | output_grad: Tensor,
77 | ) -> Tuple[Optional[Tensor], ...]:
78 | """
79 | Calculates the input gradient of dropout.
80 |
81 | Args:
82 | ctx: Context containing stored variables.
83 | output_grad: Output gradients.
84 | Must be the same shape as the output.
85 |
86 | Returns:
87 | Input gradient of dropout.
88 | """
89 | if not ctx.do_dropout:
90 | return output_grad, None, None
91 |
92 | if ctx.drop_all:
93 | return torch.zeros_like(output_grad), None, None
94 |
95 | orig_shape = output_grad.shape
96 | output_grad = output_grad.flatten()
97 | size = len(output_grad)
98 | input_grad = torch.empty_like(output_grad)
99 |
100 | # Launches 1D grid where each program operates over
101 | # BLOCK_SIZE elements.
102 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),)
103 | dropout_backward_kernel[grid](output_grad, input_grad,
104 | size, ctx.drop_p, ctx.seed)
105 |
106 | # Pads output with None because a gradient is necessary for
107 | # all input arguments.
108 | return input_grad.view(orig_shape), None, None
109 |
110 |
111 | class Dropout(nn.Dropout):
112 | """
113 | Randomly zeroes elements in the input during training.
114 | See also base class.
115 |
116 | Args:
117 | p: Probability of dropping an element.
118 | inplace: This is a dummy argument and has no effects,
119 | as in-place is currently not supported.
120 | """
121 | def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
122 | super().__init__(p=p, inplace=False)
123 |
124 | if inplace is True:
125 | warnings.warn('In-place dropout currently not supported; '
126 | 'falling back to out-of-place.')
127 |
128 | def forward(self, input: Tensor) -> Tensor:
129 | return DropoutAutoGrad.apply(input, self.p, self.training)
130 |
--------------------------------------------------------------------------------
/tests/test_multi_head_attention_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 | from torch.nn import init
7 |
8 | import attorch
9 | from .utils import assert_close, create_input, create_input_like, default_shapes
10 |
11 |
12 | @pytest.mark.parametrize('shape', default_shapes(min_dim=3, max_dim=3))
13 | @pytest.mark.parametrize('self_attention', [False, True])
14 | @pytest.mark.parametrize('num_heads', [1, 2, 4, 6, 8])
15 | @pytest.mark.parametrize('bias', [False, True])
16 | @pytest.mark.parametrize('causal', [False, True])
17 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
18 | @pytest.mark.parametrize('amp', [False, True])
19 | def test_multi_head_attention_layer(
20 | shape: Tuple[int, ...],
21 | self_attention: bool,
22 | num_heads: int,
23 | bias: bool,
24 | causal: bool,
25 | input_dtype: bool,
26 | amp: bool,
27 | subset: bool,
28 | ) -> None:
29 | if subset and (shape not in default_shapes(subset=True)):
30 | return
31 |
32 | if (input_dtype is torch.float16 and not amp) or (shape[-1] % num_heads != 0):
33 | return
34 |
35 | attorch_input_q = create_input(shape, dtype=input_dtype)
36 | pytorch_input_q = create_input(shape, dtype=input_dtype)
37 |
38 | if self_attention:
39 | attorch_input_k = attorch_input_q
40 | pytorch_input_k = pytorch_input_q
41 |
42 | else:
43 | attorch_input_k = create_input(shape, dtype=input_dtype, seed=1)
44 | pytorch_input_k = create_input(shape, dtype=input_dtype, seed=1)
45 |
46 | attorch_input_v = attorch_input_k
47 | pytorch_input_v = pytorch_input_k
48 |
49 | torch.manual_seed(0)
50 | attorch_multi_head_attention = attorch.MultiheadAttention(shape[-1], num_heads,
51 | bias=bias,
52 | batch_first=True)
53 |
54 | torch.manual_seed(0)
55 | pytorch_multi_head_attention = nn.MultiheadAttention(shape[-1], num_heads,
56 | bias=bias,
57 | batch_first=True,
58 | device='cuda')
59 |
60 | if bias:
61 | torch.manual_seed(0)
62 | init.normal_(attorch_multi_head_attention.in_proj_bias)
63 | init.normal_(attorch_multi_head_attention.out_proj.bias)
64 |
65 | torch.manual_seed(0)
66 | init.normal_(pytorch_multi_head_attention.in_proj_bias)
67 | init.normal_(pytorch_multi_head_attention.out_proj.bias)
68 |
69 | with autocast('cuda', enabled=amp):
70 | attorch_output = attorch_multi_head_attention(attorch_input_q,
71 | attorch_input_k,
72 | attorch_input_v,
73 | causal=causal)
74 | pytorch_output = pytorch_multi_head_attention(pytorch_input_q,
75 | pytorch_input_k,
76 | pytorch_input_v,
77 | attn_mask=torch.empty(2*(shape[1],)) if causal else None,
78 | is_causal=causal,
79 | need_weights=False)[0]
80 |
81 | assert_close((attorch_output, pytorch_output),
82 | rtol=1e-2, atol=1e-2)
83 |
84 | attorch_output.backward(create_input_like(attorch_output))
85 | pytorch_output.backward(create_input_like(attorch_output))
86 |
87 | bias_grad_pair = ((attorch_multi_head_attention.in_proj_bias,
88 | attorch_multi_head_attention.in_proj_bias),
89 | (attorch_multi_head_attention.out_proj.bias,
90 | attorch_multi_head_attention.out_proj.bias)
91 | if bias else ((None, None), (None, None)))
92 | assert_close((attorch_input_q.grad, pytorch_input_q.grad),
93 | (attorch_input_k.grad, pytorch_input_k.grad),
94 | (attorch_input_v.grad, pytorch_input_v.grad),
95 | (attorch_multi_head_attention.in_proj_weight.grad,
96 | pytorch_multi_head_attention.in_proj_weight.grad),
97 | (attorch_multi_head_attention.out_proj.weight.grad,
98 | pytorch_multi_head_attention.out_proj.weight.grad),
99 | *bias_grad_pair,
100 | rtol=1e-2, atol=1e-2)
101 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for tests.
3 | """
4 |
5 |
6 | from typing import List, Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 |
11 | from attorch.types import Device
12 |
13 |
14 | def default_shapes(
15 | min_dim: int = 0,
16 | max_dim: int = 4,
17 | subset: bool = False,
18 | ) -> List[Tuple[int, ...]]:
19 | """
20 | Returns typical data shapes for testing.
21 |
22 | Args:
23 | min_dim: Minimum dimensionality of the shapes returned.
24 | max_dim: Maximum dimensionality of the shapes returned.
25 | subset: Returns a subset of data shapes, useful for rapid testing.
26 | """
27 | if subset:
28 | shapes = [(96,),
29 | (768,),
30 | (4, 2000),
31 | (8, 1024),
32 | (8, 960, 196),
33 | (64, 768, 128),
34 | (64, 64, 56, 56),
35 | (256, 128, 28, 28)]
36 | else:
37 | shapes = [(96,),
38 | (128,),
39 | (196,),
40 | (384,),
41 | (768,),
42 | (1024,),
43 | (3200,),
44 | (4800,),
45 | (8000,),
46 | (12288,),
47 | (1, 8000),
48 | (4, 2000),
49 | (8, 1024),
50 | (32, 1024),
51 | (128, 1024),
52 | (2048, 768),
53 | (6144, 256),
54 | (8096, 32),
55 | (12288, 1),
56 | (1, 1024, 3072),
57 | (8, 960, 196),
58 | (64, 768, 128),
59 | (128, 960, 196),
60 | (2048, 64, 16),
61 | (1, 3, 224, 224),
62 | (8, 3, 224, 224),
63 | (64, 64, 56, 56),
64 | (256, 128, 28, 28),
65 | (256, 2048, 7, 7)]
66 | return list(filter(lambda shape: min_dim <= len(shape) <= max_dim, shapes))
67 |
68 |
69 | def create_input(
70 | shape: Tuple[int, ...],
71 | dtype: torch.dtype = torch.float32,
72 | device: Device = 'cuda',
73 | requires_grad: bool = True,
74 | seed: Optional[int] = 0,
75 | ) -> Tensor:
76 | """
77 | Creates a tensor filled with random numbers.
78 |
79 | Args:
80 | shape: Shape of tensor.
81 | dtype: Dtype of tensor.
82 | device: Device of tensor.
83 | requires_grad: Flag for recording operations for autodiff.
84 | seed: Seed for generating random numbers. If None, no seed is set.
85 |
86 | Returns:
87 | Tensor with random numbers.
88 | """
89 | if seed is not None:
90 | torch.manual_seed(seed)
91 |
92 | return torch.randn(shape, dtype=dtype, device=device,
93 | requires_grad=requires_grad)
94 |
95 |
96 | def create_input_like(
97 | input: Tensor,
98 | requires_grad: bool = False,
99 | seed: Optional[int] = 0,
100 | ) -> Tensor:
101 | """
102 | Creates a tensor filled with random numbers with the same size, dtype, and
103 | device as input.
104 |
105 | Args:
106 | input: Input.
107 | requires_grad: Flag for recording operations for autodiff.
108 | seed: Seed for generating random numbers. If None, no seed is set.
109 |
110 | Returns:
111 | Tensor with random numbers.
112 | """
113 | if seed is not None:
114 | torch.manual_seed(seed)
115 |
116 | return torch.randn_like(input, requires_grad=requires_grad)
117 |
118 |
119 | def assert_close(
120 | *tensor_pairs: Tuple[Tensor, Tensor],
121 | rtol: Optional[float] = None,
122 | atol: Optional[float] = None,
123 | ) -> None:
124 | """
125 | Asserts that the two tensors in each pair of tensor_pairs are close.
126 | See also torch.testing.assert_close.
127 |
128 | Args:
129 | *tensor_pairs: Pairs of tensors that are asserted to be close.
130 | rtol: Relative tolerance. If specified, atol must also be specified.
131 | Otherwise, it is selected according to the tensors' dtypes.
132 | atol: Absolute tolerance. If specified, rtol must also be specified.
133 | Otherwise, it is selected according to the tensors' dtypes.
134 | """
135 | for pair in tensor_pairs:
136 | try:
137 | torch.testing.assert_close(*pair, rtol=rtol, atol=atol, equal_nan=True)
138 |
139 | except AssertionError as error:
140 | sum_diffs = torch.abs(pair[0] - pair[1]).sum()
141 | elems = pair[0].numel()
142 | raise RuntimeError(f'Tensors not equal; '
143 | f'average difference per element is {sum_diffs / elems}. '
144 | f'Original error: {error}')
145 |
--------------------------------------------------------------------------------
/attorch/glu_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Kernels for gated linear units with arbitrary activation functions.
3 | """
4 |
5 |
6 | import triton
7 | import triton.language as tl
8 |
9 | from .act_kernels import apply_act_func, apply_act_func_grad
10 | from .utils import element_wise_kernel_configs
11 |
12 |
13 | @triton.autotune(
14 | configs=element_wise_kernel_configs(),
15 | key=['size'],
16 | )
17 | @triton.jit
18 | def glu_forward_kernel(
19 | input1_pointer, input2_pointer, output_pointer, size, param,
20 | act_func: tl.constexpr,
21 | BLOCK_SIZE: tl.constexpr,
22 | ):
23 | """
24 | Applies the gated linear unit with an arbitrary activation function
25 | to the input.
26 |
27 | Args:
28 | input1_pointer: Pointer to the first half of the input to gate.
29 | The first half must be contiguous and contain size elements.
30 | input2_pointer: Pointer to the second half of the input to gate.
31 | The second half must be contiguous and contain size elements.
32 | output_pointer: Pointer to a container the result is written to.
33 | The container must be contiguous and contain size elements.
34 | size: Number of elements in each half of the input.
35 | param: Parameter in the case of parameterized activation functions.
36 | act_func: Name of activation function to apply.
37 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
38 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish',
39 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', and 'hardshrink'.
40 | BLOCK_SIZE: Block size.
41 | """
42 | # This program processes BLOCK_SIZE elements.
43 | pid = tl.program_id(axis=0)
44 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
45 | mask = offset < size
46 |
47 | input1 = tl.load(input1_pointer + offset, mask=mask)
48 | input2 = tl.load(input2_pointer + offset, mask=mask)
49 |
50 | output = input1 * apply_act_func(input2, None, None, None, param,
51 | act_func, False)
52 | tl.store(output_pointer + offset, output, mask=mask)
53 |
54 |
55 | @triton.autotune(
56 | configs=element_wise_kernel_configs(),
57 | key=['size'],
58 | )
59 | @triton.jit
60 | def glu_backward_kernel(
61 | output_grad_pointer, input1_pointer, input2_pointer,
62 | input1_grad_pointer, input2_grad_pointer, size, param,
63 | act_func: tl.constexpr,
64 | BLOCK_SIZE: tl.constexpr,
65 | ):
66 | """
67 | Calculates the input gradient of the gated linear unit.
68 |
69 | Args:
70 | output_grad_pointer: Pointer to the unit's output gradients.
71 | The output gradients must be contiguous and contain size elements.
72 | input1_pointer: Pointer to the first half of the input that was gated.
73 | The first half must be contiguous and contain size elements.
74 | input2_pointer: Pointer to the second half of the input that was gated.
75 | The second half must be contiguous and contain size elements.
76 | input1_grad_pointer: Pointer to a container the first half's gradients are written to.
77 | The container must be contiguous and contain size elements.
78 | input2_grad_pointer: Pointer to a container the second half's gradients are written to.
79 | The container must be contiguous and contain size elements.
80 | size: Number of elements in each half of the input.
81 | param: Parameter in the case of parameterized activation functions.
82 | act_func: Name of activation function to apply.
83 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
84 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', 'hardshrink',
85 | and 'softshrink'.
86 | BLOCK_SIZE: Block size.
87 | """
88 | # This program processes BLOCK_SIZE elements.
89 | pid = tl.program_id(axis=0)
90 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
91 | mask = offset < size
92 |
93 | output_grad = tl.load(output_grad_pointer + offset, mask=mask)
94 | input1 = tl.load(input1_pointer + offset, mask=mask)
95 | input2 = tl.load(input2_pointer + offset, mask=mask)
96 |
97 | input1_grad = output_grad * apply_act_func(input2, None, None, None, param,
98 | act_func, False)
99 | input2_grad = output_grad * input1 * apply_act_func_grad(1, input2,
100 | None, None, None,
101 | param, act_func,
102 | False)
103 |
104 | tl.store(input1_grad_pointer + offset, input1_grad, mask=mask)
105 | tl.store(input2_grad_pointer + offset, input2_grad, mask=mask)
106 |
--------------------------------------------------------------------------------
/attorch/softmax_layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Softmax and related layers with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from triton import cdiv
12 |
13 | from .softmax_kernels import softmax_backward_kernel, softmax_forward_kernel
14 | from .types import Context
15 | from .utils import get_output_dtype
16 |
17 |
18 | class SoftmaxAutoGrad(torch.autograd.Function):
19 | """
20 | Autodiff for softmax and related functions.
21 | """
22 | @staticmethod
23 | def forward(
24 | ctx: Context,
25 | input: Tensor,
26 | neg: bool,
27 | log: bool,
28 | ) -> Tensor:
29 | """
30 | Normalizes the input using softmax.
31 |
32 | Args:
33 | ctx: Context for variable storage.
34 | input: Input to normalize.
35 | Can have arbitrary shape.
36 | neg: Flag indicating if the input should be negated to get softmin.
37 | log: Flag indicating if the log of softmax should be taken.
38 |
39 | Returns:
40 | Input normalized by softmax.
41 | """
42 | flattened_input = input.unsqueeze(0) if input.ndim == 1 else input
43 | flattened_input = flattened_input.flatten(0, -2)
44 | batch_dim, feat_dim = flattened_input.shape
45 |
46 | output_dtype = get_output_dtype(input.dtype, autocast='fp32')
47 | output = torch.empty_like(flattened_input, dtype=output_dtype)
48 |
49 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
50 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),)
51 | softmax_forward_kernel[grid](flattened_input, output, batch_dim, feat_dim,
52 | *flattened_input.stride(), *output.stride(),
53 | neg=neg, log=log)
54 |
55 | ctx.neg = neg
56 | ctx.log = log
57 | if input.requires_grad:
58 | ctx.save_for_backward(output)
59 |
60 | return output.view_as(input)
61 |
62 | @staticmethod
63 | def backward(
64 | ctx: Context,
65 | output_grad: Tensor,
66 | ) -> Tuple[Optional[Tensor], ...]:
67 | """
68 | Calculates the input gradient of softmax.
69 |
70 | Args:
71 | ctx: Context containing stored variables.
72 | output_grad: Output gradients.
73 | Must be the same shape as the output.
74 |
75 | Returns:
76 | Input gradient of softmax.
77 | """
78 | (output,) = ctx.saved_tensors
79 | flattened_output_grad = output_grad.view_as(output)
80 |
81 | batch_dim, feat_dim = output.shape
82 | input_grad = torch.empty_like(output)
83 |
84 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
85 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),)
86 | softmax_backward_kernel[grid](flattened_output_grad, output, input_grad,
87 | batch_dim, feat_dim,
88 | *flattened_output_grad.stride(),
89 | *output.stride(), *input_grad.stride(),
90 | neg=ctx.neg, log=ctx.log)
91 |
92 | # Pads output with None because a gradient is necessary for
93 | # all input arguments.
94 | return input_grad.view_as(output_grad), None, None
95 |
96 |
97 | class Softmax(nn.Softmax):
98 | """
99 | Normalizes the input using softmax.
100 | See also base class.
101 |
102 | Args:
103 | dim: Dimension along which softmax will be computed.
104 | Only softmax along the last dimension is supported.
105 | """
106 | def forward(self, input: Tensor) -> Tensor:
107 | if self.dim != -1 and self.dim != input.ndim - 1:
108 | raise RuntimeError(f'Only softmax along the last dimension is supported.')
109 |
110 | return SoftmaxAutoGrad.apply(input, False, False)
111 |
112 |
113 | class LogSoftmax(nn.LogSoftmax):
114 | """
115 | Normalizes the input using softmax and takes its log.
116 | See also base class.
117 |
118 | Args:
119 | dim: Dimension along which softmax will be computed.
120 | Only softmax along the last dimension is supported.
121 | """
122 | def forward(self, input: Tensor) -> Tensor:
123 | if self.dim != -1 and self.dim != input.ndim - 1:
124 | raise RuntimeError(f'Only softmax along the last dimension is supported.')
125 |
126 | return SoftmaxAutoGrad.apply(input, False, True)
127 |
128 |
129 | class Softmin(nn.Softmin):
130 | """
131 | Normalizes the input using softmin.
132 | See also base class.
133 |
134 | Args:
135 | dim: Dimension along which softmin will be computed.
136 | Only softmin along the last dimension is supported.
137 | """
138 | def forward(self, input: Tensor) -> Tensor:
139 | if self.dim != -1 and self.dim != input.ndim - 1:
140 | raise RuntimeError(f'Only softmin along the last dimension is supported.')
141 |
142 | return SoftmaxAutoGrad.apply(input, True, False)
143 |
--------------------------------------------------------------------------------
/attorch/glu_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Gated linear unit layer with arbitrary activation functions
3 | with PyTorch autodiff support.
4 | """
5 |
6 |
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | from torch import Tensor
11 | from torch import nn
12 | from torch.amp import custom_bwd, custom_fwd
13 | from triton import cdiv
14 |
15 | from .glu_kernels import glu_backward_kernel, glu_forward_kernel
16 | from .types import Context
17 |
18 |
19 | class GLUAutoGrad(torch.autograd.Function):
20 | """
21 | Autodiff for gated linear unit.
22 | """
23 | @staticmethod
24 | @custom_fwd(device_type='cuda')
25 | def forward(
26 | ctx: Context,
27 | input: Tensor,
28 | dim: int,
29 | act_func: str,
30 | ) -> Tensor:
31 | """
32 | Applies the gated linear unit with an arbitrary activation function
33 | to the input.
34 |
35 | Args:
36 | ctx: Context for variable storage.
37 | input: Input to gate.
38 | Can have arbitrary shape but dimension dim must be even.
39 | dim: Dimension over which to gate.
40 | act_func: Name of activation function to apply.
41 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
42 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish',
43 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM',
44 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM'
45 | where PARAM stands for the parameter in the case of parameterized
46 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a
47 | negative slope of 0.01).
48 |
49 | Returns:
50 | Input transformed by the gated linear unit
51 | with an arbitrary activation function.
52 | """
53 | param = None
54 | if '_' in act_func:
55 | comps = act_func.split('_')
56 | act_func = '_'.join(comps[:-1])
57 | param = float(comps[-1])
58 |
59 | input1, input2 = input.chunk(2, dim=dim)
60 | input1 = input1.contiguous()
61 | input2 = input2.contiguous()
62 |
63 | requires_grad = input.requires_grad
64 | size = input1.numel()
65 | output = torch.empty_like(input1)
66 |
67 | ctx.param = param
68 | ctx.act_func = act_func
69 | ctx.dim = dim
70 | ctx.size = size
71 | if requires_grad:
72 | ctx.save_for_backward(input1, input2)
73 |
74 | # Launches 1D grid where each program operates over
75 | # BLOCK_SIZE elements.
76 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),)
77 | glu_forward_kernel[grid](input1, input2, output, size, param, act_func)
78 |
79 | return output
80 |
81 | @staticmethod
82 | @custom_bwd(device_type='cuda')
83 | def backward(
84 | ctx: Context,
85 | output_grad: Tensor,
86 | ) -> Tuple[Optional[Tensor], ...]:
87 | """
88 | Calculates the input gradient of the gated linear unit.
89 |
90 | Args:
91 | ctx: Context containing stored variables.
92 | output_grad: Output gradients.
93 | Must be the same shape as the output.
94 |
95 | Returns:
96 | Input gradient of the gated linear unit.
97 | """
98 | (input1, input2) = ctx.saved_tensors
99 | input1_grad = torch.empty_like(input1)
100 | input2_grad = torch.empty_like(input2)
101 |
102 | # Launches 1D grid where each program operates over
103 | # BLOCK_SIZE elements.
104 | grid = lambda META: (cdiv(ctx.size, META['BLOCK_SIZE']),)
105 | glu_backward_kernel[grid](output_grad, input1, input2,
106 | input1_grad, input2_grad,
107 | ctx.size, ctx.param, ctx.act_func)
108 |
109 | # Pads output with None because a gradient is necessary for
110 | # all input arguments.
111 | return torch.concat([input1_grad, input2_grad], dim=ctx.dim), None, None
112 |
113 |
114 | class GLU(nn.GLU):
115 | """
116 | Applies the gated linear unit with an arbitrary activation function
117 | to the input.
118 | See also base class.
119 |
120 | Args:
121 | dim: Dimension over which to gate.
122 | act_func: Name of activation function to apply.
123 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
124 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish',
125 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM',
126 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM'
127 | where PARAM stands for the parameter in the case of parameterized
128 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a
129 | negative slope of 0.01).
130 | """
131 | def __init__(self, dim: int = -1, act_func: str = 'sigmoid') -> None:
132 | super().__init__(dim)
133 | self.act_func = act_func
134 |
135 | def forward(self, input: Tensor) -> Tensor:
136 | return GLUAutoGrad.apply(input, self.dim, self.act_func)
137 |
--------------------------------------------------------------------------------
/tests/test_batch_norm_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import pytest
4 | import torch
5 | from torch import autocast, nn
6 | from torch.nn import functional as F, init
7 |
8 | import attorch
9 | from .utils import assert_close, create_input, create_input_like, default_shapes
10 |
11 |
12 | @pytest.mark.parametrize('shape', default_shapes(min_dim=2, max_dim=4))
13 | @pytest.mark.parametrize('eps', [1e-5, 1e-6])
14 | @pytest.mark.parametrize('momentum', [0.1, 0.2])
15 | @pytest.mark.parametrize('affine', [False, True])
16 | @pytest.mark.parametrize('track_running_stats', [False, True])
17 | @pytest.mark.parametrize('add_pre_act', [False, True])
18 | @pytest.mark.parametrize('act_func', [None, 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'silu',
19 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu',
20 | 'mish', 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_0.01',
21 | 'elu_1', 'celu_1', 'hardshrink_0.5', 'softshrink_0.5'])
22 | @pytest.mark.parametrize('input_dtype', [torch.float32, torch.float16])
23 | @pytest.mark.parametrize('amp', [False, True])
24 | def test_batch_norm_layer(
25 | shape: Tuple[int, ...],
26 | eps: float,
27 | momentum: float,
28 | affine: bool,
29 | track_running_stats: bool,
30 | add_pre_act: bool,
31 | act_func: Optional[str],
32 | input_dtype: bool,
33 | amp: bool,
34 | subset: bool,
35 | ) -> None:
36 | if subset and (shape not in default_shapes(subset=True)):
37 | return
38 |
39 | if shape[0] == 1 or (input_dtype is torch.float16 and not amp):
40 | return
41 |
42 | bn_name = 'BatchNorm2d' if len(shape) == 4 else 'BatchNorm1d'
43 | attorch_input = create_input(shape, dtype=input_dtype)
44 | pytorch_input = create_input(shape, dtype=input_dtype)
45 |
46 | if add_pre_act:
47 | attorch_residual = create_input(shape, dtype=input_dtype, seed=1)
48 | pytorch_residual = create_input(shape, dtype=input_dtype, seed=1)
49 |
50 | else:
51 | attorch_residual = pytorch_residual = None
52 |
53 | attorch_batch_norm = getattr(attorch, bn_name)(num_features=shape[1],
54 | eps=eps, momentum=momentum,
55 | affine=affine,
56 | track_running_stats=track_running_stats,
57 | act_func=act_func)
58 | pytorch_batch_norm = getattr(nn, bn_name)(num_features=shape[1],
59 | eps=eps, momentum=momentum,
60 | affine=affine,
61 | track_running_stats=track_running_stats,
62 | device='cuda')
63 | pytorch_act = nn.Identity() if act_func is None else getattr(F, act_func.rsplit('_', 1)[0])
64 |
65 | if affine:
66 | torch.manual_seed(0)
67 | init.normal_(attorch_batch_norm.weight)
68 | init.normal_(attorch_batch_norm.bias)
69 |
70 | torch.manual_seed(0)
71 | init.normal_(pytorch_batch_norm.weight)
72 | init.normal_(pytorch_batch_norm.bias)
73 |
74 | with autocast('cuda', enabled=amp):
75 | if add_pre_act:
76 | attorch_output = attorch_batch_norm(attorch_input, attorch_residual)
77 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input) +
78 | pytorch_residual)
79 |
80 | else:
81 | attorch_output = attorch_batch_norm(attorch_input)
82 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input))
83 |
84 | assert_close((attorch_output, pytorch_output),
85 | (attorch_batch_norm.running_mean, pytorch_batch_norm.running_mean),
86 | (attorch_batch_norm.running_var, pytorch_batch_norm.running_var))
87 |
88 | attorch_output.backward(create_input_like(attorch_output))
89 | pytorch_output.backward(create_input_like(pytorch_output))
90 |
91 | residual_grad_pair = ((attorch_residual.grad, pytorch_residual.grad)
92 | if add_pre_act else (None, None))
93 | weight_grad_pair = ((attorch_batch_norm.weight.grad, pytorch_batch_norm.weight.grad)
94 | if affine else (None, None))
95 | bias_grad_pair = ((attorch_batch_norm.bias.grad, pytorch_batch_norm.bias.grad)
96 | if affine else (None, None))
97 | assert_close((attorch_input.grad, pytorch_input.grad),
98 | residual_grad_pair, weight_grad_pair, bias_grad_pair,
99 | rtol=1e-2, atol=1e-3)
100 |
101 | attorch_batch_norm.eval()
102 | pytorch_batch_norm.eval()
103 |
104 | with autocast('cuda', enabled=amp):
105 | if add_pre_act:
106 | attorch_output = attorch_batch_norm(attorch_input, attorch_residual)
107 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input) +
108 | pytorch_residual)
109 |
110 | else:
111 | attorch_output = attorch_batch_norm(attorch_input)
112 | pytorch_output = pytorch_act(pytorch_batch_norm(pytorch_input))
113 |
114 | assert_close((attorch_output, pytorch_output),
115 | (attorch_batch_norm.running_mean, pytorch_batch_norm.running_mean),
116 | (attorch_batch_norm.running_var, pytorch_batch_norm.running_var))
117 |
--------------------------------------------------------------------------------
/examples/wikitext-2/gpt.py:
--------------------------------------------------------------------------------
1 | """
2 | GPT2 for language modelling.
3 | """
4 |
5 |
6 | from typing import Optional
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 |
12 | import attorch
13 |
14 |
15 | class MLP(nn.Module):
16 | """
17 | Transforms the input using a multilayer perceptron with one hidden layer
18 | and the GELU activation function.
19 |
20 | Args:
21 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend.
22 | in_dim: Number of input features.
23 | hidden_dim: Number of hidden features.
24 | out_dim: Number of output features.
25 | If None, it is set to the number of input features.
26 | """
27 | def __init__(
28 | self,
29 | use_attorch: bool,
30 | in_dim: int,
31 | hidden_dim: int,
32 | out_dim: Optional[int] = None,
33 | ) -> None:
34 | super().__init__()
35 |
36 | self.fc1 = (attorch.Linear(in_dim, hidden_dim, act_func='gelu')
37 | if use_attorch else nn.Linear(in_dim, hidden_dim))
38 | self.act = nn.Identity() if use_attorch else nn.GELU()
39 | self.fc2 = (attorch.Linear(hidden_dim, out_dim or in_dim)
40 | if use_attorch else nn.Linear(hidden_dim, out_dim or in_dim))
41 |
42 | def forward(self, input: Tensor) -> Tensor:
43 | return self.fc2(self.act(self.fc1(input)))
44 |
45 |
46 | class TransformerBlock(nn.Module):
47 | """
48 | Passes the input through a transformer block.
49 |
50 | Args:
51 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend.
52 | dim: Embedding dimension.
53 | num_heads: Number of heads for multi-headed self-attention.
54 | """
55 | def __init__(
56 | self,
57 | use_attorch: bool,
58 | dim: int,
59 | num_heads: int,
60 | ) -> None:
61 | super().__init__()
62 | self.use_attorch = use_attorch
63 | backend = attorch if use_attorch else nn
64 |
65 | self.ln1 = backend.LayerNorm(dim)
66 | self.attn = backend.MultiheadAttention(dim, num_heads,
67 | batch_first=True)
68 |
69 | self.ln2 = backend.LayerNorm(dim)
70 | self.mlp = MLP(use_attorch, dim, 4 * dim)
71 |
72 | def forward(self, input: Tensor) -> Tensor:
73 | if self.use_attorch:
74 | output = input + self.attn(self.ln1(input), causal=True)
75 |
76 | else:
77 | output = self.ln1(input)
78 | output = input + self.attn(output, output, output,
79 | attn_mask=torch.empty(2*(input.shape[1],)),
80 | is_causal=True,
81 | need_weights=False)[0]
82 |
83 | output = output + self.mlp(self.ln2(output))
84 | return output
85 |
86 |
87 | class GPT2(nn.Module):
88 | """
89 | Performs language modelling using GPT2,
90 | optionally computing the loss if return_loss is True.
91 |
92 | Args:
93 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend.
94 | vocab_size: Vocabulary size.
95 | depth: Depth of the transformer.
96 | dim: Embedding dimension.
97 | num_heads: Number of heads for multi-headed self-attention.
98 | max_seq_len: Maximum sequence length of the incoming inputs.
99 | """
100 | def __init__(
101 | self,
102 | use_attorch: bool,
103 | vocab_size: int,
104 | depth: int,
105 | dim: int,
106 | num_heads: int,
107 | max_seq_len: int = 512,
108 | ) -> None:
109 | super().__init__()
110 | backend = attorch if use_attorch else nn
111 |
112 | self.tok_embed = nn.Embedding(vocab_size, dim)
113 | self.pos_embed = nn.Embedding(max_seq_len, dim)
114 | self.transformer = nn.Sequential(*[TransformerBlock(use_attorch, dim, num_heads)
115 | for _ in range(depth)])
116 | self.norm = backend.LayerNorm(dim)
117 | self.fc = backend.Linear(dim, vocab_size)
118 | self.loss_func = backend.CrossEntropyLoss()
119 |
120 | def forward(
121 | self,
122 | input: Tensor,
123 | return_loss: bool = False,
124 | ) -> Tensor:
125 | tok_embed = self.tok_embed(input)
126 | pos_embed = self.pos_embed(torch.arange(0, input.shape[1],
127 | dtype=torch.long,
128 | device=input.device))
129 |
130 | output = self.transformer(tok_embed + pos_embed)
131 | output = self.norm(output)
132 | output = self.fc(output)
133 |
134 | return (self.loss_func(output[:, :-1].contiguous().view(-1, output.shape[-1]),
135 | input[:, 1:].contiguous().view(-1))
136 | if return_loss else output)
137 |
138 |
139 | def gpt2(
140 | use_attorch: bool,
141 | vocab_size: int,
142 | max_seq_len: 512,
143 | downsize: int = 1,
144 | ) -> GPT2:
145 | """
146 | Returns a GPT2 model with optional cross entropy loss.
147 |
148 | Args:
149 | use_attorch: Flag to use attorch in lieu of PyTorch as the backend.
150 | vocab_size: Vocabulary size.
151 | max_seq_len: Maximum sequence length of the incoming inputs.
152 | downsize: The depth and width of the model are calculated by dividing
153 | GPT2's original depth and width by this factor.
154 | """
155 | return GPT2(use_attorch, vocab_size=vocab_size,
156 | depth=12 // downsize, dim=768 // downsize, num_heads=12 // downsize,
157 | max_seq_len=max_seq_len)
158 |
--------------------------------------------------------------------------------
/attorch/pooling_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Pooling layers with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from torch.nn.modules.utils import _pair
12 |
13 | from .conv_layer import Conv2dAutoGrad
14 |
15 |
16 | class AvgPool1d(nn.AvgPool1d):
17 | """
18 | Averages 1D windows of pixels in the input.
19 | See also base class.
20 |
21 | Note: The Triton compiler does not perform well with convolutional kernels,
22 | which underlie this pooling layer, and a significant speed disparity between
23 | this module and its PyTorch equivalent should be expected. Use at your own discretion.
24 |
25 | Args:
26 | kernel_size: Kernel size.
27 | If an int, this value is used along both spatial dimensions.
28 | stride: Stride of kernel.
29 | If an int, this value is used along both spatial dimensions.
30 | If None, the kernel size is used as the stride.
31 | padding: Padding applied to the input.
32 | If an int, this value is used along both spatial dimensions.
33 | ceil_mode: Flag for using ceil instead of floor to calculate the output shape.
34 | Must be False.
35 | count_include_pad: Flag for including padding when averaging pixels.
36 | Must be True.
37 |
38 | Raises:
39 | RuntimeError: 1. Ceil was requested in place of floor to calculate the otuput shape.
40 | 2. Padding was requested to be excluded in average computations.
41 | """
42 | def __init__(
43 | self,
44 | kernel_size: int,
45 | stride: Optional[int] = None,
46 | padding: int = 0,
47 | ceil_mode: bool = False,
48 | count_include_pad: bool = True,
49 | ) -> None:
50 | if ceil_mode:
51 | raise RuntimeError('The output shape can only be calculated using floor, not ceil.')
52 |
53 | if not count_include_pad:
54 | raise RuntimeError('Padding must be included when averaging pixels.')
55 |
56 | super().__init__(kernel_size, stride, padding, ceil_mode,
57 | count_include_pad)
58 |
59 | self.kernel = None
60 |
61 | def forward(self, input: Tensor) -> Tensor:
62 | if self.kernel is None:
63 | self.kernel = torch.full((input.shape[1], 1, self.kernel_size[0], 1),
64 | 1 / self.kernel_size[0],
65 | device='cuda')
66 |
67 | return Conv2dAutoGrad.apply(input.unsqueeze(-1), self.kernel,
68 | None,
69 | self.stride[0], 1, self.padding[0], 0,
70 | input.shape[1]).squeeze(-1)
71 |
72 |
73 | class AvgPool2d(nn.AvgPool2d):
74 | """
75 | Averages 2D windows of pixels in the input.
76 | See also base class.
77 |
78 | Note: The Triton compiler does not perform well with convolutional kernels,
79 | which underlie this pooling layer, and a significant speed disparity between
80 | this module and its PyTorch equivalent should be expected. Use at your own discretion.
81 |
82 | Args:
83 | kernel_size: Kernel size.
84 | If an int, this value is used along both spatial dimensions.
85 | stride: Stride of kernel.
86 | If an int, this value is used along both spatial dimensions.
87 | If None, the kernel size is used as the stride.
88 | padding: Padding applied to the input.
89 | If an int, this value is used along both spatial dimensions.
90 | ceil_mode: Flag for using ceil instead of floor to calculate the output shape.
91 | Must be False.
92 | count_include_pad: Flag for including padding when averaging pixels.
93 | Must be True.
94 | divisor_override: Average divisor.
95 | Must be None.
96 |
97 | Raises:
98 | RuntimeError: 1. Ceil was requested in place of floor to calculate the otuput shape.
99 | 2. Padding was requested to be excluded in average computations.
100 | 3. Average divisor was overriden.
101 | """
102 | def __init__(
103 | self,
104 | kernel_size: Union[int, Tuple[int, int]],
105 | stride: Optional[Union[int, Tuple[int, int]]] = None,
106 | padding: Union[int, Tuple[int, int]] = 0,
107 | ceil_mode: bool = False,
108 | count_include_pad: bool = True,
109 | divisor_override: Optional[int] = None,
110 | ) -> None:
111 | if ceil_mode:
112 | raise RuntimeError('The output shape can only be calculated using floor, not ceil.')
113 |
114 | if not count_include_pad:
115 | raise RuntimeError('Padding must be included when averaging pixels.')
116 |
117 | if divisor_override is not None:
118 | raise RuntimeError('The average divisor must be the size of the window.')
119 |
120 | super().__init__(kernel_size, stride, padding, ceil_mode,
121 | count_include_pad, divisor_override)
122 |
123 | self.kernel_size = _pair(self.kernel_size)
124 | self.stride = _pair(self.stride)
125 | self.padding = _pair(self.padding)
126 | self.kernel = None
127 |
128 | def forward(self, input: Tensor) -> Tensor:
129 | if self.kernel is None:
130 | self.kernel = torch.full((input.shape[1], 1, *self.kernel_size),
131 | 1 / (self.kernel_size[0] * self.kernel_size[1]),
132 | device='cuda')
133 |
134 | return Conv2dAutoGrad.apply(input, self.kernel, None,
135 | *self.stride, *self.padding,
136 | input.shape[1])
137 |
--------------------------------------------------------------------------------
/examples/mnist/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains and benchmarks a multilayer perceptron (MLP) on the MNIST dataset.
3 | """
4 |
5 |
6 | import time
7 | from argparse import ArgumentParser
8 | from functools import partial
9 | from pathlib import Path
10 | from typing import Callable, Tuple
11 |
12 | import torch
13 | from torch import nn
14 | from torch.optim import SGD
15 | from torch.utils.data import DataLoader
16 | from torchvision import transforms as T
17 | from torchvision.datasets import MNIST
18 |
19 | from .mlp import MLP
20 | from ..utils import AvgMeter, benchmark_fw_and_bw
21 |
22 |
23 | def create_dls(batch_size: int = 1_024) -> Tuple[DataLoader, DataLoader]:
24 | """
25 | Creates data loaders for MNIST with normalization.
26 |
27 | args:
28 | batch_size: Batch size.
29 |
30 | Returns:
31 | Training and validation MNIST data loaders.
32 | """
33 | transform = T.Compose([T.ToTensor(),
34 | T.Normalize((0.1307,), (0.3081,))])
35 |
36 | train_dataset = MNIST(root='.', train=True, transform=transform,
37 | download=not Path('MNIST/').exists())
38 | valid_dataset = MNIST(root='.', train=False, transform=transform)
39 |
40 | train_dl = DataLoader(dataset=train_dataset, batch_size=batch_size,
41 | shuffle=True, drop_last=True)
42 | valid_dl = DataLoader(dataset=valid_dataset, batch_size=batch_size,
43 | shuffle=False, drop_last=True)
44 |
45 | return train_dl, valid_dl
46 |
47 |
48 | def train(
49 | model: nn.Module,
50 | train_dl: DataLoader,
51 | valid_dl: DataLoader,
52 | epochs: int = 10,
53 | batch_size: int = 1_024,
54 | ) -> float:
55 | """
56 | Trains and validates a model for classification.
57 |
58 | Args:
59 | model: Model to train. Its forward pass must optionally accept targets
60 | to compute the loss.
61 | train_dl: Data loader for training.
62 | valid_dl: Data loader for validation.
63 | epochs: Number of epochs to train for.
64 | batch_size: Batch size.
65 |
66 | Returns:
67 | Total training and validation time.
68 | """
69 | model = model.to('cuda')
70 | optim = SGD(model.parameters(), lr=batch_size / 64 * 0.005)
71 | optim.zero_grad()
72 |
73 | avg_meter = AvgMeter()
74 | start = time.time()
75 |
76 | for epoch in range(epochs):
77 | print(f'Epoch {epoch+1}/{epochs}')
78 |
79 | model.train()
80 | avg_meter.reset()
81 | for input, target in train_dl:
82 | input = input.to('cuda')
83 | target = target.to('cuda')
84 |
85 | loss = model(input, target)
86 | optim.zero_grad()
87 |
88 | avg_meter.update(loss.item(), len(input))
89 | print(f'Training loss: {avg_meter.avg}')
90 |
91 | model.eval()
92 | avg_meter.reset()
93 | with torch.no_grad():
94 | for input, target in valid_dl:
95 | input = input.to('cuda')
96 | target = target.to('cuda')
97 |
98 | output = model(input)
99 | acc = (output.argmax(dim=-1) == target).float().mean()
100 | avg_meter.update(acc.item(), len(input))
101 | print(f'Validation accuracy: {avg_meter.avg}')
102 |
103 | return time.time() - start
104 |
105 |
106 | def main(model_cls: Callable, epochs: int = 10, batch_size: int = 1_024) -> None:
107 | """
108 | Trains and benchmarks a vision model on the MNIST dataset.
109 |
110 | Args:
111 | model_cls: Model class to train, with a 'num_classes' argument for
112 | specifying the number of output classes.
113 | epochs: Number of epochs to train for.
114 | batch_size: Batch size.
115 | """
116 | train_dl, valid_dl = create_dls(batch_size)
117 | model = model_cls(num_classes=len(MNIST.classes)).to('cuda')
118 |
119 | input, target = next(iter(train_dl))
120 | input = input.to('cuda')
121 | target = target.to('cuda')
122 |
123 | for _ in range(10):
124 | model.train()
125 | with torch.autocast('cuda'):
126 | loss = model(input, target)
127 | loss.backward()
128 |
129 | model.eval()
130 | with torch.no_grad() and torch.autocast('cuda'):
131 | model(input)
132 |
133 | model.train()
134 | benchmark_fw_and_bw(model, input=input, target=target)
135 |
136 | print('Total training and validation time: '
137 | f'{train(model, train_dl, valid_dl, epochs, batch_size)}')
138 |
139 |
140 | if __name__ == '__main__':
141 | parser = ArgumentParser(description='Trains and benchmarks a multilayer perceptron (MLP) on the MNIST dataset.')
142 | parser.add_argument('--hidden_dim',
143 | type=int,
144 | default=128,
145 | help='Number of hidden features in the MLP.')
146 | parser.add_argument('--depth',
147 | type=int,
148 | default=1,
149 | help='Number of hidden layers in the MLP.')
150 | parser.add_argument('--epochs',
151 | type=int,
152 | default=10,
153 | help='Number of epochs to train for')
154 | parser.add_argument('--batch_size',
155 | type=int,
156 | default=1_024,
157 | help='Batch size')
158 | args = parser.parse_args()
159 |
160 | model_cls = partial(MLP, hidden_dim=args.hidden_dim, depth=args.depth)
161 |
162 | print('attorch run:')
163 | main(partial(model_cls, use_attorch=True),
164 | epochs=args.epochs, batch_size=args.batch_size)
165 |
166 | print('PyTorch run:')
167 | main(partial(model_cls, use_attorch=False),
168 | epochs=args.epochs, batch_size=args.batch_size)
169 |
--------------------------------------------------------------------------------
/attorch/p_loss_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Kernels for p-norm-induced losses.
3 | """
4 |
5 |
6 | import triton
7 | import triton.language as tl
8 |
9 | from .utils import element_wise_kernel_configs
10 |
11 |
12 | @triton.autotune(
13 | configs=element_wise_kernel_configs(),
14 | key=['size'],
15 | )
16 | @triton.jit
17 | def p_loss_forward_kernel(
18 | input_pointer, target_pointer, output_pointer,
19 | param, size, p_loss: tl.constexpr, reduction: tl.constexpr,
20 | BLOCK_SIZE: tl.constexpr,
21 | ):
22 | """
23 | Measures the smooth L1, L1, squared L2, or Huber loss of the difference
24 | between the input and target.
25 |
26 | Args:
27 | input_pointer: Pointer to the input.
28 | The input must be of shape [size].
29 | target_pointer: Pointer to the target.
30 | The target must be of shape [size].
31 | output_pointer: Pointer to a container the error is written to.
32 | The container must be of shape [size] if reduction is 'none',
33 | and otherwise of shape [size/BLOCK_SIZE].
34 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber).
35 | size: Number of elements in the input and target.
36 | p_loss: p-norm used to compute the error.
37 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss.
38 | reduction: Reduction strategy for the output.
39 | Options are 'none' for no reduction, 'mean' for averaging the error
40 | across all entries, and 'sum' for summing the error across all entries.
41 | If a reduction method is specified, the reduced result of each
42 | program is written to a separate index in the output container,
43 | which should later be summed.
44 | BLOCK_SIZE: Block size.
45 | """
46 | # This program processes BLOCK_SIZE rows.
47 | pid = tl.program_id(axis=0)
48 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
49 | mask = offset < size
50 |
51 | input = tl.load(input_pointer + offset, mask=mask).to(tl.float32)
52 | target = tl.load(target_pointer + offset, mask=mask).to(tl.float32)
53 | diff = input - target
54 |
55 | if p_loss == 0:
56 | error = tl.where(diff < param, 0.5 * diff * diff / param, tl.abs(diff) - 0.5 * param)
57 |
58 | elif p_loss == 1:
59 | error = tl.abs(diff)
60 |
61 | elif p_loss == 2:
62 | error = diff * diff
63 |
64 | elif p_loss == 3:
65 | error = tl.where(diff < param, 0.5 * diff * diff, param * (tl.abs(diff) - 0.5 * param))
66 |
67 | if reduction == 'none':
68 | tl.store(output_pointer + offset, error, mask=mask)
69 |
70 | elif reduction == 'mean':
71 | tl.store(output_pointer + pid, tl.sum(error) / size)
72 |
73 | elif reduction == 'sum':
74 | tl.store(output_pointer + pid, tl.sum(error))
75 |
76 |
77 | @triton.autotune(
78 | configs=element_wise_kernel_configs(),
79 | key=['size'],
80 | )
81 | @triton.jit
82 | def p_loss_backward_kernel(
83 | output_grad_pointer, input_pointer, target_pointer,
84 | input_grad_pointer, target_grad_pointer, param, size,
85 | p_loss: tl.constexpr, reduction: tl.constexpr,
86 | BLOCK_SIZE: tl.constexpr,
87 | ):
88 | """
89 | Calculates the input gradient of the smooth L1 norm, L1 norm, L2 norm, or Huber loss.
90 |
91 | Args:
92 | output_grad_pointer: Pointer to the error's output gradients.
93 | The output gradients must be a scalar or of shape [size].
94 | input_pointer: Pointer to the input.
95 | The input must be of shape [size].
96 | target_pointer: Pointer to the target.
97 | The target must be of shape [size].
98 | input_grad_pointer: Pointer to a container the input's gradients are written to.
99 | The container must be of shape [size].
100 | target_grad_pointer: Pointer to a container the target's gradients are written to.
101 | The container must be of shape [size].
102 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber).
103 | size: Number of elements in the input and target.
104 | p_loss: p-norm used to compute the error whose gradient is calculated.
105 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss.
106 | reduction: Reduction strategy for the output whose gradient is calculated.
107 | Options are 'none' for no reduction, 'mean' for averaging the error
108 | across all entries, and 'sum' for summing the error across all entries.
109 | BLOCK_SIZE: Block size.
110 | """
111 | # This program processes BLOCK_SIZE rows.
112 | pid = tl.program_id(axis=0)
113 | offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
114 | mask = offset < size
115 |
116 | output_grad_mask = None
117 | if reduction == 'none':
118 | output_grad_pointer += offset
119 | output_grad_mask = mask
120 |
121 | input = tl.load(input_pointer + offset, mask=mask).to(tl.float32)
122 | target = tl.load(target_pointer + offset, mask=mask).to(tl.float32)
123 | diff = input - target
124 | output_grad = tl.load(output_grad_pointer, mask=output_grad_mask).to(tl.float32)
125 |
126 | if p_loss == 0:
127 | input_grad = tl.where(diff < param, diff / param, tl.where(0 <= diff, 1, -1))
128 |
129 | elif p_loss == 1:
130 | input_grad = tl.where(0 <= diff, 1, -1)
131 |
132 | elif p_loss == 2:
133 | input_grad = 2 * diff
134 |
135 | elif p_loss == 3:
136 | input_grad = tl.where(diff < param, diff, param * tl.where(0 <= diff, 1, -1))
137 |
138 | if reduction == 'mean':
139 | input_grad /= size
140 |
141 | input_grad *= output_grad
142 | tl.store(input_grad_pointer + offset, input_grad, mask=mask)
143 | tl.store(target_grad_pointer + offset, -input_grad, mask=mask)
144 |
--------------------------------------------------------------------------------
/examples/regression/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains and benchmarks a multilayer perceptron (MLP) on synthetic regression data.
3 | """
4 |
5 |
6 | import time
7 | from argparse import ArgumentParser
8 | from types import ModuleType
9 | from typing import Callable, Tuple
10 |
11 | import torch
12 | from torch.optim import SGD
13 | from torch.utils.data import DataLoader, TensorDataset
14 |
15 | import attorch
16 | from ..utils import AvgMeter, benchmark_fw_and_bw
17 |
18 |
19 | def create_dls(
20 | n_samples: int = 50_000,
21 | dim: int = 10,
22 | batch_size: int = 1_024,
23 | ) -> Tuple[DataLoader, DataLoader]:
24 | """
25 | Creates synthetic regression data loaders.
26 |
27 | Args:
28 | n_samples: Number of samples to generate.
29 | dim: Dimensionality of synthetic data.
30 | batch_size: Batch size.
31 |
32 | Returns:
33 | Training and validation data loaders.
34 | """
35 | torch.manual_seed(0)
36 | input = torch.randn(n_samples, dim)
37 | target = input @ torch.randn(dim, 1) + 0.1 * torch.randn(n_samples, 1)
38 | n_train = int(0.8 * n_samples)
39 |
40 | train_dataset = TensorDataset(input[:n_train], target[:n_train])
41 | valid_dataset = TensorDataset(input[n_train:], target[n_train:])
42 |
43 | train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
44 | drop_last=True)
45 | valid_dl = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
46 | drop_last=True)
47 |
48 | return train_dl, valid_dl
49 |
50 |
51 | def train(
52 | model: torch.nn.Module,
53 | loss_fn: Callable,
54 | train_dl: DataLoader,
55 | valid_dl: DataLoader,
56 | epochs: int = 10,
57 | batch_size: int = 1_024,
58 | ) -> float:
59 | """
60 | Trains and validates a regression model.
61 |
62 | Args:
63 | model: Model to train.
64 | loss_fn: Loss function.
65 | train_dl: Data loader for training.
66 | valid_dl: Data loader for validation.
67 | epochs: Number of epochs to train for.
68 | batch_size: Batch size.
69 |
70 | Returns:
71 | Total training and validation time.
72 | """
73 | model = model.to('cuda')
74 | optim = SGD(model.parameters(), lr=batch_size / 64 * 0.005)
75 | optim.zero_grad()
76 |
77 | avg_meter = AvgMeter()
78 | start = time.time()
79 |
80 | for epoch in range(epochs):
81 | print(f'Epoch {epoch+1}/{epochs}')
82 |
83 | model.train()
84 | avg_meter.reset()
85 | for input, target in train_dl:
86 | input = input.to('cuda')
87 | target = target.to('cuda')
88 |
89 | loss = loss_fn(model(input), target)
90 | optim.zero_grad()
91 |
92 | avg_meter.update(loss.item(), len(input))
93 | print(f'Training loss: {avg_meter.avg}')
94 |
95 | model.eval()
96 | avg_meter.reset()
97 | with torch.no_grad():
98 | for input, target in valid_dl:
99 | input = input.to('cuda')
100 | target = target.to('cuda')
101 |
102 | loss = loss_fn(model(input), target)
103 | avg_meter.update(loss.item(), len(input))
104 | print(f'Validation loss: {avg_meter.avg}')
105 |
106 | return time.time() - start
107 |
108 |
109 | def main(
110 | nn: ModuleType,
111 | n_samples: int = 50_000,
112 | dim: int = 10,
113 | epochs: int = 10,
114 | batch_size: int = 1_024,
115 | ) -> None:
116 | """
117 | Trains and benchmarks a regression MLP on synthetic data.
118 |
119 | Args:
120 | nn: Neural network module used to construct the MLP.
121 | n_samples: Number of samples to generate.
122 | dim: Dimensionality of synthetic data.
123 | epochs: Number of epochs to train for.
124 | batch_size: Batch size.
125 | """
126 | train_dl, valid_dl = create_dls(n_samples, dim, batch_size)
127 | model = nn.Sequential(nn.Linear(dim, dim // 2),
128 | nn.ReLU(),
129 | nn.Linear(dim // 2, 1)).to('cuda')
130 | loss_fn = nn.MSELoss()
131 |
132 | input, target = next(iter(train_dl))
133 | input = input.to('cuda')
134 | target = target.to('cuda')
135 |
136 | for _ in range(10):
137 | model.train()
138 | with torch.autocast('cuda'):
139 | loss = loss_fn(model(input), target)
140 | loss.backward()
141 |
142 | model.eval()
143 | with torch.no_grad() and torch.autocast('cuda'):
144 | model(input)
145 |
146 | model.train()
147 | benchmark_fw_and_bw(model, input=input)
148 |
149 | print('Total training and validation time: '
150 | f'{train(model, loss_fn, train_dl, valid_dl, epochs, batch_size)}')
151 |
152 |
153 | if __name__ == '__main__':
154 | parser = ArgumentParser(description='Trains and benchmarks a multilayer perceptron (MLP) on synthetic regression data.')
155 | parser.add_argument('--n_samples',
156 | type=int,
157 | default=50_000,
158 | help='Number of samples to generate')
159 | parser.add_argument('--dim',
160 | type=int,
161 | default=10,
162 | help='Dimensionality of synthetic data')
163 | parser.add_argument('--epochs',
164 | type=int,
165 | default=10,
166 | help='Number of training epochs')
167 | parser.add_argument('--batch_size',
168 | type=int,
169 | default=1_024,
170 | help='Batch size')
171 | args = parser.parse_args()
172 |
173 | print('attorch run:')
174 | main(attorch.nn, n_samples=args.n_samples, dim=args.dim,
175 | epochs=args.epochs, batch_size=args.batch_size)
176 |
177 | print('PyTorch run:')
178 | main(torch.nn, n_samples=args.n_samples, dim=args.dim,
179 | epochs=args.epochs, batch_size=args.batch_size)
180 |
--------------------------------------------------------------------------------
/attorch/rms_norm_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Root mean square normalization with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from torch.amp import custom_bwd, custom_fwd
12 | from triton import cdiv
13 |
14 | from .rms_norm_kernels import rms_norm_backward_kernel, rms_norm_forward_kernel
15 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic
16 | from .types import Context, Device
17 |
18 |
19 | class RMSNormAutoGrad(torch.autograd.Function):
20 | """
21 | Autodiff for root mean square normalization.
22 | """
23 | @staticmethod
24 | @custom_fwd(device_type='cuda')
25 | def forward(
26 | ctx: Context,
27 | input: Tensor,
28 | weight: Optional[Tensor] = None,
29 | eps: Optional[float] = None,
30 | ) -> Tensor:
31 | """
32 | Root-mean-square-normalizes the input.
33 |
34 | Args:
35 | ctx: Context for variable storage.
36 | input: Input to root-mean-square-normalize.
37 | Can have arbitrary shape.
38 | weight: Optional weights for linear transform.
39 | If provided, must be of shape [feat_dim].
40 | eps: Epsilon added in the square root in the denominator
41 | to avoid division by zero. If None, it defaults to
42 | torch.finfo(input.dtype).eps.
43 |
44 | Returns:
45 | Root-mean-square-normalized input.
46 | """
47 | flattened_input = input.unsqueeze(0) if input.ndim == 1 else input
48 | flattened_input = flattened_input.flatten(0, -2)
49 | batch_dim, feat_dim = flattened_input.shape
50 | eps = torch.finfo(input.dtype).eps if eps is None else eps
51 |
52 | output = torch.empty_like(flattened_input)
53 |
54 | scale_by_weight = weight is not None
55 | requires_grad = (input.requires_grad or
56 | (scale_by_weight and weight.requires_grad))
57 |
58 | if requires_grad:
59 | inv_rms = torch.empty(batch_dim,
60 | device=input.device,
61 | dtype=torch.float32)
62 |
63 | else:
64 | inv_rms = None
65 |
66 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
67 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),)
68 | rms_norm_forward_kernel[grid](flattened_input, weight,
69 | inv_rms, output,
70 | batch_dim, feat_dim,
71 | *flattened_input.stride(), *output.stride(),
72 | eps,
73 | scale_by_weight=scale_by_weight,
74 | save_stats=requires_grad)
75 |
76 | ctx.scale_by_weight = scale_by_weight
77 | if requires_grad:
78 | ctx.save_for_backward(flattened_input, inv_rms, weight)
79 |
80 | return output.view_as(input)
81 |
82 | @staticmethod
83 | @custom_bwd(device_type='cuda')
84 | def backward(
85 | ctx: Context,
86 | output_grad: Tensor,
87 | ) -> Tuple[Optional[Tensor], ...]:
88 | """
89 | Calculates the input gradient of root mean square normalization.
90 |
91 | Args:
92 | ctx: Context containing stored variables.
93 | output_grad: Output gradients.
94 | Must be the same shape as the output.
95 |
96 | Returns:
97 | Input gradient of root mean square normalization.
98 | """
99 | scale_by_weight = ctx.scale_by_weight
100 | (flattened_input, inv_rms, weight) = ctx.saved_tensors
101 | flattened_output_grad = output_grad.view_as(flattened_input)
102 |
103 | batch_dim, feat_dim = flattened_output_grad.shape
104 | input_grad = torch.empty_like(flattened_output_grad)
105 |
106 | if scale_by_weight:
107 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim,
108 | 'feat_dim': feat_dim})
109 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH
110 |
111 | weight_grad = torch.empty((out_batch_dim, feat_dim),
112 | device=flattened_input.device)
113 |
114 | else:
115 | weight_grad = None
116 |
117 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
118 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),)
119 | rms_norm_backward_kernel[grid](flattened_output_grad, flattened_input,
120 | inv_rms, weight,
121 | input_grad, weight_grad,
122 | batch_dim, feat_dim,
123 | *flattened_output_grad.stride(),
124 | *flattened_input.stride(),
125 | *input_grad.stride(),
126 | *weight_grad.stride() if scale_by_weight else (1, 1),
127 | scale_by_weight=scale_by_weight)
128 |
129 | if scale_by_weight:
130 | weight_grad = weight_grad.sum(dim=0)
131 |
132 | # Pads output with None because a gradient is necessary for
133 | # all input arguments.
134 | return input_grad.view_as(output_grad), weight_grad, None
135 |
136 |
137 | class RMSNorm(nn.RMSNorm):
138 | """
139 | Root-mean-square-normalizes the input.
140 | See also base class.
141 |
142 | Args:
143 | normalized_shape: Dimensionality of last feature that is normalized.
144 | eps: Epsilon added in the square root in the denominator
145 | to avoid division by zero. If None, it defaults to
146 | torch.finfo(input.dtype).eps.
147 | elementwise_affine: Flag for scaling the normalized output by weights.
148 | device: Device to use.
149 | dtype: Dtype of layer.
150 |
151 | Raises:
152 | RuntimeError: Normalized shape was not an integer.
153 | """
154 | def __init__(
155 | self,
156 | normalized_shape: int,
157 | eps: Optional[float] = None,
158 | elementwise_affine: bool = True,
159 | device: Device = 'cuda',
160 | dtype: torch.dtype = torch.float32,
161 | ) -> None:
162 | if not isinstance(normalized_shape, int):
163 | raise RuntimeError('Normalized shape must be an integer.')
164 |
165 | super().__init__(normalized_shape, eps, elementwise_affine, device, dtype)
166 |
167 | def forward(self, input: Tensor) -> Tensor:
168 | return RMSNormAutoGrad.apply(input, self.weight, self.eps)
169 |
--------------------------------------------------------------------------------
/attorch/multi_head_attention_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Multi-headed attention with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from functools import partial
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import Tensor
11 | from torch import nn
12 | from torch.nn import functional as F
13 |
14 | from .multi_head_attention_kernels import _attention
15 | from .types import Device
16 |
17 |
18 | def extract_heads(input: Tensor, num_heads: int) -> Tensor:
19 | """
20 | Reshapes the projected input to extract heads for multi-headed attention.
21 |
22 | Args:
23 | input: Input to reshape and extract heads from.
24 | num_heads: Number of heads to extract.
25 |
26 | Returns:
27 | Input reshaped with its heads extracted.
28 | """
29 | batch_dim, n_tokens, _ = input.shape
30 | return input.reshape(batch_dim, n_tokens, num_heads, -1).transpose(1, 2)
31 |
32 |
33 | class MultiheadAttention(nn.MultiheadAttention):
34 | """
35 | Applies multi-headed scaled dot-product attention to the inputs.
36 | See also base class.
37 |
38 | Args:
39 | embed_dim: Dimensionality of the query, key, and value inputs.
40 | num_heads: Number of heads.
41 | dropout: Dropout probability on the attention scores,
42 | currently not supported.
43 | bias: Flag for adding bias to the query-key-value-output projections.
44 | add_bias_kv: Flag for appending a bias vector to the key and value sequences,
45 | currently not supported.
46 | add_zero_attn: Flag for appending a zero vector to the key and value sequences,
47 | currently not supported.
48 | kdim: Dimensionality of the key input, which has to be None or equal to embed_dim.
49 | vdim: Dimensionality of the value input, which has to be None or equal to embed_dim.
50 | batch_first: Flag to indicate if the batch dimension comes first in the input,
51 | currently otherwise is not supported.
52 | device: Device to use.
53 | dtype: Dtype of layer.
54 |
55 | Raises:
56 | RuntimeError: 1. Dropout on the attention scores is requested.
57 | 2. Appending a bias vector to the key and value sequences is requested.
58 | 3. Appending a zero vector to the key and value sequences is requested.
59 | 4. The query and key dimensionalities are unequal.
60 | 5. The query and value dimensionalities are unequal.
61 | 6. The input is not batch-first.
62 | """
63 | def __init__(
64 | self,
65 | embed_dim: int,
66 | num_heads: int,
67 | dropout: float = 0.0,
68 | bias: bool = True,
69 | add_bias_kv: bool = False,
70 | add_zero_attn: bool = False,
71 | kdim: Optional[int] = None,
72 | vdim: Optional[int] = None,
73 | batch_first: bool = True,
74 | device: Device = 'cuda',
75 | dtype: torch.dtype = torch.float32,
76 | ) -> None:
77 | if dropout > 0.0:
78 | raise RuntimeError('Dropout on the attention scores is not supported.')
79 |
80 | if add_bias_kv:
81 | raise RuntimeError('Appending a bias vector to the key and value '
82 | 'sequences is not supported.')
83 |
84 | if add_zero_attn:
85 | raise RuntimeError('Appending a zero vector to the key and value '
86 | 'sequences is not supported.')
87 |
88 | if kdim is not None and kdim != embed_dim:
89 | raise RuntimeError(f'The key dimensionality ({kdim}) is not equal to '
90 | f'the query dimensionality ({embed_dim}).')
91 |
92 | if vdim is not None and vdim != embed_dim:
93 | raise RuntimeError(f'The value dimensionality ({vdim}) is not equal to '
94 | f'the query dimensionality ({embed_dim}).')
95 |
96 | if not batch_first:
97 | raise RuntimeError('The input must be batch-first.')
98 |
99 | super().__init__(embed_dim, num_heads, dropout, bias,
100 | add_bias_kv, add_zero_attn,
101 | kdim, vdim, batch_first,
102 | device, dtype)
103 |
104 | def forward(
105 | self,
106 | input_q: Tensor,
107 | input_k: Optional[Tensor] = None,
108 | input_v: Optional[Tensor] = None,
109 | causal: bool = False,
110 | sequence_parallel: bool = False,
111 | ) -> Tensor:
112 | assert input_v is None or input_k is not None, \
113 | f'Key inputs must be provided if value inputs have been passed'
114 |
115 | input_k = input_q if input_k is None else input_k
116 | input_v = input_k if input_v is None else input_v
117 |
118 | if input_k is input_v:
119 | if input_q is input_k:
120 | qkv = F.linear(input_q, self.in_proj_weight, self.in_proj_bias)
121 | q, k, v = map(partial(extract_heads, num_heads=self.num_heads),
122 | qkv.chunk(3, dim=-1))
123 |
124 | else:
125 | weight_q, weight_kv = torch.split(self.in_proj_weight,
126 | [self.embed_dim, 2 * self.embed_dim],
127 | dim=0)
128 |
129 | if self.in_proj_bias is None:
130 | bias_q = bias_kv = None
131 |
132 | else:
133 | bias_q, bias_kv = torch.split(self.in_proj_bias,
134 | [self.embed_dim, 2 * self.embed_dim],
135 | dim=0)
136 |
137 | q = F.linear(input_q, weight_q, bias_q)
138 | kv = F.linear(input_k, weight_kv, bias_kv)
139 | q, k, v = map(partial(extract_heads, num_heads=self.num_heads),
140 | [q, *kv.chunk(2, dim=-1)])
141 |
142 | else:
143 | weight_q, weight_k, weight_v = torch.split(self.in_proj_weight,
144 | 3*[self.embed_dim],
145 | dim=0)
146 |
147 | if self.in_proj_bias is None:
148 | bias_q = bias_k = bias_v = None
149 |
150 | else:
151 | bias_q, bias_k, bias_v = torch.split(self.in_proj_bias,
152 | 3*[self.embed_dim],
153 | dim=0)
154 |
155 | q = F.linear(input_q, weight_q, bias_q)
156 | k = F.linear(input_k, weight_k, bias_k)
157 | v = F.linear(input_v, weight_v, bias_v)
158 | q, k, v = map(partial(extract_heads, num_heads=self.num_heads),
159 | [q, k, v])
160 |
161 | output = _attention.apply(q, k, v, causal, 0.5, sequence_parallel)
162 | output = output.transpose(1, 2).reshape(len(input_q), -1, self.embed_dim)
163 | return self.out_proj(output)
164 |
--------------------------------------------------------------------------------
/attorch/cross_entropy_loss_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Cross entropy loss with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from triton import cdiv
12 |
13 | from .cross_entropy_loss_kernels import cross_entropy_loss_backward_kernel, \
14 | cross_entropy_loss_forward_kernel
15 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic
16 | from .types import Context
17 | from .utils import get_output_dtype
18 |
19 |
20 | class CrossEntropyLossAutoGrad(torch.autograd.Function):
21 | """
22 | Autodiff for cross entropy loss.
23 | """
24 | @staticmethod
25 | def forward(
26 | ctx: Context,
27 | input: Tensor,
28 | target: Tensor,
29 | weight: Optional[Tensor] = None,
30 | ) -> Tensor:
31 | """
32 | Measures the mean cross entropy loss between the input and target,
33 | with optional reweighing of each class.
34 |
35 | Args:
36 | ctx: Context for variable storage.
37 | input: Input.
38 | Must be of shape [batch_dim, feat_dim].
39 | target: Target.
40 | Must be of shape [batch_dim].
41 | weight: Optional class weight vector, with None for no reweighing.
42 | If provided, must be of shape [feat_dim].
43 |
44 | Returns:
45 | Loss.
46 | """
47 | assert input.ndim == 2, f'Inputs of rank other than 2 not valid'
48 | assert len(input) == len(target), \
49 | f'Incompatible input shape ({input.shape}) and target shape ({target.shape})'
50 | assert weight is None or len(weight) == input.shape[1], \
51 | f'Dimensionality of weight vector ({len(weight)}) and input features ({input.shape[1]}) not equal'
52 |
53 | batch_dim, feat_dim = input.shape
54 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim,
55 | 'feat_dim': feat_dim})
56 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH
57 | weighted = weight is not None
58 |
59 | output_dtype = get_output_dtype(input.dtype, autocast='fp32')
60 | output = torch.empty(out_batch_dim,
61 | dtype=output_dtype,
62 | device=input.device)
63 |
64 | if weighted:
65 | sum_weights = torch.empty_like(output, dtype=torch.float32)
66 |
67 | else:
68 | sum_weights = None
69 |
70 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
71 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),)
72 | cross_entropy_loss_forward_kernel[grid](input, target, weight, sum_weights, output,
73 | batch_dim, feat_dim,
74 | *input.stride(),
75 | weighted=weighted)
76 | output = output.sum()
77 |
78 | if weighted:
79 | sum_weights = sum_weights.sum()
80 | output /= sum_weights
81 |
82 | ctx.sum_weights = sum_weights
83 | ctx.weight = weight
84 | ctx.output_dtype = output_dtype
85 | if input.requires_grad:
86 | ctx.save_for_backward(input, target)
87 |
88 | return output
89 |
90 | @staticmethod
91 | def backward(
92 | ctx: Context,
93 | output_grad: Tensor,
94 | ) -> Tuple[Optional[Tensor], ...]:
95 | """
96 | Calculates the input gradient of the loss.
97 |
98 | Args:
99 | ctx: Context containing stored variables.
100 | output_grad: Output gradients.
101 | Must be a scalar.
102 |
103 | Returns:
104 | Input gradient of the loss.
105 | """
106 | (input, target) = ctx.saved_tensors
107 | batch_dim, feat_dim = input.shape
108 | input_grad = torch.empty_like(input, dtype=ctx.output_dtype)
109 |
110 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
111 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),)
112 | cross_entropy_loss_backward_kernel[grid](output_grad, target, input, ctx.weight,
113 | ctx.sum_weights, input_grad,
114 | batch_dim, feat_dim,
115 | *input.stride(),
116 | *input_grad.stride(),
117 | weighted=ctx.weight is not None)
118 |
119 | # Pads output with None because a gradient is necessary for
120 | # all input arguments.
121 | return input_grad, None, None
122 |
123 |
124 | class CrossEntropyLoss(nn.CrossEntropyLoss):
125 | """
126 | Measures the mean cross entropy loss between the input and target,
127 | with optional reweighing of each class.
128 | See also base class.
129 |
130 | Note: To keep its implementation compact, this module does not support
131 | advanced features such as inputs with spatial dimensions or label smoothing.
132 | For greater flexibility, a combination of attorch.LogSoftmax and
133 | attorch.NLLLoss can be used.
134 |
135 | Args:
136 | reduction: Reduction strategy for the output. Only 'mean' is supported.
137 | Providing size_average and reduce overrides this argument.
138 | size_average: Flag for averaging instead of summing the loss entries
139 | when reduce is True. Only averaging is supported.
140 | reduce: Flag for averaging or summing all the loss entries instead of
141 | returning a loss per element. Only averaging is supported.
142 | weight: Optional class weight vector, with None for no reweighing.
143 | If provided, must be of shape [feat_dim].
144 | ignore_index: This argument is not supported.
145 | label_smoothing: This argument is not supported.
146 |
147 | Raises:
148 | RuntimeError: 1. Reduction method was not set to 'mean'.
149 | 2. Label smoothing is requested.
150 | """
151 | def __init__(
152 | self,
153 | reduction: str = 'mean',
154 | size_average: Optional[bool] = None,
155 | reduce: Optional[bool] = None,
156 | weight: Optional[Tensor] = None,
157 | ignore_index: int = -100,
158 | label_smoothing: float = 0.0,
159 | ) -> None:
160 | super().__init__(weight, size_average, ignore_index, reduce,
161 | reduction, label_smoothing)
162 |
163 | if self.reduction != 'mean':
164 | raise RuntimeError('Cross entropy only supports averaging the loss.')
165 |
166 | if label_smoothing > 0.0:
167 | raise RuntimeError('Cross entropy does not support label smoothing.')
168 |
169 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
170 | return CrossEntropyLossAutoGrad.apply(input, target, self.weight)
171 |
--------------------------------------------------------------------------------
/attorch/nll_loss_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Negative log likelihood loss with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from triton import cdiv
12 |
13 | from .nll_loss_kernels import nll_loss_backward_kernel, nll_loss_forward_kernel, \
14 | BLOCK_SIZE_BATCH_heuristic
15 | from .types import Context
16 | from .utils import get_output_dtype
17 |
18 |
19 | class NLLLossAutoGrad(torch.autograd.Function):
20 | """
21 | Autodiff for negative log likelihood loss.
22 | """
23 | @staticmethod
24 | def forward(
25 | ctx: Context,
26 | input: Tensor,
27 | target: Tensor,
28 | reduction: str,
29 | weight: Optional[Tensor] = None,
30 | ) -> Tensor:
31 | """
32 | Measures the negative log likelihood loss between the input and target,
33 | with optional reweighing of each class.
34 |
35 | Args:
36 | ctx: Context for variable storage.
37 | input: Input.
38 | Must be of shape [batch_dim, feat_dim, ...],
39 | where ... denotes an arbitrary number of spatial dimensions.
40 | target: Target.
41 | Must be of shape [batch_dim, ...],
42 | where ... denotes the same spatial dimensions as the input.
43 | reduction: Reduction strategy for the output.
44 | Options are 'none' for no reduction, 'mean' for averaging the loss
45 | across all entries, and 'sum' for summing the loss across all entries.
46 | weight: Optional class weight vector, with None for no reweighing.
47 | If provided, must be of shape [feat_dim].
48 |
49 | Returns:
50 | Loss.
51 | """
52 | assert len(input) == len(target) and input.shape[2:] == target.shape[1:], \
53 | f'Incompatible input shape ({input.shape}) and target shape ({target.shape})'
54 | assert weight is None or len(weight) == input.shape[1], \
55 | f'Dimensionality of weight vector ({len(weight)}) and input features ({input.shape[1]}) not equal'
56 |
57 | flattened_input = input.unsqueeze(-1) if input.ndim == 2 else input
58 | flattened_input = flattened_input.flatten(2, -1)
59 |
60 | flattened_target = target.unsqueeze(-1) if target.ndim == 1 else target
61 | flattened_target = flattened_target.flatten(1, -1)
62 |
63 | batch_dim, _, spatial_dim = flattened_input.shape
64 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim,
65 | 'spatial_dim': spatial_dim})
66 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH
67 |
68 | output_dtype = get_output_dtype(input.dtype, autocast='fp32')
69 | sum_weights = (torch.empty(out_batch_dim, dtype=torch.float32, device=input.device)
70 | if reduction == 'mean' else None)
71 | output = (torch.empty_like(flattened_target, dtype=output_dtype)
72 | if reduction == 'none' else
73 | torch.empty(out_batch_dim, dtype=output_dtype, device=input.device))
74 |
75 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
76 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),)
77 | nll_loss_forward_kernel[grid](input, target, weight, sum_weights, output,
78 | batch_dim, spatial_dim,
79 | *flattened_input.stride(),
80 | *flattened_target.stride(),
81 | *output.stride() if reduction == 'none' else (1, 1),
82 | reduction=reduction,
83 | weighted=weight is not None)
84 |
85 | if reduction != 'none':
86 | output = output.sum()
87 |
88 | if reduction == 'mean' and weight is not None:
89 | sum_weights = sum_weights.sum()
90 | output /= sum_weights
91 |
92 | else:
93 | output = output.view_as(target)
94 |
95 | ctx.sum_weights = sum_weights
96 | ctx.reduction = reduction
97 | ctx.weight = weight
98 | ctx.output_dtype = output_dtype
99 | if input.requires_grad:
100 | ctx.save_for_backward(input, flattened_target)
101 |
102 | return output
103 |
104 | @staticmethod
105 | def backward(
106 | ctx: Context,
107 | output_grad: Tensor,
108 | ) -> Tuple[Optional[Tensor], ...]:
109 | """
110 | Calculates the input gradient of the loss.
111 |
112 | Args:
113 | ctx: Context containing stored variables.
114 | output_grad: Output gradients.
115 | Must be the same shape as the output.
116 |
117 | Returns:
118 | Input gradient of the loss.
119 | """
120 | (input, flattened_target) = ctx.saved_tensors
121 | flattened_input = input.view(len(flattened_target), -1,
122 | flattened_target.shape[-1])
123 | output_grad = (output_grad.view_as(flattened_target)
124 | if output_grad.ndim > 0 else output_grad)
125 |
126 | batch_dim, _, spatial_dim = flattened_input.shape
127 | input_grad = torch.zeros_like(flattened_input, dtype=ctx.output_dtype)
128 |
129 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
130 | grid = lambda META: (cdiv(len(input), META['BLOCK_SIZE_BATCH']),)
131 | nll_loss_backward_kernel[grid](output_grad, flattened_target, ctx.weight,
132 | ctx.sum_weights, input_grad,
133 | batch_dim, spatial_dim,
134 | *output_grad.stride() if ctx.reduction == 'none' else (1, 1),
135 | *flattened_target.stride(),
136 | *input_grad.stride(),
137 | reduction=ctx.reduction,
138 | weighted=ctx.weight is not None)
139 |
140 | # Pads output with None because a gradient is necessary for
141 | # all input arguments.
142 | return input_grad.view_as(input), None, None, None
143 |
144 |
145 | class NLLLoss(nn.NLLLoss):
146 | """
147 | Measures the negative log likelihood loss between the input and target,
148 | with optional reweighing of each class.
149 | See also base class.
150 |
151 | Args:
152 | reduction: Reduction strategy for the output.
153 | Options are 'none' for no reduction, 'mean' for averaging the loss
154 | across all entries, and 'sum' for summing the loss across all entries.
155 | Providing size_average and reduce overrides this argument.
156 | size_average: Flag for averaging instead of summing the loss entries
157 | when reduce is True.
158 | reduce: Flag for averaging or summing all the loss entries instead of
159 | returning a loss per element.
160 | weight: Optional class weight vector, with None for no reweighing.
161 | If provided, must be of shape [feat_dim].
162 | ignore_index: This argument is not supported.
163 | """
164 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
165 | return NLLLossAutoGrad.apply(input, target, self.reduction, self.weight)
166 |
--------------------------------------------------------------------------------
/attorch/p_loss_layers.py:
--------------------------------------------------------------------------------
1 | """
2 | p-norm-induced losses with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from triton import cdiv
12 |
13 | from .p_loss_kernels import p_loss_backward_kernel, p_loss_forward_kernel
14 | from .types import Context
15 | from .utils import get_output_dtype
16 |
17 |
18 | class PLossAutoGrad(torch.autograd.Function):
19 | """
20 | Autodiff for p-losses.
21 | """
22 | @staticmethod
23 | def forward(
24 | ctx: Context,
25 | input: Tensor,
26 | target: Tensor,
27 | p_loss: int,
28 | reduction: str,
29 | param: float = 1.0,
30 | ) -> Tensor:
31 | """
32 | Measures the smooth L1, L1, squared L2, or Huber loss of the difference
33 | between the input and target.
34 |
35 | Args:
36 | ctx: Context for variable storage.
37 | input: Input.
38 | Can have arbitrary shape.
39 | target: Target.
40 | Must be the same shape as input.
41 | p_loss: p-norm used to compute the error.
42 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss.
43 | reduction: Reduction strategy for the output.
44 | Options are 'none' for no reduction, 'mean' for averaging the error
45 | across all entries, and 'sum' for summing the error across all entries.
46 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber).
47 |
48 | Returns:
49 | Error.
50 | """
51 | assert input.shape == target.shape, \
52 | f'Input shape {input.shape} and target shape {target.shape} not equal'
53 |
54 | output_dtype = get_output_dtype(input.dtype, autocast='fp32')
55 |
56 | ctx.p_loss = p_loss
57 | ctx.reduction = reduction
58 | ctx.param = param
59 | ctx.output_dtype = output_dtype
60 | if input.requires_grad or target.requires_grad:
61 | ctx.save_for_backward(input, target)
62 |
63 | flattened_input = input.flatten()
64 | flattened_target = target.flatten()
65 | size = len(flattened_input)
66 |
67 | output = (torch.empty_like(flattened_input, dtype=output_dtype) if reduction == 'none'
68 | else torch.empty(cdiv(size, 32), dtype=output_dtype, device=input.device))
69 |
70 | # Launches 1D grid where each program operates over
71 | # BLOCK_SIZE elements.
72 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),)
73 | p_loss_forward_kernel[grid](flattened_input, flattened_target, output,
74 | param, size, p_loss=p_loss, reduction=reduction)
75 |
76 | if reduction != 'none':
77 | BLOCK_SIZE = p_loss_forward_kernel.best_config.kwargs['BLOCK_SIZE']
78 | output = output[:cdiv(size, BLOCK_SIZE)].sum()
79 |
80 | else:
81 | output = output.view_as(input)
82 |
83 | return output
84 |
85 | @staticmethod
86 | def backward(
87 | ctx: Context,
88 | output_grad: Tensor,
89 | ) -> Tuple[Optional[Tensor], ...]:
90 | """
91 | Calculates the input gradient of the error.
92 |
93 | Args:
94 | ctx: Context containing stored variables.
95 | output_grad: Output gradients.
96 | Must be the same shape as the output.
97 |
98 | Returns:
99 | Input gradient of the error.
100 | """
101 | (input, target) = ctx.saved_tensors
102 | flattened_input = input.flatten()
103 | flattened_target = target.flatten()
104 | output_grad = output_grad.flatten()
105 |
106 | size = len(flattened_input)
107 | input_grad = torch.empty_like(flattened_input, dtype=ctx.output_dtype)
108 | target_grad = torch.empty_like(flattened_target, dtype=ctx.output_dtype)
109 |
110 | # Launches 1D grid where each program operates over
111 | # BLOCK_SIZE elements.
112 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),)
113 | p_loss_backward_kernel[grid](output_grad, flattened_input, flattened_target,
114 | input_grad, target_grad, ctx.param, size,
115 | p_loss=ctx.p_loss, reduction=ctx.reduction)
116 |
117 | # Pads output with None because a gradient is necessary for
118 | # all input arguments.
119 | return input_grad.view_as(input), target_grad.view_as(input), None, None
120 |
121 |
122 | class L1Loss(nn.L1Loss):
123 | """
124 | Measures the L1 error (mean absolute error) between the input and target.
125 | See also base class.
126 |
127 | Args:
128 | reduction: Reduction strategy for the output.
129 | Options are 'none' for no reduction, 'mean' for averaging the error
130 | across all entries, and 'sum' for summing the error across all entries.
131 | Providing size_average and reduce overrides this argument.
132 | size_average: Flag for averaging instead of summing the error entries
133 | when reduce is True.
134 | reduce: Flag for averaging or summing all the error entries instead of
135 | returning a loss per element.
136 | """
137 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
138 | return PLossAutoGrad.apply(input, target, 1, self.reduction)
139 |
140 |
141 | class MSELoss(nn.MSELoss):
142 | """
143 | Measures the squared L2 error (mean squared error) between the input and target.
144 | See also base class.
145 |
146 | Args:
147 | reduction: Reduction strategy for the output.
148 | Options are 'none' for no reduction, 'mean' for averaging the error
149 | across all entries, and 'sum' for summing the error across all entries.
150 | Providing size_average and reduce overrides this argument.
151 | size_average: Flag for averaging instead of summing the error entries
152 | when reduce is True.
153 | reduce: Flag for averaging or summing all the error entries instead of
154 | returning a loss per element.
155 | """
156 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
157 | return PLossAutoGrad.apply(input, target, 2, self.reduction)
158 |
159 |
160 | class SmoothL1Loss(nn.SmoothL1Loss):
161 | """
162 | Measures the smooth L1 error between the input and target.
163 | See also base class.
164 |
165 | Args:
166 | reduction: Reduction strategy for the output.
167 | Options are 'none' for no reduction, 'mean' for averaging the error
168 | across all entries, and 'sum' for summing the error across all entries.
169 | Providing size_average and reduce overrides this argument.
170 | size_average: Flag for averaging instead of summing the error entries
171 | when reduce is True.
172 | reduce: Flag for averaging or summing all the error entries instead of
173 | returning a loss per element.
174 | beta: Beta value for the softening threshold.
175 | """
176 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
177 | return PLossAutoGrad.apply(input, target, 0, self.reduction, self.beta)
178 |
179 |
180 | class HuberLoss(nn.HuberLoss):
181 | """
182 | Measures the Huber loss between the input and target.
183 | See also base class.
184 |
185 | Args:
186 | reduction: Reduction strategy for the output.
187 | Options are 'none' for no reduction, 'mean' for averaging the error
188 | across all entries, and 'sum' for summing the error across all entries.
189 | beta: Beta value for the softening threshold.
190 | """
191 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
192 | return PLossAutoGrad.apply(input, target, 3, self.reduction, self.delta)
193 |
--------------------------------------------------------------------------------
/attorch/cross_entropy_loss_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Kernels for cross entropy loss.
3 | """
4 |
5 |
6 | import triton
7 | import triton.language as tl
8 | from triton import next_power_of_2
9 |
10 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic
11 | from .utils import warps_kernel_configs
12 |
13 |
14 | @triton.autotune(
15 | configs=warps_kernel_configs(),
16 | key=['batch_dim', 'feat_dim'],
17 | )
18 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,
19 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})
20 | @triton.jit
21 | def cross_entropy_loss_forward_kernel(
22 | input_pointer, target_pointer, weight_pointer,
23 | sum_weights_pointer, output_pointer,
24 | batch_dim, feat_dim,
25 | input_batch_stride, input_feat_stride,
26 | weighted: tl.constexpr,
27 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,
28 | ):
29 | """
30 | Measures the mean cross entropy loss between the input and target,
31 | with optional reweighing of each class.
32 |
33 | Args:
34 | input_pointer: Pointer to the input.
35 | The input must be of shape [batch_dim, feat_dim].
36 | target_pointer: Pointer to the target.
37 | The target must be of shape [batch_dim].
38 | weight_pointer: Pointer to an optional class weight vector.
39 | The class weight vector, if provided, must be of shape [feat_dim].
40 | sum_weights_pointer: Pointer to a container the sum of the class weights is written to.
41 | The container must be of shape [batch_dim/BLOCK_SIZE_BATCH].
42 | output_pointer: Pointer to a container the loss is written to.
43 | The container must be of shape [batch_dim/BLOCK_SIZE_BATCH].
44 | batch_dim: Batch dimension.
45 | feat_dim: Dimensionality of the features.
46 | input_batch_stride: Stride necessary to jump one element along the
47 | input's batch dimension.
48 | input_feat_stride: Stride necessary to jump one element along the
49 | input's feature dimension.
50 | weighted: Flag for weighing each class.
51 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
52 | BLOCK_SIZE_FEAT: Block size across the feature dimension.
53 | """
54 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.
55 | batch_pid = tl.program_id(axis=0)
56 |
57 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
58 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)
59 |
60 | batch_mask = batch_offset < batch_dim
61 | feat_mask = feat_offset < feat_dim
62 |
63 | target = tl.load(target_pointer + batch_offset, mask=batch_mask)
64 |
65 | pred_pointer = (input_pointer +
66 | input_feat_stride * target +
67 | input_batch_stride * batch_offset)
68 | input_pointer += (input_batch_stride * batch_offset[:, None] +
69 | input_feat_stride * feat_offset[None, :])
70 |
71 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :],
72 | other=-float('inf')).to(tl.float32)
73 | pred = tl.load(pred_pointer, mask=batch_mask).to(tl.float32)
74 | mx = tl.max(input, axis=1)
75 | input -= mx[:, None]
76 | loss = tl.log(tl.sum(tl.exp(input), axis=1)) - pred + mx
77 |
78 | if weighted:
79 | weight = tl.load(weight_pointer + target, mask=batch_mask).to(tl.float32)
80 | loss *= weight
81 | tl.store(sum_weights_pointer + batch_pid, tl.sum(weight))
82 |
83 | else:
84 | loss /= batch_dim
85 |
86 | tl.store(output_pointer + batch_pid, tl.sum(loss))
87 |
88 |
89 | @triton.autotune(
90 | configs=warps_kernel_configs(),
91 | key=['batch_dim', 'feat_dim'],
92 | )
93 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,
94 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})
95 | @triton.jit
96 | def cross_entropy_loss_backward_kernel(
97 | output_grad_pointer, target_pointer, input_pointer, weight_pointer,
98 | sum_weights_pointer, input_grad_pointer,
99 | batch_dim, feat_dim,
100 | input_batch_stride, input_feat_stride,
101 | input_grad_batch_stride, input_grad_feat_stride,
102 | weighted: tl.constexpr,
103 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,
104 | ):
105 | """
106 | Calculates the input gradient of cross entropy loss.
107 |
108 | Args:
109 | output_grad_pointer: Pointer to the loss's output gradients.
110 | The output gradient must be a scalar.
111 | target_pointer: Pointer to the target.
112 | The target must be of shape [batch_dim].
113 | input_pointer: Pointer to the input.
114 | The input must be of shape [batch_dim, feat_dim].
115 | weight_pointer: Pointer to an optional class weight vector.
116 | The class weight vector, if provided, must be of shape [feat_dim].
117 | sum_weights_pointer: Pointer to the sum of the class weights if the classes were weighed.
118 | The sum of weights must be a scalar.
119 | input_grad_pointer: Pointer to a container the input's gradients are written to.
120 | The container must be of shape [batch_dim, feat_dim].
121 | batch_dim: Batch dimension.
122 | feat_dim: Dimensionality of the features.
123 | input_batch_stride: Stride necessary to jump one element along the
124 | input's batch dimension.
125 | input_feat_stride: Stride necessary to jump one element along the
126 | input's feature dimension.
127 | input_grad_batch_stride: Stride necessary to jump one element along the
128 | input gradient container's batch dimension.
129 | input_grad_feat_stride: Stride necessary to jump one element along the
130 | input gradient container's feature dimension.
131 | weighted: Flag for weighing each class.
132 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
133 | BLOCK_SIZE_FEAT: Block size across the feature dimension.
134 | """
135 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.
136 | batch_pid = tl.program_id(axis=0)
137 |
138 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
139 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)
140 |
141 | batch_mask = batch_offset < batch_dim
142 | feat_mask = feat_offset < feat_dim
143 |
144 | input_pointer += (input_batch_stride * batch_offset[:, None] +
145 | input_feat_stride * feat_offset[None, :])
146 | input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] +
147 | input_grad_feat_stride * feat_offset[None, :])
148 |
149 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :],
150 | other=-float('inf')).to(tl.float32)
151 | input -= tl.max(input, axis=1)[:, None]
152 | numerator = tl.exp(input)
153 | softmax = numerator / tl.sum(numerator, axis=1)[:, None]
154 |
155 | output_grad = tl.load(output_grad_pointer).to(tl.float32)
156 | target = tl.load(target_pointer + batch_offset, mask=batch_mask)
157 | broadcasted_feat_offset = tl.broadcast_to(feat_offset[None, :],
158 | (BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT))
159 | broadcasted_target = tl.broadcast_to(target[:, None],
160 | (BLOCK_SIZE_BATCH, BLOCK_SIZE_FEAT))
161 | input_grad = output_grad * (softmax - (broadcasted_feat_offset == broadcasted_target))
162 |
163 | if weighted:
164 | weight = tl.load(weight_pointer + target, mask=batch_mask).to(tl.float32)
165 | sum_weights = tl.load(sum_weights_pointer)
166 | input_grad *= weight[:, None] / sum_weights
167 |
168 | else:
169 | input_grad /= batch_dim
170 |
171 | tl.store(input_grad_pointer, input_grad,
172 | mask=batch_mask[:, None] & feat_mask[None, :])
173 |
--------------------------------------------------------------------------------
/attorch/softmax_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Kernels for softmax and related functions.
3 | """
4 |
5 |
6 | from typing import Dict
7 |
8 | import triton
9 | import triton.language as tl
10 | from triton import next_power_of_2
11 |
12 | from .utils import warps_kernel_configs
13 |
14 |
15 | def BLOCK_SIZE_BATCH_heuristic(args: Dict) -> int:
16 | """
17 | Approximates an appropriate batch block size for softmax using a heuristic.
18 |
19 | Args:
20 | args: Arguments to softmax kernel.
21 |
22 | Returns:
23 | Appropriate batch block size.
24 | """
25 | # This heuristic was derived manually.
26 | # Essentially, if the batch dimension is greater than 1024,
27 | # for small feature sizes (less than 64), it is much more efficient
28 | # to process multiple rows at once in a given program.
29 | # Specifically, each time the number of samples is doubled,
30 | # the block size across the batch dimension should be doubled too,
31 | # with an upper bound of 128.
32 | return (min(max(1, next_power_of_2(args['batch_dim'] // 2 ** 10)), 128)
33 | if args['feat_dim'] < 64 else 1)
34 |
35 |
36 | @triton.autotune(
37 | configs=warps_kernel_configs(),
38 | key=['batch_dim', 'feat_dim'],
39 | )
40 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,
41 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})
42 | @triton.jit
43 | def softmax_forward_kernel(
44 | input_pointer, output_pointer,
45 | batch_dim, feat_dim,
46 | input_batch_stride, input_feat_stride,
47 | output_batch_stride, output_feat_stride,
48 | neg: tl.constexpr, log: tl.constexpr,
49 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,
50 | ):
51 | """
52 | Normalizes the input using softmax.
53 |
54 | Args:
55 | input_pointer: Pointer to the input to normalize.
56 | The input must be of shape [batch_dim, feat_dim].
57 | output_pointer: Pointer to a container the result is written to.
58 | The container must be of shape [batch_dim, feat_dim].
59 | batch_dim: Batch dimension.
60 | feat_dim: Dimensionality of the features.
61 | input_batch_stride: Stride necessary to jump one element along the
62 | input's batch dimension.
63 | input_feat_stride: Stride necessary to jump one element along the
64 | input's feature dimension.
65 | output_batch_stride: Stride necessary to jump one element along the
66 | output container's batch dimension.
67 | output_feat_stride: Stride necessary to jump one element along the
68 | output container's feature dimension.
69 | neg: Flag indicating if the input should be negated to get softmin.
70 | log: Flag indicating if the log of softmax should be taken.
71 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
72 | BLOCK_SIZE_FEAT: Block size across the feature dimension.
73 | """
74 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.
75 | batch_pid = tl.program_id(axis=0)
76 |
77 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
78 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)
79 |
80 | batch_mask = batch_offset < batch_dim
81 | feat_mask = feat_offset < feat_dim
82 |
83 | input_pointer += (input_batch_stride * batch_offset[:, None] +
84 | input_feat_stride * feat_offset[None, :])
85 | output_pointer += (output_batch_stride * batch_offset[:, None] +
86 | output_feat_stride * feat_offset[None, :])
87 |
88 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :],
89 | other=float('inf') if neg else -float('inf')).to(tl.float32)
90 | if neg:
91 | input = -input
92 |
93 | input -= tl.max(input, axis=1)[:, None]
94 | numerator = tl.exp(input)
95 | denominator = tl.sum(numerator, axis=1)[:, None]
96 |
97 | if log:
98 | output = input - tl.log(denominator)
99 |
100 | else:
101 | output = numerator / denominator
102 |
103 | tl.store(output_pointer, output, mask=batch_mask[:, None] & feat_mask[None, :])
104 |
105 |
106 | @triton.autotune(
107 | configs=warps_kernel_configs(),
108 | key=['batch_dim', 'feat_dim'],
109 | )
110 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,
111 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})
112 | @triton.jit
113 | def softmax_backward_kernel(
114 | output_grad_pointer, output_pointer, input_grad_pointer,
115 | batch_dim, feat_dim,
116 | output_grad_batch_stride, output_grad_feat_stride,
117 | output_batch_stride, output_feat_stride,
118 | input_grad_batch_stride, input_grad_feat_stride,
119 | neg: tl.constexpr, log: tl.constexpr,
120 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,
121 | ):
122 | """
123 | Calculates the input gradient of softmax.
124 |
125 | Args:
126 | output_grad_pointer: Pointer to softmax's output gradients.
127 | The output gradients must be of shape [batch_dim, feat_dim].
128 | output_pointer: Pointer to softmax's output.
129 | The output must be of shape [batch_dim, feat_dim].
130 | input_grad_pointer: Pointer to a container the input's gradients are written to.
131 | The container must be of shape [batch_dim, feat_dim].
132 | batch_dim: Batch dimension.
133 | feat_dim: Dimensionality of the features.
134 | output_grad_batch_stride: Stride necessary to jump one element along the
135 | output gradients' batch dimension.
136 | output_grad_feat_stride: Stride necessary to jump one element along the
137 | output gradients' feature dimension.
138 | output_batch_stride: Stride necessary to jump one element along the
139 | output's batch dimension.
140 | output_feat_stride: Stride necessary to jump one element along the
141 | output's feature dimension.
142 | input_grad_batch_stride: Stride necessary to jump one element along the
143 | input gradient container's batch dimension.
144 | input_grad_feat_stride: Stride necessary to jump one element along the
145 | input gradient container's feature dimension.
146 | neg: Flag indicating if the input was negated to get softmin.
147 | log: Flag indicating if log of softmax was taken.
148 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
149 | BLOCK_SIZE_FEAT: Block size across the feature dimension.
150 | """
151 | # This program processes a single row and BLOCK_SIZE_FEAT columns.
152 | batch_pid = tl.program_id(axis=0)
153 |
154 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
155 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)
156 |
157 | batch_mask = batch_offset < batch_dim
158 | feat_mask = feat_offset < feat_dim
159 |
160 | output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] +
161 | output_grad_feat_stride * feat_offset[None, :])
162 | output_pointer += (output_batch_stride * batch_offset[:, None] +
163 | output_feat_stride * feat_offset[None, :])
164 | input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] +
165 | input_grad_feat_stride * feat_offset[None, :])
166 |
167 | output_grad = tl.load(output_grad_pointer,
168 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)
169 | output = tl.load(output_pointer,
170 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)
171 |
172 | if log:
173 | input_grad = (output_grad -
174 | tl.exp(output) * tl.sum(output_grad, axis=1)[:, None])
175 |
176 | else:
177 | input_grad = output * (output_grad -
178 | tl.sum(output_grad * output, axis=1)[:, None])
179 |
180 | tl.store(input_grad_pointer, -input_grad if neg else input_grad,
181 | mask=batch_mask[:, None] & feat_mask[None, :])
182 |
--------------------------------------------------------------------------------
/examples/wikitext-2/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains and benchmarks a language model on the WikiText-2 dataset.
3 | """
4 |
5 |
6 | import time
7 | from argparse import ArgumentParser
8 | from functools import partial
9 | from math import exp, sqrt
10 | from typing import Callable, Tuple
11 |
12 | import torch
13 | from datasets import load_dataset
14 | from torch import nn
15 | from torch.optim import AdamW
16 | from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
17 | from torch.utils.data import DataLoader
18 | from transformers import AutoTokenizer
19 |
20 | from .gpt import gpt2
21 | from ..utils import AvgMeter, benchmark_fw_and_bw
22 |
23 |
24 | def create_dls(
25 | batch_size: int = 32,
26 | seq_len: int = 512,
27 | num_workers: int = 4,
28 | ) -> Tuple[DataLoader, DataLoader]:
29 | """
30 | Creates data loaders for WikiText-2 with tokenization.
31 |
32 | args:
33 | batch_size: Batch size.
34 | seq_len: Sequence length.
35 | num_workers: Number of workers for data loading.
36 |
37 | Returns:
38 | Training and validation WikiText-2 data loaders.
39 | """
40 | tokenizer = AutoTokenizer.from_pretrained('gpt2')
41 | tokenizer.pad_token = tokenizer.eos_token
42 |
43 | dataset = load_dataset('wikitext', 'wikitext-2-v1')
44 | dataset = dataset.map(lambda input: tokenizer(input['text'],
45 | max_length=seq_len,
46 | padding='max_length',
47 | truncation=True),
48 | batched=True)
49 | dataset.set_format('torch')
50 |
51 | train_dl = DataLoader(dataset=dataset['train'], batch_size=batch_size,
52 | shuffle=True, num_workers=num_workers, pin_memory=True,
53 | drop_last=True)
54 | valid_dl = DataLoader(dataset=dataset['validation'], batch_size=batch_size,
55 | shuffle=False, num_workers=num_workers, pin_memory=True,
56 | drop_last=True)
57 |
58 | return train_dl, valid_dl
59 |
60 |
61 | def train(
62 | model: nn.Module,
63 | train_dl: DataLoader,
64 | valid_dl: DataLoader,
65 | scheduler: str = 'one-cycle',
66 | epochs: int = 10,
67 | batch_size: int = 32,
68 | ) -> float:
69 | """
70 | Trains and validates a model for language modelling.
71 |
72 | Args:
73 | model: Model to train. Its forward pass must optionally accept return_loss
74 | to compute the loss.
75 | scheduler: Learning rate scheduler.
76 | Options are 'one-cycle' and 'cosine'.
77 | train_dl: Data loader for training.
78 | valid_dl: Data loader for validation.
79 | epochs: Number of epochs to train for.
80 | batch_size: Batch size.
81 |
82 | Returns:
83 | Total training and validation time.
84 | """
85 | model = model.to('cuda')
86 | optim = AdamW(model.parameters(), lr=1e-4)
87 | optim.zero_grad()
88 | scaler = torch.GradScaler('cuda')
89 |
90 | if scheduler == 'one-cycle':
91 | sched = OneCycleLR(optim, max_lr=sqrt(batch_size / 32) * 4e-4,
92 | steps_per_epoch=len(train_dl), epochs=epochs)
93 |
94 | elif scheduler == 'cosine':
95 | sched = CosineAnnealingLR(optim, T_max=len(train_dl) * epochs)
96 |
97 | else:
98 | raise RuntimeError(f'Scheduler {scheduler} not supported.')
99 |
100 | avg_meter = AvgMeter()
101 | start = time.time()
102 |
103 | for epoch in range(epochs):
104 | print(f'Epoch {epoch+1}/{epochs}')
105 |
106 | model.train()
107 | avg_meter.reset()
108 | for ind, batch in enumerate(train_dl):
109 | if (ind + 1) % 5 == 0:
110 | print(f'Training iteration {ind+1}/{len(train_dl)}', end='\r')
111 |
112 | input = batch['input_ids'].to('cuda')
113 |
114 | with torch.autocast('cuda'):
115 | loss = model(input, return_loss=True)
116 |
117 | scaler.scale(loss).backward()
118 | scaler.step(optim)
119 | scaler.update()
120 | sched.step()
121 | optim.zero_grad()
122 |
123 | avg_meter.update(loss.item(), len(input))
124 | print(f'Training loss: {avg_meter.avg}')
125 |
126 | model.eval()
127 | avg_meter.reset()
128 | with torch.no_grad():
129 | for ind, batch in enumerate(valid_dl):
130 | if (ind + 1) % 5 == 0:
131 | print(f'Validation iteration {ind+1}/{len(valid_dl)}', end='\r')
132 |
133 | input = batch['input_ids'].to('cuda')
134 |
135 | with torch.autocast('cuda'):
136 | loss = model(input, return_loss=True)
137 | avg_meter.update(loss.item(), len(input))
138 | print(f'Validation perplexity: {exp(avg_meter.avg)}')
139 |
140 | return time.time() - start
141 |
142 |
143 | def main(
144 | model_cls: Callable,
145 | scheduler: str = 'one-cycle',
146 | epochs: int = 10,
147 | batch_size: int = 32,
148 | seq_len: int = 512,
149 | num_workers: int = 4,
150 | ) -> None:
151 | """
152 | Trains and benchmarks a language model on the WikiText-2 dataset.
153 |
154 | Args:
155 | model_cls: Model class to train, with a 'vocab_size' argument for
156 | specifying the vocabulary size and a 'max_seq_len' argument for
157 | specifying the maximum sequence length of the incoming inputs.
158 | scheduler: Learning rate scheduler.
159 | Options are 'one-cycle' and 'cosine'.
160 | epochs: Number of epochs to train for.
161 | batch_size: Batch size.
162 | seq_len: Sequence length.
163 | num_workers: Number of workers for data loading.
164 | """
165 | train_dl, valid_dl = create_dls(batch_size, seq_len, num_workers)
166 | vocab_size = AutoTokenizer.from_pretrained('gpt2').vocab_size
167 | model = model_cls(vocab_size=vocab_size, max_seq_len=seq_len).to('cuda')
168 |
169 | batch = next(iter(train_dl))
170 | input = batch['input_ids'].to('cuda')
171 |
172 | for _ in range(10):
173 | model.train()
174 | with torch.autocast('cuda'):
175 | loss = model(input, return_loss=True)
176 | loss.backward()
177 |
178 | model.eval()
179 | with torch.no_grad() and torch.autocast('cuda'):
180 | model(input, return_loss=True)
181 |
182 | model.train()
183 | benchmark_fw_and_bw(model, input=input, return_loss=True)
184 |
185 | print('Total training and validation time: '
186 | f'{train(model, train_dl, valid_dl, scheduler, epochs, batch_size)}')
187 |
188 |
189 | if __name__ == '__main__':
190 | parser = ArgumentParser(description='Trains and benchmarks a language model on the WikiText-2 dataset.')
191 | parser.add_argument('--model',
192 | type=str,
193 | default='gpt2',
194 | choices=['gpt2'],
195 | help='Name of language model to train')
196 | parser.add_argument('--downsize',
197 | type=int,
198 | default=1,
199 | help='The depth and width of the model are calculated by dividing GPT2\'s original depth and width by this factor.')
200 | parser.add_argument('--scheduler',
201 | type=str,
202 | default='one-cycle',
203 | choices=['one-cycle', 'cosine'],
204 | help='Learning rate scheduler.')
205 | parser.add_argument('--epochs',
206 | type=int,
207 | default=10,
208 | help='Number of epochs to train for')
209 | parser.add_argument('--batch_size',
210 | type=int,
211 | default=32,
212 | help='Batch size')
213 | parser.add_argument('--seq_len',
214 | type=int,
215 | default=512,
216 | help='Sequence length')
217 | parser.add_argument('--num_workers',
218 | type=int,
219 | default=4,
220 | help='Number of workers for data loading')
221 | args = parser.parse_args()
222 |
223 | print('attorch run:')
224 | main(partial(locals()[args.model], use_attorch=True, downsize=args.downsize),
225 | scheduler=args.scheduler, epochs=args.epochs, batch_size=args.batch_size,
226 | seq_len=args.seq_len, num_workers=args.num_workers)
227 |
228 | print('PyTorch run:')
229 | main(partial(locals()[args.model], use_attorch=False, downsize=args.downsize),
230 | scheduler=args.scheduler, epochs=args.epochs, batch_size=args.batch_size,
231 | seq_len=args.seq_len, num_workers=args.num_workers)
232 |
--------------------------------------------------------------------------------
/attorch/layer_norm_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Layer normalization with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from triton import cdiv
12 |
13 | from .layer_norm_kernels import layer_norm_backward_kernel, layer_norm_forward_kernel
14 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic
15 | from .types import Context, Device
16 | from .utils import get_output_dtype
17 |
18 |
19 | class LayerNormAutoGrad(torch.autograd.Function):
20 | """
21 | Autodiff for layer normalization.
22 | """
23 | @staticmethod
24 | def forward(
25 | ctx: Context,
26 | input: Tensor,
27 | weight: Optional[Tensor] = None,
28 | bias: Optional[Tensor] = None,
29 | eps: float = 1e-5,
30 | autocast_to_fp32: bool = True,
31 | ) -> Tensor:
32 | """
33 | Layer-normalizes the input.
34 |
35 | Args:
36 | ctx: Context for variable storage.
37 | input: Input to layer-normalize.
38 | Can have arbitrary shape.
39 | weight: Optional weights for affine transform.
40 | If provided, must be of shape [feat_dim].
41 | bias: Optional bias vector for affine transform when weight is provided.
42 | If provided, must be of shape [feat_dim].
43 | eps: Epsilon added in the square root in the denominator
44 | to avoid division by zero.
45 | autocast_to_fp32: Flag for autocasting the output dtype to fp32.
46 | If False, the input dtype flows through.
47 |
48 | Returns:
49 | Layer-normalized input.
50 | """
51 | flattened_input = input.unsqueeze(0) if input.ndim == 1 else input
52 | flattened_input = flattened_input.flatten(0, -2)
53 | batch_dim, feat_dim = flattened_input.shape
54 |
55 | output_dtype = get_output_dtype(input.dtype,
56 | autocast='fp32' if autocast_to_fp32 else None)
57 | output = torch.empty_like(flattened_input, dtype=output_dtype)
58 |
59 | scale_by_weight = weight is not None
60 | add_bias = scale_by_weight and bias is not None
61 | requires_grad = (input.requires_grad or
62 | (scale_by_weight and weight.requires_grad) or
63 | (add_bias and bias.requires_grad))
64 |
65 | if requires_grad:
66 | mean = torch.empty(batch_dim,
67 | device=input.device,
68 | dtype=torch.float32)
69 | inv_std = torch.empty(batch_dim,
70 | device=input.device,
71 | dtype=torch.float32)
72 |
73 | else:
74 | mean = inv_std = None
75 |
76 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
77 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),)
78 | layer_norm_forward_kernel[grid](flattened_input, weight, bias,
79 | mean, inv_std, output,
80 | batch_dim, feat_dim,
81 | *flattened_input.stride(), *output.stride(),
82 | eps,
83 | scale_by_weight=scale_by_weight,
84 | add_bias=add_bias,
85 | save_stats=requires_grad)
86 |
87 | ctx.scale_by_weight = scale_by_weight
88 | ctx.add_bias = add_bias
89 | ctx.output_dtype = output_dtype
90 | if requires_grad:
91 | ctx.save_for_backward(flattened_input, mean, inv_std, weight)
92 |
93 | return output.view_as(input)
94 |
95 | @staticmethod
96 | def backward(
97 | ctx: Context,
98 | output_grad: Tensor,
99 | ) -> Tuple[Optional[Tensor], ...]:
100 | """
101 | Calculates the input gradient of layer normalization.
102 |
103 | Args:
104 | ctx: Context containing stored variables.
105 | output_grad: Output gradients.
106 | Must be the same shape as the output.
107 |
108 | Returns:
109 | Input gradient of layer normalization.
110 | """
111 | scale_by_weight, add_bias = ctx.scale_by_weight, ctx.add_bias
112 | (flattened_input, mean, inv_std, weight) = ctx.saved_tensors
113 | flattened_output_grad = output_grad.view_as(flattened_input)
114 |
115 | batch_dim, feat_dim = flattened_output_grad.shape
116 | input_grad = torch.empty_like(flattened_output_grad,
117 | dtype=ctx.output_dtype)
118 |
119 | if scale_by_weight:
120 | BLOCK_SIZE_BATCH = BLOCK_SIZE_BATCH_heuristic({'batch_dim': batch_dim,
121 | 'feat_dim': feat_dim})
122 | out_batch_dim = batch_dim // BLOCK_SIZE_BATCH
123 |
124 | weight_grad = torch.empty((out_batch_dim, feat_dim),
125 | device=flattened_input.device)
126 | if add_bias:
127 | bias_grad = torch.empty((out_batch_dim, feat_dim),
128 | device=flattened_input.device)
129 |
130 | else:
131 | bias_grad = None
132 |
133 | else:
134 | weight_grad = bias_grad = None
135 |
136 | # Launches 1D grid where each program operates over BLOCK_SIZE_BATCH rows.
137 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']),)
138 | layer_norm_backward_kernel[grid](flattened_output_grad, flattened_input,
139 | mean, inv_std, weight,
140 | input_grad, weight_grad, bias_grad,
141 | batch_dim, feat_dim,
142 | *flattened_output_grad.stride(),
143 | *flattened_input.stride(),
144 | *input_grad.stride(),
145 | *weight_grad.stride() if scale_by_weight else (1, 1),
146 | *bias_grad.stride() if scale_by_weight and add_bias else (1, 1),
147 | scale_by_weight=scale_by_weight,
148 | add_bias=add_bias)
149 |
150 | if scale_by_weight:
151 | weight_grad = weight_grad.sum(dim=0)
152 | if add_bias:
153 | bias_grad = bias_grad.sum(dim=0)
154 |
155 | # Pads output with None because a gradient is necessary for
156 | # all input arguments.
157 | return input_grad.view_as(output_grad), weight_grad, bias_grad, None, None
158 |
159 |
160 | class LayerNorm(nn.LayerNorm):
161 | """
162 | Layer-normalizes the input.
163 | See also base class.
164 |
165 | Note: During automatic mixed precision training, PyTorch autocasts
166 | the output dtype of layer normalization to fp32. This conversion is not necessary
167 | and retaining the input dtype is generally numerically stable in addition to
168 | being slightly more efficient. Layer normalization in attorch thus offers
169 | an additional argument, autocast_to_fp32, to manually disable this
170 | autocasting behaviour.
171 |
172 | Args:
173 | normalized_shape: Dimensionality of last feature that is normalized.
174 | eps: Epsilon added in the square root in the denominator
175 | to avoid division by zero.
176 | elementwise_affine: Flag for scaling the normalized output by weights.
177 | bias: Flag for adding a bias vector to the normalized output
178 | if elementwise_affine is True.
179 | autocast_to_fp32: Flag for autocasting the output dtype to fp32.
180 | If False, the input dtype flows through.
181 | device: Device to use.
182 | dtype: Dtype of layer.
183 |
184 | Raises:
185 | RuntimeError: Normalized shape was not an integer.
186 | """
187 | def __init__(
188 | self,
189 | normalized_shape: int,
190 | eps: float = 1e-5,
191 | elementwise_affine: bool = True,
192 | bias: bool = True,
193 | autocast_to_fp32: bool = True,
194 | device: Device = 'cuda',
195 | dtype: torch.dtype = torch.float32,
196 | ) -> None:
197 | if not isinstance(normalized_shape, int):
198 | raise RuntimeError('Normalized shape must be an integer.')
199 |
200 | super().__init__(normalized_shape, eps, elementwise_affine, bias,
201 | device, dtype)
202 | self.autocast_to_fp32 = autocast_to_fp32
203 |
204 | def forward(self, input: Tensor) -> Tensor:
205 | return LayerNormAutoGrad.apply(input, self.weight, self.bias, self.eps,
206 | self.autocast_to_fp32)
207 |
--------------------------------------------------------------------------------
/attorch/linear_layer.py:
--------------------------------------------------------------------------------
1 | """
2 | Linear layer with fused activation with PyTorch autodiff support.
3 | """
4 |
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from torch import Tensor
10 | from torch import nn
11 | from triton import cdiv
12 |
13 | from .act_kernels import act_func_backward_kernel
14 | from .linear_kernels import linear_forward_kernel
15 | from .types import Context, Device
16 | from .utils import get_output_dtype
17 |
18 |
19 | class LinearAutoGrad(torch.autograd.Function):
20 | """
21 | Autodiff for linear layer.
22 | """
23 | @staticmethod
24 | def forward(
25 | ctx: Context,
26 | input: Tensor,
27 | weight: Tensor,
28 | bias: Optional[Tensor] = None,
29 | act_func: Optional[str] = None,
30 | ) -> Tensor:
31 | """
32 | Linearly transforms the input using weights, optionally adding bias
33 | and fusing an activation function.
34 |
35 | Args:
36 | input: Input to transform.
37 | Must be of shape [..., in_feat_dim].
38 | weight: Weights input is transformed by.
39 | Must be of shape [in_feat_dim, out_feat_dim].
40 | bias: Optional additive bias vector, with None for no bias.
41 | If provided, must be of shape [out_feat_dim].
42 | act_func: Name of activation function to apply, with None for identity.
43 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
44 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish',
45 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM',
46 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM'
47 | where PARAM stands for the parameter in the case of parameterized
48 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a
49 | negative slope of 0.01).
50 |
51 | Returns:
52 | Input linearly transformed, potentially with added biased and
53 | fused activation.
54 | """
55 | assert weight.ndim == 2, \
56 | f'Weights must be 2D, received shape {weight.shape}'
57 | assert bias is None or bias.ndim == 1, \
58 | f'Bias must be 1D, received shape {bias.shape}'
59 |
60 | assert input.shape[-1] == weight.shape[0], \
61 | f'Incompatible input ({input.shape}) and weights ({weight.shape}) shape'
62 | assert bias is None or weight.shape[1] == bias.shape[0], \
63 | f'Incompatible weights ({weight.shape}) and bias ({bias.shape}) shape'
64 |
65 | param = None
66 | if act_func is not None and '_' in act_func:
67 | comps = act_func.split('_')
68 | act_func = '_'.join(comps[:-1])
69 | param = float(comps[-1])
70 |
71 | flattened_input = input.flatten(0, -2)
72 | batch_dim, in_feat_dim = flattened_input.shape
73 | _, out_feat_dim = weight.shape
74 |
75 | requires_grad = (input.requires_grad or
76 | weight.requires_grad or
77 | (bias is not None and bias.requires_grad))
78 | save_pre_act = requires_grad and (act_func is not None)
79 |
80 | output_dtype = get_output_dtype(input.dtype, autocast='fp16')
81 | output = torch.empty((batch_dim, out_feat_dim),
82 | device=input.device,
83 | dtype=output_dtype)
84 | pre_act = torch.empty_like(output) if save_pre_act else output
85 |
86 | # Launches a 1D grid, where each program outputs blocks of
87 | # BLOCK_SIZE_BATCH rows and BLOCK_SIZE_OUT_FEAT columns.
88 | grid = lambda META: (cdiv(batch_dim, META['BLOCK_SIZE_BATCH']) *
89 | cdiv(out_feat_dim, META['BLOCK_SIZE_OUT_FEAT']),)
90 | linear_forward_kernel[grid](flattened_input, weight,
91 | input if bias is None else bias,
92 | pre_act, output,
93 | batch_dim, in_feat_dim, out_feat_dim,
94 | *flattened_input.stride(), *weight.stride(),
95 | *pre_act.stride(), *output.stride(), param,
96 | add_bias=bias is not None, act_func=act_func,
97 | save_pre_act=save_pre_act,
98 | fp16=output_dtype is torch.float16)
99 |
100 | ctx.param = param
101 | ctx.act_func = act_func
102 | ctx.bias_requires_grad = False if bias is None else bias.requires_grad
103 | ctx.output_dtype = output_dtype
104 | if requires_grad:
105 | ctx.save_for_backward(input, pre_act if save_pre_act else None, weight)
106 |
107 | return output.view(*input.shape[:-1], out_feat_dim)
108 |
109 | @staticmethod
110 | def backward(
111 | ctx: Context,
112 | output_grad: Tensor,
113 | ) -> Tuple[Optional[Tensor], ...]:
114 | """
115 | Calculates the input gradient of the linear layer.
116 |
117 | Args:
118 | ctx: Context containing stored variables.
119 | output_grad: Output gradients.
120 | Must be the same shape as the output.
121 |
122 | Returns:
123 | Input gradient of the linear layer.
124 | """
125 | input, pre_act, weight = ctx.saved_tensors
126 |
127 | output_grad = output_grad.flatten(0, -2)
128 | flattened_input = input.flatten(0, -2)
129 |
130 | batch_dim, _ = flattened_input.shape
131 | _, out_feat_dim = weight.shape
132 |
133 | if ctx.act_func is None:
134 | pre_act_grad = output_grad
135 |
136 | else:
137 | size = batch_dim * out_feat_dim
138 | pre_act_grad = torch.empty(size, dtype=pre_act.dtype,
139 | device=pre_act.device)
140 |
141 | # Launches 1D grid where each program operates over
142 | # BLOCK_SIZE elements.
143 | grid = lambda META: (cdiv(size, META['BLOCK_SIZE']),)
144 | act_func_backward_kernel[grid](output_grad, pre_act, pre_act_grad,
145 | size, None, None, ctx.param,
146 | ctx.act_func, False)
147 |
148 | pre_act_grad = pre_act_grad.view_as(pre_act)
149 |
150 | # Using PyTorch's matmul, but linear_forward_kernel
151 | # could have also been used.
152 | with torch.autocast('cuda', dtype=ctx.output_dtype):
153 | input_grad = pre_act_grad @ weight.T if input.requires_grad else None
154 | weight_grad = (flattened_input.T @ pre_act_grad
155 | if weight.requires_grad else None)
156 | bias_grad = pre_act_grad.sum(dim=0) if ctx.bias_requires_grad else None
157 |
158 | # Pads output with None because a gradient is necessary for
159 | # all input arguments.
160 | return (input_grad.view_as(input) if input_grad is not None else None,
161 | weight_grad, bias_grad, None)
162 |
163 |
164 | class Linear(nn.Linear):
165 | """
166 | Linearly transforms the input using weights, optionally adding bias
167 | and fusing an activation function.
168 | See also base class.
169 |
170 | Note: Unlike PyTorch's linear layer, the weight matrix in this module is
171 | of shape [in_features, out_features] instead of [out_features, in_features].
172 | This may cause unexpected issues when manipulating the weights (e.g., porting
173 | parameters, initializing them, and so forth).
174 |
175 | Args:
176 | in_features: Number of input features.
177 | out_features: Number of output features.
178 | bias: Flag for additive bias.
179 | act_func: Name of activation function to apply, with None for identity.
180 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
181 | 'relu6', 'hardsigmoid', 'hardtanh', 'hardswish', 'selu', 'mish',
182 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu_PARAM',
183 | 'elu_PARAM', 'celu_PARAM', 'hardshrink_PARAM', and 'softshrink_PARAM'
184 | where PARAM stands for the parameter in the case of parameterized
185 | activation functions (e.g., 'leaky_relu_0.01' for leaky ReLU with a
186 | negative slope of 0.01).
187 | device: Device to use.
188 | dtype: Dtype of layer.
189 | """
190 | def __init__(
191 | self,
192 | in_features: int,
193 | out_features: int,
194 | bias: bool = True,
195 | act_func: Optional[str] = None,
196 | device: Device = 'cuda',
197 | dtype: torch.dtype = torch.float32,
198 | ) -> None:
199 | super().__init__(in_features, out_features, bias, device, dtype)
200 | self.weight = nn.Parameter(self.weight.T.contiguous())
201 | self.act_func = act_func
202 |
203 | def forward(self, input: Tensor) -> Tensor:
204 | return LinearAutoGrad.apply(input, self.weight, self.bias,
205 | self.act_func)
206 |
--------------------------------------------------------------------------------
/attorch/linear_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Kernels for linear layer with fused activation.
3 | """
4 |
5 |
6 | import triton
7 | import triton.language as tl
8 |
9 | from .act_kernels import apply_act_func
10 | from .utils import allow_tf32, get_n_stages
11 |
12 |
13 | def linear_forward_config(
14 | BLOCK_SIZE_BATCH: int,
15 | BLOCK_SIZE_IN_FEAT: int,
16 | BLOCK_SIZE_OUT_FEAT: int,
17 | GROUP_SIZE_BATCH: int = 8,
18 | n_warps: int = 4,
19 | n_stages: int = 2,
20 | ) -> triton.Config:
21 | """
22 | Creates a triton.Config object for linear_forward_kernel
23 | given meta-parameters for auto-tuning.
24 |
25 | Args:
26 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
27 | BLOCK_SIZE_IN_FEAT: Block size across the input feature dimension.
28 | BLOCK_SIZE_OUT_FEAT: Block size across the output feature dimension.
29 | GROUP_SIZE_BATCH: Group size across the batch dimension.
30 | n_warps: Number of warps to use for the kernel when compiled for GPUs.
31 | n_stages: Number of stages the compiler uses to software-pipeline.
32 | On GPU architectures older than Ampere, this is fixed at 2.
33 |
34 | Returns:
35 | Kernel configuration.
36 | """
37 | return triton.Config({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH,
38 | 'BLOCK_SIZE_IN_FEAT': BLOCK_SIZE_IN_FEAT,
39 | 'BLOCK_SIZE_OUT_FEAT': BLOCK_SIZE_OUT_FEAT,
40 | 'GROUP_SIZE_BATCH': GROUP_SIZE_BATCH},
41 | num_warps=n_warps, num_stages=get_n_stages(n_stages))
42 |
43 |
44 | @triton.autotune(
45 | configs=[
46 | linear_forward_config(32, 32, 32, n_warps=2, n_stages=2),
47 | linear_forward_config(64, 32, 32, n_warps=2, n_stages=5),
48 | linear_forward_config(64, 32, 128, n_warps=4, n_stages=4),
49 | linear_forward_config(64, 32, 256, n_warps=4, n_stages=4),
50 | linear_forward_config(128, 32, 32, n_warps=4, n_stages=4),
51 | linear_forward_config(128, 32, 64, n_warps=4, n_stages=4),
52 | linear_forward_config(128, 32, 128, n_warps=4, n_stages=4),
53 | linear_forward_config(128, 64, 256, n_warps=8, n_stages=3),
54 | ],
55 | key=['batch_dim', 'in_feat_dim', 'out_feat_dim', 'fp16'],
56 | )
57 | @triton.heuristics({'tf32': lambda _: allow_tf32()})
58 | @triton.jit
59 | def linear_forward_kernel(
60 | input_pointer, weight_pointer, bias_pointer, pre_act_pointer, output_pointer,
61 | batch_dim, in_feat_dim, out_feat_dim,
62 | input_batch_stride, input_in_feat_stride,
63 | weight_in_feat_stride, weight_out_feat_stride,
64 | pre_act_batch_stride, pre_act_out_feat_stride,
65 | output_batch_stride, output_out_feat_stride, param,
66 | add_bias: tl.constexpr, act_func: tl.constexpr, save_pre_act: tl.constexpr,
67 | fp16: tl.constexpr, tf32: tl.constexpr,
68 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_IN_FEAT: tl.constexpr,
69 | BLOCK_SIZE_OUT_FEAT: tl.constexpr, GROUP_SIZE_BATCH: tl.constexpr,
70 | ):
71 | """
72 | Linearly transforms the input using weights, optionally adding bias
73 | and fusing an activation function.
74 |
75 | Args:
76 | input_pointer: Pointer to the input to transform.
77 | The input must be of shape [batch_dim, in_feat_dim].
78 | weight_pointer: Pointer to the weights input is transformed by.
79 | The weights must be of shape [in_feat_dim, out_feat_dim].
80 | bias_pointer: Pointer to an optional additive bias vector.
81 | The bias vector, if provided, must be of shape [out_feat_dim].
82 | pre_act_pointer: Pointer to an optional container the pre-activation input
83 | is written to if act_func is not None and save_pre_act is True.
84 | The container, if provided, must be of shape [batch_dim, out_feat_dim].
85 | output_pointer: Pointer to a container the result is written to.
86 | The container must be of shape [batch_dim, out_feat_dim].
87 | batch_dim: Batch dimension of the input and output.
88 | in_feat_dim: Dimensionality of the input features.
89 | out_feat_dim: Dimensionality of the output features.
90 | input_batch_stride: Stride necessary to jump one element along the
91 | input's batch dimension.
92 | input_in_feat_stride: Stride necessary to jump one element along the
93 | input's feature dimension.
94 | weight_in_feat_stride: Stride necessary to jump one element along the
95 | weights' input feature dimension.
96 | weight_out_feat_stride: Stride necessary to jump one element along the
97 | weights' output feature dimension.
98 | pre_act_batch_stride: Stride necessary to jump one element along the
99 | pre-activation input container's batch dimension.
100 | pre_act_out_feat_stride: Stride necessary to jump one element along the
101 | pre-activation input container's feature dimension.
102 | output_batch_stride: Stride necessary to jump one element along the
103 | output container's batch dimension.
104 | output_out_feat_stride: Stride necessary to jump one element along the
105 | output container's feature dimension.
106 | param: Parameter in the case of parameterized activation functions.
107 | add_bias: Flag for adding a bias vector.
108 | act_func: Name of activation function to apply, with None for identity.
109 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
110 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', 'hardshrink',
111 | and 'softshrink'.
112 | save_pre_act: Flag for saving the pre-activation input.
113 | fp16: Flag for loading the input, weights, and bias in FP16.
114 | tf32: Flag for performing matrix products in TF32.
115 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
116 | BLOCK_SIZE_IN_FEAT: Block size across the input feature dimension.
117 | BLOCK_SIZE_OUT_FEAT: Block size across the output feature dimension.
118 | GROUP_SIZE_BATCH: Group size across the batch dimension.
119 | """
120 | # Programs are blocked together, GROUP_SIZE_BATCH rows at a time,
121 | # to alleviate L2 miss rates.
122 | pid = tl.program_id(axis=0)
123 | n_batch_pids = tl.cdiv(batch_dim, BLOCK_SIZE_BATCH)
124 | n_out_feat_pids = tl.cdiv(out_feat_dim, BLOCK_SIZE_OUT_FEAT)
125 | pids_per_group = GROUP_SIZE_BATCH * n_out_feat_pids
126 | group_id = pid // pids_per_group
127 | first_batch_pid = group_id * GROUP_SIZE_BATCH
128 | GROUP_SIZE_BATCH = min(n_batch_pids - first_batch_pid, GROUP_SIZE_BATCH)
129 | batch_pid = first_batch_pid + (pid % GROUP_SIZE_BATCH)
130 | out_feat_pid = (pid % pids_per_group) // GROUP_SIZE_BATCH
131 |
132 | batch_offset = (batch_pid * BLOCK_SIZE_BATCH +
133 | tl.arange(0, BLOCK_SIZE_BATCH))
134 | out_feat_offset = (out_feat_pid * BLOCK_SIZE_OUT_FEAT +
135 | tl.arange(0, BLOCK_SIZE_OUT_FEAT))
136 |
137 | batch_mask = batch_offset < batch_dim
138 | out_feat_mask = out_feat_offset < out_feat_dim
139 |
140 | input_pointer += input_batch_stride * batch_offset[:, None]
141 | weight_pointer += weight_out_feat_stride * out_feat_offset[None, :]
142 |
143 | accum = tl.zeros((BLOCK_SIZE_BATCH, BLOCK_SIZE_OUT_FEAT),
144 | dtype=tl.float32)
145 |
146 | for block_ind in range(0, tl.cdiv(in_feat_dim, BLOCK_SIZE_IN_FEAT)):
147 | in_feat_offset = (block_ind * BLOCK_SIZE_IN_FEAT +
148 | tl.arange(0, BLOCK_SIZE_IN_FEAT))
149 | in_feat_mask = in_feat_offset < in_feat_dim
150 |
151 | curr_input_pointer = (input_pointer +
152 | input_in_feat_stride * in_feat_offset[None, :])
153 | curr_weight_pointer = (weight_pointer +
154 | weight_in_feat_stride * in_feat_offset[:, None])
155 |
156 | input_block = tl.load(curr_input_pointer,
157 | mask=batch_mask[:, None] & in_feat_mask[None, :])
158 | weight_block = tl.load(curr_weight_pointer,
159 | mask=out_feat_mask[None, :] & in_feat_mask[:, None])
160 |
161 | if fp16:
162 | input_block = input_block.to(tl.float16)
163 | weight_block = weight_block.to(tl.float16)
164 |
165 | accum += tl.dot(input_block, weight_block, allow_tf32=tf32)
166 |
167 | if add_bias:
168 | bias = tl.load(bias_pointer + out_feat_offset,
169 | mask=out_feat_mask)
170 |
171 | if fp16:
172 | bias = bias.to(tl.float16)
173 |
174 | accum += bias[None, :]
175 |
176 | if act_func is not None:
177 | if save_pre_act:
178 | pre_act_pointer += (pre_act_batch_stride * batch_offset[:, None] +
179 | pre_act_out_feat_stride * out_feat_offset[None, :])
180 | tl.store(pre_act_pointer, accum,
181 | mask=batch_mask[:, None] & out_feat_mask[None, :])
182 |
183 | accum = apply_act_func(accum, None, None, None, param, act_func, False)
184 |
185 | output_pointer += (output_batch_stride * batch_offset[:, None] +
186 | output_out_feat_stride * out_feat_offset[None, :])
187 | tl.store(output_pointer, accum,
188 | mask=batch_mask[:, None] & out_feat_mask[None, :])
189 |
--------------------------------------------------------------------------------
/attorch/rms_norm_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | Kernels for root mean square normalization.
3 | """
4 |
5 |
6 | import triton
7 | import triton.language as tl
8 | from triton import next_power_of_2
9 |
10 | from .softmax_kernels import BLOCK_SIZE_BATCH_heuristic
11 | from .utils import warps_kernel_configs
12 |
13 |
14 | @triton.autotune(
15 | configs=warps_kernel_configs(),
16 | key=['batch_dim', 'feat_dim'],
17 | )
18 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,
19 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})
20 | @triton.jit
21 | def rms_norm_forward_kernel(
22 | input_pointer, weight_pointer,
23 | inv_rms_pointer, output_pointer,
24 | batch_dim, feat_dim,
25 | input_batch_stride, input_feat_stride,
26 | output_batch_stride, output_feat_stride,
27 | eps,
28 | scale_by_weight: tl.constexpr, save_stats: tl.constexpr,
29 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,
30 | ):
31 | """
32 | Root-mean-square-normalizes the input.
33 |
34 | Args:
35 | input_pointer: Pointer to the input to root-mean-square-normalize.
36 | The input must be of shape [batch_dim, feat_dim].
37 | weight_pointer: Pointer to optional weights for linear transform.
38 | The weights, if provided, must be of shape [feat_dim].
39 | inv_rms_pointer: Pointer to an optional container the input's inverse
40 | root mean square is written to if save_stats is True.
41 | The container, if provided, must be of shape [batch_dim].
42 | output_pointer: Pointer to a container the result is written to.
43 | The container must be of shape [batch_dim, feat_dim].
44 | batch_dim: Batch dimension.
45 | feat_dim: Dimensionality of the features.
46 | input_batch_stride: Stride necessary to jump one element along the
47 | input's batch dimension.
48 | input_feat_stride: Stride necessary to jump one element along the
49 | input's feature dimension.
50 | output_batch_stride: Stride necessary to jump one element along the
51 | output container's batch dimension.
52 | output_feat_stride: Stride necessary to jump one element along the
53 | output container's feature dimension.
54 | eps: Epsilon added in the square root in the denominator
55 | to avoid division by zero.
56 | scale_by_weight: Flag for scaling the normalized output by weights.
57 | save_stats: Flag for saving the root mean square.
58 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
59 | BLOCK_SIZE_FEAT: Block size across the feature dimension.
60 | """
61 | # This program processes BLOCK_SIZE_BATCH rows and BLOCK_SIZE_FEAT columns.
62 | batch_pid = tl.program_id(axis=0)
63 |
64 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
65 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)
66 |
67 | batch_mask = batch_offset < batch_dim
68 | feat_mask = feat_offset < feat_dim
69 |
70 | input_pointer += (input_batch_stride * batch_offset[:, None] +
71 | input_feat_stride * feat_offset[None, :])
72 | output_pointer += (output_batch_stride * batch_offset[:, None] +
73 | output_feat_stride * feat_offset[None, :])
74 |
75 | input = tl.load(input_pointer,
76 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)
77 | inv_rms = tl.rsqrt(tl.sum(input * input, axis=1) / feat_dim + eps)
78 | output = input * inv_rms[:, None]
79 |
80 | if save_stats:
81 | tl.store(inv_rms_pointer + batch_offset, inv_rms, mask=batch_mask)
82 |
83 | if scale_by_weight:
84 | weight = tl.load(weight_pointer + feat_offset, mask=feat_mask)
85 | output *= weight
86 |
87 | tl.store(output_pointer, output,
88 | mask=batch_mask[:, None] & feat_mask[None, :])
89 |
90 |
91 | @triton.autotune(
92 | configs=warps_kernel_configs(),
93 | key=['batch_dim', 'feat_dim'],
94 | )
95 | @triton.heuristics({'BLOCK_SIZE_BATCH': BLOCK_SIZE_BATCH_heuristic,
96 | 'BLOCK_SIZE_FEAT': lambda args: next_power_of_2(args['feat_dim'])})
97 | @triton.jit
98 | def rms_norm_backward_kernel(
99 | output_grad_pointer, input_pointer, inv_rms_pointer, weight_pointer,
100 | input_grad_pointer, weight_grad_pointer,
101 | batch_dim, feat_dim,
102 | output_grad_batch_stride, output_grad_feat_stride,
103 | input_batch_stride, input_feat_stride,
104 | input_grad_batch_stride, input_grad_feat_stride,
105 | weight_grad_batch_stride, weight_grad_feat_stride,
106 | scale_by_weight: tl.constexpr,
107 | BLOCK_SIZE_BATCH: tl.constexpr, BLOCK_SIZE_FEAT: tl.constexpr,
108 | ):
109 | """
110 | Calculates the input gradient of root mean square normalization.
111 |
112 | Args:
113 | output_grad_pointer: Pointer to root mean square normalization's output gradients.
114 | The output gradients must be of shape [batch_dim, feat_dim].
115 | input_pointer: Pointer to the input.
116 | The input must be of shape [batch_dim, feat_dim].
117 | inv_rms_pointer: Pointer to the input's inverse root mean square.
118 | The inverse root mean square should be of shape [batch_dim].
119 | weight_pointer: Pointer to optional weights if affine transform occurred.
120 | The weights, if provided, must be of shape [feat_dim].
121 | input_grad_pointer: Pointer to a container the input's gradients are written to.
122 | The container must be of shape [batch_dim, feat_dim].
123 | weight_grad_pointer: Pointer to an optional container the weights' row-wise gradients
124 | are written to if scale_by_weight is True, which should later be summed.
125 | The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim].
126 | bias_grad_pointer: Pointer to an optional container the bias vector's row-wise gradients
127 | are written to if scale_by_weight and add_bias are True, which should later be summed.
128 | The container, if provided, must be of shape [batch_dim/BLOCK_SIZE_BATCH, feat_dim].
129 | batch_dim: Batch dimension.
130 | feat_dim: Dimensionality of the features.
131 | output_grad_batch_stride: Stride necessary to jump one element along the
132 | output gradients' batch dimension.
133 | output_grad_feat_stride: Stride necessary to jump one element along the
134 | output gradients' feature dimension.
135 | input_batch_stride: Stride necessary to jump one element along the
136 | input's batch dimension.
137 | input_feat_stride: Stride necessary to jump one element along the
138 | input's feature dimension.
139 | input_grad_batch_stride: Stride necessary to jump one element along the
140 | input gradient container's batch dimension.
141 | input_grad_feat_stride: Stride necessary to jump one element along the
142 | input gradient container's feature dimension.
143 | weight_grad_batch_stride: Stride necessary to jump one element along the
144 | weight gradient container's batch dimension.
145 | weight_grad_feat_stride: Stride necessary to jump one element along the
146 | weight gradient container's feature dimension.
147 | scale_by_weight: Flag for scaling the normalized output by weights.
148 | BLOCK_SIZE_BATCH: Block size across the batch dimension.
149 | BLOCK_SIZE_FEAT: Block size across the feature dimension.
150 | """
151 | # This program processes a single row and BLOCK_SIZE_FEAT columns.
152 | batch_pid = tl.program_id(axis=0)
153 |
154 | batch_offset = batch_pid * BLOCK_SIZE_BATCH + tl.arange(0, BLOCK_SIZE_BATCH)
155 | feat_offset = tl.arange(0, BLOCK_SIZE_FEAT)
156 |
157 | batch_mask = batch_offset < batch_dim
158 | feat_mask = feat_offset < feat_dim
159 |
160 | output_grad_pointer += (output_grad_batch_stride * batch_offset[:, None] +
161 | output_grad_feat_stride * feat_offset[None, :])
162 | input_pointer += (input_batch_stride * batch_offset[:, None] +
163 | input_feat_stride * feat_offset[None, :])
164 | input_grad_pointer += (input_grad_batch_stride * batch_offset[:, None] +
165 | input_grad_feat_stride * feat_offset[None, :])
166 |
167 | output_grad = tl.load(output_grad_pointer,
168 | mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)
169 | input = tl.load(input_pointer, mask=batch_mask[:, None] & feat_mask[None, :]).to(tl.float32)
170 | inv_rms = tl.load(inv_rms_pointer + batch_offset, mask=batch_mask)
171 | pre_lin = input * inv_rms[:, None]
172 |
173 | if scale_by_weight:
174 | weight = tl.load(weight_pointer + feat_offset, mask=feat_mask)
175 | weight_output_grad_prod = weight * output_grad
176 |
177 | else:
178 | weight_output_grad_prod = output_grad
179 |
180 | term1 = input * tl.sum(input * weight_output_grad_prod, axis=1)
181 | term2 = inv_rms[:, None] * inv_rms[:, None]
182 | input_grad = (inv_rms[:, None] *
183 | (weight_output_grad_prod - term1 * term2 / feat_dim))
184 |
185 | tl.store(input_grad_pointer, input_grad,
186 | mask=batch_mask[:, None] & feat_mask[None, :])
187 |
188 | if scale_by_weight:
189 | weight_grad_pointer += (weight_grad_batch_stride * batch_pid +
190 | weight_grad_feat_stride * feat_offset)
191 | tl.store(weight_grad_pointer,
192 | tl.sum(output_grad * pre_lin, axis=0),
193 | mask=feat_mask)
194 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # attorch
2 |
3 | • **[Introduction](#introduction)**
4 | • **[Installation](#installation)**
5 | • **[Layers](#layers)**
6 | • **[Math Functions](#math-functions)**
7 | • **[PyTorch Fallback](#pytorch-fallback)**
8 | • **[Tests](#tests)**
9 | • **[Examples](#examples)**
10 | • **[Citations](#citations)**
11 |
12 | ## Introduction
13 |
14 | attorch is a subset of PyTorch's [```nn```](https://pytorch.org/docs/stable/nn.html) module, written purely in Python using OpenAI's [Triton](https://github.com/openai/triton). Its goal is to be an easily hackable, self-contained, and readable collection of neural network modules whilst maintaining or improving upon the efficiency of PyTorch. In other words, it intends to be a forkable project endowed with a simple, intuitive design that can serve as an accessible starting point for those who are seeking to develop custom deep learning operations but are not satisfied with the speed of a pure PyTorch implementation and do not have the technical expertise or resources to write CUDA kernels.
15 |
16 | There already exist a number of wonderful PyTorch-like frameworks powered by Triton, including [kernl](https://github.com/ELS-RD/kernl/tree/main), [xFormers](https://github.com/facebookresearch/xformers), [Unsloth](https://github.com/unslothai/unsloth), and [```fla```](https://github.com/sustcsonglin/flash-linear-attention), but most concentrate mainly on Transformers and NLP applications, whereas attorch aims to be more inclusive by also presenting a variety of layers pertaining to areas beyond NLP, such as computer vision. Moreover, attorch is not an inference-only package and fully supports both forward and backward passes, meaning it can be used during training as well as inference, though its performance for the latter is generally not on par with dedicated inference engines.
17 |
18 | ## Installation
19 |
20 | The only dependencies of attorch are ```torch==2.4.0``` and ```triton==3.0.0```. Please install the specified versions of these two libraries and clone this repository to get started.
21 |
22 | ## Layers
23 |
24 | Currently implemented layers, with automatic mixed precision (AMP) support, are:
25 |
26 | * ```attorch.Conv1d```: 1D-convolves over the input using weights, optionally adding bias.
27 | * ```attorch.Conv2d```: 2D-convolves over the input using weights, optionally adding bias.
28 | * ```attorch.AvgPool1d```: Averages 1D windows of pixels in the input.
29 | * ```attorch.AvgPool2d```: Averages 2D windows of pixels in the input.
30 | * ```attorch.MultiheadAttention```: Applies multi-headed scaled dot-product attention to the inputs.
31 | * ```attorch.ELU```: Applies ELU to the input, optionally fusing dropout.
32 | * ```attorch.Hardshrink```: Applies hard shrink to the input, optionally fusing dropout.
33 | * ```attorch.Hardsigmoid```: Applies hard sigmoid to the input, optionally fusing dropout.
34 | * ```attorch.Hardswish```: Applies hard Swish to the input, optionally fusing dropout.
35 | * ```attorch.Hardtanh```: Applies hard tanh to the input, optionally fusing dropout.
36 | * ```attorch.LeakyReLU```: Applies leaky ReLU to the input, optionally fusing dropout.
37 | * ```attorch.LogSigmoid```: Applies the log of sigmoid to the input, optionally fusing dropout.
38 | * ```attorch.CELU```: Applies CELU to the input, optionally fusing dropout.
39 | * ```attorch.GELU```: Applies GELU to the input, optionally fusing dropout.
40 | * ```attorch.ReLU```: Applies ReLU to the input, optionally fusing dropout.
41 | * ```attorch.ReLU6```: Applies ReLU6 to the input, optionally fusing dropout.
42 | * ```attorch.SELU```: Applies SELU to the input, optionally fusing dropout.
43 | * ```attorch.SiLU```: Applies SiLU to the input, optionally fusing dropout.
44 | * ```attorch.Mish```: Applies Mish to the input, optionally fusing dropout.
45 | * ```attorch.Softplus```: Applies softplus to the input, optionally fusing dropout.
46 | * ```attorch.Softshrink```: Applies softshrink to the input, optionally fusing dropout.
47 | * ```attorch.Softsign```: Applies softsign to the input, optionally fusing dropout.
48 | * ```attorch.Sigmoid```: Applies sigmoid to the input, optionally fusing dropout.
49 | * ```attorch.Tanh```: Applies tanh to the input, optionally fusing dropout.
50 | * ```attorch.Tanhshrink```: Applies tanh to the input, optionally fusing dropout.
51 | * ```attorch.GLU```: Applies the gated linear unit with an arbitrary activation function to the input.
52 | * ```attorch.LogSoftmax```: Normalizes the input using softmax and takes its log.
53 | * ```attorch.Softmax```: Normalizes the input using softmax.
54 | * ```attorch.Softmin```: Normalizes the input using softmin.
55 | * ```attorch.BatchNorm1d```: Batch-normalizes the 2D or 3D input, optionally fusing an activation function and adding a residual to the pre-activation result.
56 | * ```attorch.BatchNorm2d```: Batch-normalizes the 4D input, optionally fusing an activation function and adding a residual to the pre-activation result.
57 | * ```attorch.LayerNorm```: Layer-normalizes the input.
58 | * ```attorch.RMSNorm```: Root-mean-square-normalizes the input.
59 | * ```attorch.Linear```: Linearly transforms the input using weights, optionally adding bias and fusing an activation function.
60 | * ```attorch.Dropout```: Randomly zeroes elements in the input during training.
61 | * ```attorch.L1Loss```: Measures the L1 error (mean absolute error) between the input and target.
62 | * ```attorch.MSELoss```: Measures the squared L2 error (mean squared error) between the input and target.
63 | * ```attorch.CrossEntropyLoss```: Measures the mean cross entropy loss between the input and target, with optional reweighting of each class.
64 | * ```attorch.NLLLoss```: Measures the negative log likelihood loss between the input and target, with optional reweighting of each class.
65 | * ```attorch.HuberLoss```: Measures the Huber loss between the input and target.
66 | * ```attorch.SmoothL1Loss```: Measures the smooth L1 error between the input and target.
67 |
68 | Unless otherwise noted in their docstrings, the aforementioned layers behave identically to their PyTorch equivalents.
69 |
70 | ## Math Functions
71 | Triton kernels are generally composed of two parts: One handles the loading and storing of the relevant tensors, the other transforms the data using appropriate mathematical functions. For instance, a layer normalization kernel reads one or several rows from the input (load), standardizes the features (math), and writes the results into a container (store). A selection of these pure math functions is supplied by ```attorch.math```, the objective being to facilitate the implementation of custom kernels and operation fusion. Although only the forward passes of the said functions are available in ```attorch.math```, thanks to their purity and absence of I/O actions, their gradients can be automatically derived via the [```triton-autodiff```](https://github.com/srush/triton-autodiff) library. Significant portions of attorch's kernels can be refactored by supplanting their math bits with the corresponding ```attorch.math``` transformations or their derivatives, but doing so would sacrifice the single-file and self-contained design of attorch, so ```attorch.math``` and the rest of the library will remain separate.
72 |
73 | ## PyTorch Fallback
74 |
75 | To enable easier integration of attorch and PyTorch layers, ```attorch.nn``` is offered, which provides an interface to attorch's modules with PyTorch fallback should a desired layer not be available, as seen below:
76 |
77 | ```python
78 | from attorch import nn
79 |
80 |
81 | lin = nn.Linear(10, 20) # Uses attorch's linear layer
82 | gap = nn.AdaptiveAvgPool2d(1) # Uses PyTorch's global pooling since GAP is not available in attorch
83 | ```
84 |
85 | Even though attorch comes with convolutional and pooling layers, the performance of these modules is extremely slow compared to PyTorch. Therefore, ```attorch.nn``` exposes PyTorch's, and not attorch's, convolutions and pools.
86 |
87 | ## Tests
88 |
89 | Each module can be tested against its PyTorch counterpart to ensure correctness. These tests are included under ```tests/``` and can be executed using ```pytest```. A switch, `--subset`, is provided that runs the tests on a smaller subset of data shapes for faster yet less thorough evaluation. It should be noted that some tests may fail owing to numerical precision issues, but in most practical use cases, that should not be a problem.
90 |
91 | ## Examples
92 |
93 | [`examples/`](https://github.com/BobMcDear/attorch/tree/main/examples) contains a handful of common deep learning workflows implemented in both attorch and PyTorch. Although attorch can be used as drop-in replacements for `torch.nn`, some modules offer optional kernel fusion and therefore require additional arguments that are not compatible with their PyTorch analogs. Such distinctions are illustrated in these examples.
94 |
95 | ## Citations
96 |
97 | ```bibtex
98 | @inproceedings{tillet2019triton,
99 | title={Triton: an intermediate language and compiler for tiled neural network computations},
100 | author={Tillet, Philippe and Kung, Hsiang-Tsung and Cox, David},
101 | booktitle={Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
102 | pages={10--19},
103 | year={2019}
104 | }
105 | ```
106 |
107 | ```bibtex
108 | @Misc{xFormers2022,
109 | author = {Benjamin Lefaudeux and Francisco Massa and Diana Liskovich and Wenhan Xiong and Vittorio Caggiano and Sean Naren and Min Xu and Jieru Hu and Marta Tintore and Susan Zhang and Patrick Labatut and Daniel Haziza and Luca Wehrstedt and Jeremy Reizenstein and Grigory Sizov},
110 | title = {xFormers: A modular and hackable Transformer modelling library},
111 | howpublished = {\url{https://github.com/facebookresearch/xformers}},
112 | year = {2022}
113 | }
114 | ```
115 |
116 | ```bibtex
117 | @software{yang2024fla,
118 | title = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
119 | author = {Yang, Songlin and Zhang, Yu},
120 | url = {https://github.com/fla-org/flash-linear-attention},
121 | month = jan,
122 | year = {2024}
123 | }
124 | ```
125 |
--------------------------------------------------------------------------------
/attorch/math.py:
--------------------------------------------------------------------------------
1 | """
2 | Pure math operations to be performed on loaded Triton tensors.
3 | """
4 |
5 |
6 | import triton
7 | import triton.language as tl
8 |
9 | from .act_kernels import apply_act_func
10 |
11 |
12 | @triton.jit
13 | def accum_linear(accum, input1, input2,
14 | fp16: tl.constexpr, tf32: tl.constexpr):
15 | """
16 | Accumulates matrix multiplications of input tensors for linear functions.
17 |
18 | Args:
19 | accum: Accumulator holding aggregation of matrix multiplications.
20 | The accumulator must be of shape [BLOCK_SIZE1, BLOCK_SIZE3].
21 | input1: First operand of matrix multiplication.
22 | The operand must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].
23 | input2: Second operand of matrix multiplication.
24 | The operand must be of shape [BLOCK_SIZE2, BLOCK_SIZE3].
25 | fp16: Flag for converting operands to FP16.
26 | tf32: Flag for performing matrix multiplication in TF32.
27 |
28 | Returns:
29 | Accumulator with the result of the new matrix multiplication added to it.
30 | """
31 | if fp16:
32 | input1 = input1.to(tl.float16)
33 | input2 = input2.to(tl.float16)
34 |
35 | return accum + tl.dot(input1, input2, allow_tf32=tf32)
36 |
37 |
38 | @triton.jit
39 | def glu(input1, input2, param, act_func: tl.constexpr):
40 | """
41 | Applies the gated linear unit with an arbitrary activation function
42 | to the input.
43 |
44 | Args:
45 | input1: First half of input to gate.
46 | The first half must be of the same shape as the second half.
47 | input2: Second half of input to gate.
48 | The second half must be of the same shape as the first half.
49 | param: Parameter in the case of parameterized activation functions.
50 | act_func: Name of activation function to apply.
51 | Options are 'sigmoid', 'logsigmoid', 'tanh', 'relu', 'gelu', 'geluapprox', 'silu',
52 | 'softplus', 'softsign', 'tanhshrink', 'leaky_relu', 'elu', 'celu', 'hardshrink',
53 | and 'softshrink'.
54 | param: Parameter in the case of parameterized activation functions.
55 |
56 | Args:
57 | Input transformed by the gated linear unit
58 | with an arbitrary activation function.
59 | """
60 | return input1 * apply_act_func(input2, None, None, None, param, act_func, False)
61 |
62 |
63 | @triton.jit
64 | def softmax(input,
65 | neg: tl.constexpr,
66 | log: tl.constexpr):
67 | """
68 | Normalizes the input using softmax along the last dimension.
69 |
70 | Args:
71 | input: Input to normalize.
72 | The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].
73 | neg: Flag indicating if the input should be negated to get softmin.
74 | log: Flag indicating if the log of softmax should be taken.
75 |
76 | Returns:
77 | Input normalized by softmax.
78 | """
79 | input = input.to(tl.float32)
80 | if neg:
81 | input = -input
82 |
83 | input = input - tl.max(input, axis=1)[:, None]
84 | numerator = tl.exp(input)
85 | denominator = tl.sum(numerator, axis=1)[:, None]
86 |
87 | if log:
88 | output = input - tl.log(denominator)
89 |
90 | else:
91 | output = numerator / denominator
92 |
93 | return output
94 |
95 |
96 | @triton.jit
97 | def calc_mean_and_inv_std(input, last_dim, eps,
98 | last_dim_mask: tl.constexpr):
99 | """
100 | Calculates the mean and inverse standard deviation of the input
101 | along the last dimension.
102 |
103 | Args:
104 | input: Input whose mean and inverse standard deviation are calculated.
105 | The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].
106 | last_dim: Size of the last dimension of input.
107 | eps: Epsilon added in the square root in the denominator
108 | to avoid division by zero.
109 | last_dim_mask: Mask for the last dimension indicating
110 | which elements should be included in the calculations.
111 | The mask must be of shape [BLOCK_SIZE2].
112 |
113 | Returns:
114 | Mean and inverse standard deviation of the input.
115 | """
116 | input = input.to(tl.float32)
117 |
118 | mean = tl.sum(input, axis=1) / last_dim
119 | diff = tl.where(last_dim_mask[None, :], input - mean[:, None], 0)
120 | inv_std = tl.rsqrt(tl.sum(diff * diff, axis=1) / last_dim + eps)
121 |
122 | return mean, inv_std
123 |
124 |
125 | @triton.jit
126 | def update_welford(input, prev_count, prev_mean, prev_var, curr_count,
127 | mask: tl.constexpr):
128 | """
129 | Updates count, mean, and variance (M2) statistics for Welford's algorithm.
130 |
131 | Args:
132 | input: Input used to update statistics.
133 | The input must be of the same shape as the mask.
134 | prev_count: Previous count statistic to update.
135 | prev_mean: Previous mean statistic to update.
136 | prev_var: Previous variance (M2) statistic to update.
137 | curr_count: Count of elements in current input.
138 | mask: Mask indicating which elements should be included in the calculations.
139 | The mask must be of the same shape as the input.
140 |
141 | Returns:
142 | Updated count, mean, and variance (M2) statistics
143 | """
144 | input = input.to(tl.float32)
145 |
146 | count = prev_count + curr_count
147 | mean = (tl.sum(input) - curr_count * prev_mean) / count
148 | deltas = tl.where(mask, (input - mean) * (input - prev_mean), 0.)
149 | var = prev_var + tl.sum(deltas)
150 |
151 | return count, mean, var
152 |
153 |
154 | @triton.jit
155 | def update_ema(prev_ema, new_val, momentum):
156 | """
157 | Updates exponential moving average.
158 |
159 | Args:
160 | prev_ema: Previous exponential moving average.
161 | new_val: Value used to update the exponential moving average.
162 | momentum: Momentum.
163 |
164 | Returns:
165 | Updated running statistic.
166 | """
167 | return (1 - momentum) * prev_ema + momentum * new_val
168 |
169 |
170 | @triton.jit
171 | def standardize(input, mean, inv_std, weight, bias):
172 | """
173 | Standardizes the input given its mean and inverse standard deviation,
174 | multiplies the result by weights, and adds a bias vector.
175 |
176 | Args:
177 | input: Input to standardize.
178 | mean: Mean of input.
179 | inv_std: Inverse standard deviation of input.
180 | weight: Weight multiplied by the standardized input.
181 | bias: Bias added to the result of the weight multiplication.
182 |
183 | Returns:
184 | Standardized input.
185 | """
186 | return weight * inv_std * (input - mean) + bias
187 |
188 |
189 | @triton.jit
190 | def calc_p_loss(input, target, param, size,
191 | p_loss: tl.constexpr, reduction: tl.constexpr):
192 | """
193 | Measures the smooth L1, L1, squared L2, or Huber loss of the difference
194 | between the input and target.
195 |
196 | Args:
197 | input: Input.
198 | The input must be of shape [BLOCK_SIZE].
199 | target: Target.
200 | The target must be of shape [BLOCK_SIZE].
201 | param: Parameter of loss function (i.e., beta or delta for smooth L1 and Huber).
202 | size: Number of elements in the input and target.
203 | This value is used only if reduction is 'mean'.
204 | p_loss: p-norm used to compute the error.
205 | Options are 0 for smooth L1, 1 for L1, 2 for squared L2, and 3 for Huber loss.
206 | reduction: Reduction strategy for the output.
207 | Options are 'none' for no reduction, 'mean' for averaging the error
208 | across all entries, and 'sum' for summing the error across all entries.
209 |
210 | Returns:
211 | Error.
212 | """
213 | input = input.to(tl.float32)
214 | target = target.to(tl.float32)
215 |
216 | diff = input - target
217 |
218 | if p_loss == 0:
219 | error = tl.where(diff < param, 0.5 * diff * diff / param, tl.abs(diff) - 0.5 * param)
220 |
221 | elif p_loss == 1:
222 | error = tl.abs(diff)
223 |
224 | elif p_loss == 2:
225 | error = diff * diff
226 |
227 | elif p_loss == 3:
228 | error = tl.where(diff < param, 0.5 * diff * diff, param * (tl.abs(diff) - 0.5 * param))
229 |
230 | if reduction == 'none':
231 | output = error
232 |
233 | elif reduction == 'mean':
234 | output = tl.sum(error) / size
235 |
236 | elif reduction == 'sum':
237 | output = tl.sum(error)
238 |
239 | return output
240 |
241 |
242 | @triton.jit
243 | def nll_loss(input, size,
244 | reduction: tl.constexpr):
245 | """
246 | Measures the negative log likelihood loss given log-probabilities of target class.
247 |
248 | Args:
249 | input: Input containing predicted log-probabilities corresponding to target class.
250 | The input can have arbitrary shape.
251 | size: Number of elements in the input.
252 | This value is used only if reduction is 'mean'.
253 | reduction: Reduction strategy for the output.
254 | Options are 'none' for no reduction, 'mean' for averaging the loss
255 | across all entries, and 'sum' for summing the loss across all entries.
256 |
257 | Returns:
258 | Loss.
259 | """
260 | input = input.to(tl.float32)
261 |
262 | if reduction == 'none':
263 | output = -input
264 |
265 | elif reduction == 'mean':
266 | output = -tl.sum(input) / size
267 |
268 | elif reduction == 'sum':
269 | output = -tl.sum(input)
270 |
271 | return output
272 |
273 |
274 | @triton.jit
275 | def cross_entropy_loss(input, pred):
276 | """
277 | Measures the per-row cross entropy loss given
278 | input and predicted logits corresponding to target class.
279 |
280 | Args:
281 | input: Input.
282 | The input must be of shape [BLOCK_SIZE1, BLOCK_SIZE2].
283 | pred: Predicted logits corresponding to target class.
284 | The predictions must be of shape [BLOCK_SIZE1].
285 |
286 | Returns:
287 | Loss.
288 | """
289 | input = input.to(tl.float32)
290 | pred = pred.to(tl.float32)
291 |
292 | mx = tl.max(input, axis=1)
293 | input -= mx[:, None]
294 | loss = tl.log(tl.sum(tl.exp(input), axis=1)) - pred + mx
295 |
296 | return loss
297 |
--------------------------------------------------------------------------------