├── .gitmodules ├── stk ├── backend │ ├── __init__.py │ ├── autocast.py │ ├── sputnik.py │ └── triton_kernels.py ├── random │ ├── __init__.py │ ├── random_ops.py │ └── random_ops_test.py ├── __init__.py ├── ops │ ├── __init__.py │ ├── eltwise_ops.py │ ├── linear_ops.py │ ├── matrix_ops_test.py │ ├── eltwise_ops_test.py │ ├── matrix_ops.py │ └── linear_ops_test.py └── matrix.py ├── requirements.txt ├── entrypoint.sh ├── .gitignore ├── Dockerfile ├── media └── block_sparse_matmul_benchmarks.png ├── docker.sh ├── Makefile ├── MANIFEST.in ├── .github └── workflows │ └── ci.yaml ├── setup.py ├── Dockerfile-runner ├── README.md └── LICENSE /.gitmodules: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /stk/backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | triton==2.1.0 2 | -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./run.sh & wait $! -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.egg-info 3 | build 4 | __pycache__/ 5 | core 6 | dist -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.09-py3 2 | 3 | WORKDIR /mount/stk 4 | -------------------------------------------------------------------------------- /stk/random/__init__.py: -------------------------------------------------------------------------------- 1 | from stk.random.random_ops import dense_mask, mask, randn 2 | -------------------------------------------------------------------------------- /stk/__init__.py: -------------------------------------------------------------------------------- 1 | import stk.random 2 | import stk.ops 3 | from stk.matrix import Matrix 4 | -------------------------------------------------------------------------------- /media/block_sparse_matmul_benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanford-futuredata/stk/HEAD/media/block_sparse_matmul_benchmarks.png -------------------------------------------------------------------------------- /stk/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from stk.ops.linear_ops import dds, dsd, sdd 2 | from stk.ops.matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse 3 | from stk.ops.eltwise_ops import mul 4 | -------------------------------------------------------------------------------- /docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the repo location. Assumes that this 4 | # script is executed from inside the repo. 5 | BASEDIR=`cd .. && pwd` 6 | docker run -it --runtime=nvidia -v ${BASEDIR}:/mount stk-dev 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | rm -rf dist/* 3 | 4 | dist: clean 5 | python3 setup.py sdist 6 | 7 | upload: dist 8 | twine upload dist/* 9 | 10 | upload-test: dist 11 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 12 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include csrc *.cc 2 | 3 | recursive-include third_party/sputnik/sputnik *.h 4 | recursive-include third_party/sputnik/sputnik *.cu 5 | recursive-include third_party/sputnik/sputnik *.cc 6 | 7 | recursive-include third_party/sputnik/third_party/cutlass/include *.h -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: run-tests 2 | on: [push] 3 | jobs: 4 | linear_tests: 5 | runs-on: [self-hosted] 6 | steps: 7 | - uses: actions/checkout@v3 8 | - name: Correctness tests 9 | run: | 10 | echo 'Setup...' 11 | python setup.py install --prefix ~/.local 12 | echo 'Run tests...' 13 | python stk/ops/linear_ops_test.py 14 | python stk/ops/matrix_ops_test.py 15 | python stk/random/random_ops_test.py 16 | -------------------------------------------------------------------------------- /stk/ops/eltwise_ops.py: -------------------------------------------------------------------------------- 1 | from stk.matrix import Matrix 2 | 3 | def mul(a, b): 4 | """Performs element-wise multiplication of matrices a and b. 5 | 6 | It is the user's responsibility to make sure that a and b 7 | follow the same matrix topology. This function assumes it is safe 8 | to use the topoplogy of a. 9 | 10 | Args: 11 | a: stk.Matrix. 12 | b: stk.Matrix with a's matrix topology. 13 | 14 | Returns: 15 | stk.Matrix where the entries correspond to torch.mul(a, b). 16 | """ 17 | assert isinstance(a, Matrix) 18 | assert isinstance(b, Matrix) 19 | assert a.size() == b.size() 20 | 21 | return Matrix(a.size(), 22 | a.data * b.data, 23 | a.row_indices, 24 | a.column_indices, 25 | a.offsets, 26 | a.column_indices_t, 27 | a.offsets_t, 28 | a.block_offsets_t) 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | install_requires = [ 4 | 'torch>=2.3.0,<3.0', 5 | 'triton>=2.1.0', 6 | ] 7 | 8 | extra_deps = {} 9 | 10 | extra_deps['dev'] = [ 11 | 'absl-py', 12 | ] 13 | 14 | extra_deps['all'] = list(set(dep for deps in extra_deps.values() for dep in deps)) 15 | 16 | setup( 17 | name="stanford-stk", 18 | version="0.7.1", 19 | author="Trevor Gale", 20 | author_email="tgale@stanford.edu", 21 | description="Sparse Toolkit", 22 | long_description=open('README.md').read(), 23 | long_description_content_type='text/markdown', 24 | url="https://github.com/stanford-futuredata/stk", 25 | classifiers=[ 26 | "Programming Language :: Python :: 3", 27 | "License :: OSI Approved :: BSD License", 28 | "Operating System :: Unix", 29 | ], 30 | packages=find_packages(), 31 | install_requires=install_requires, 32 | extras_require=extra_deps, 33 | ) 34 | -------------------------------------------------------------------------------- /stk/random/random_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from stk.ops import matrix_ops 4 | 5 | 6 | @torch.no_grad() 7 | def dense_mask(rows, cols, sparsity, blocking=1): 8 | assert sparsity >= 0.0 and sparsity <= 1.0 9 | assert rows % blocking == 0 and cols % blocking == 0 10 | 11 | block_rows, block_cols = (rows // blocking, cols // blocking) 12 | nnz = round(block_rows * block_cols * (1 - sparsity)) 13 | 14 | out = np.ones(block_rows * block_cols) 15 | mask = np.random.choice(out.size, out.size - nnz, replace=False) 16 | out[mask] = 0.0 17 | 18 | out = np.tile( 19 | np.reshape(out, [block_rows, 1, block_cols, 1]), 20 | (1, blocking, 1, blocking)) 21 | out = np.reshape(out, [rows, cols]) 22 | return torch.from_numpy(out.astype(np.float32)) 23 | 24 | 25 | @torch.no_grad() 26 | def mask(m, n, sparsity, blocking=1): 27 | out = dense_mask(m, n, sparsity, blocking).type(torch.float16) 28 | return matrix_ops.to_sparse(out, blocking=blocking) 29 | 30 | 31 | @torch.no_grad() 32 | def randn(shape, sparsity, blocking=1): 33 | shape_2d = (np.prod(shape[:-1]), shape[-1]) 34 | out = mask(*shape_2d, sparsity, blocking) 35 | out.data.copy_(torch.randn(*out.data.shape)) 36 | return out.view(*shape) 37 | -------------------------------------------------------------------------------- /Dockerfile-runner: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.01-py3 2 | 3 | RUN DEBIAN_FRONTEND=noninteractive 4 | 5 | # Install Triton version:2.1.0 6 | RUN pip uninstall -y triton 7 | WORKDIR /tmp/install 8 | RUN git clone https://github.com/openai/triton.git 9 | RUN pip install cmake 10 | WORKDIR /tmp/install/triton/python 11 | RUN pip install -e . 12 | 13 | # Installl github actions-runner and dependencies 14 | ENV TZ=US/Pacific 15 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 16 | 17 | RUN useradd -m ci 18 | RUN cd /home/ci && mkdir actions-runner && cd actions-runner 19 | 20 | WORKDIR /home/ci/actions-runner 21 | 22 | RUN curl -o actions-runner-linux-x64-2.304.0.tar.gz -L https://github.com/actions/runner/releases/download/v2.304.0/actions-runner-linux-x64-2.304.0.tar.gz 23 | RUN tar xzf ./actions-runner-linux-x64-2.304.0.tar.gz 24 | 25 | RUN chown -R ci ~ci && /home/ci/actions-runner/bin/installdependencies.sh 26 | RUN ./bin/installdependencies.sh 27 | 28 | COPY entrypoint.sh entrypoint.sh 29 | RUN chmod +x entrypoint.sh 30 | USER ci 31 | 32 | ARG GITHUB_REPO 33 | ARG ACCESS_TOKEN 34 | 35 | RUN cd /home/ci/actions-runner 36 | RUN ./config.sh --url ${GITHUB_REPO} --pat ${ACCESS_TOKEN} 37 | 38 | ENTRYPOINT ["./entrypoint.sh"] 39 | -------------------------------------------------------------------------------- /stk/backend/autocast.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | 4 | 5 | def _is_eligible(x): 6 | return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) 7 | 8 | 9 | def _cast(x, dtype): 10 | if isinstance(x, torch.Tensor) and _is_eligible(x): 11 | return x.to(dtype) 12 | elif isinstance(x, map): 13 | return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} 14 | elif isinstance(x, list) or isinstance(x, tuple): 15 | return type(x)(map(lambda y: _cast(y, dtype), x)) 16 | return x 17 | 18 | 19 | def custom_fwd(fwd): 20 | """Wrap a custom autograd function that always uses autocast dtype.""" 21 | 22 | @functools.wraps(fwd) 23 | def decorate_fwd(*args, **kwargs): 24 | if torch.is_autocast_enabled(): 25 | with torch.autocast(device_type="cuda", enabled=False): 26 | dtype = torch.get_autocast_gpu_dtype() 27 | return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) 28 | return fwd(*args, **kwargs) 29 | return decorate_fwd 30 | 31 | 32 | def custom_bwd(bwd): 33 | @functools.wraps(bwd) 34 | def decorate_bwd(*args, **kwargs): 35 | with torch.autocast(device_type="cuda", enabled=False): 36 | return bwd(*args, **kwargs) 37 | return decorate_bwd 38 | -------------------------------------------------------------------------------- /stk/ops/linear_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from stk.backend import sputnik 4 | from stk.matrix import Matrix 5 | 6 | 7 | def dsd(a, b): 8 | assert isinstance(a, Matrix) 9 | assert isinstance(b, torch.Tensor) 10 | return sputnik.dsd( 11 | a.size(), 12 | a.data, a.offsets, 13 | a.row_indices, 14 | a.column_indices, 15 | a.offsets_t, 16 | a.column_indices_t, 17 | a.block_offsets_t, 18 | not a.is_contiguous(), 19 | b) 20 | 21 | 22 | def dds(a, b): 23 | assert isinstance(a, torch.Tensor) 24 | assert isinstance(b, Matrix) 25 | return sputnik.dds( 26 | a, 27 | b.size(), 28 | b.data, b.offsets, 29 | b.row_indices, 30 | b.column_indices, 31 | b.offsets_t, 32 | b.column_indices_t, 33 | b.block_offsets_t, 34 | not b.is_contiguous()) 35 | 36 | 37 | def sdd(a, b, topo): 38 | assert isinstance(a, torch.Tensor) 39 | assert isinstance(b, torch.Tensor) 40 | assert isinstance(topo, Matrix) 41 | assert topo.is_contiguous() 42 | out = sputnik.sdd( 43 | a, b, 44 | topo.size(), 45 | topo.data, 46 | topo.offsets, 47 | topo.row_indices, 48 | topo.column_indices, 49 | topo.offsets_t, 50 | topo.column_indices_t, 51 | topo.block_offsets_t) 52 | return Matrix(topo.size(), 53 | out, 54 | topo.row_indices, 55 | topo.column_indices, 56 | topo.offsets, 57 | topo.column_indices_t, 58 | topo.offsets_t, 59 | topo.block_offsets_t) 60 | -------------------------------------------------------------------------------- /stk/ops/matrix_ops_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from absl.testing import parameterized 4 | import stk 5 | import torch 6 | 7 | 8 | @parameterized.parameters( 9 | (8, 16, 0.0, 1), 10 | (8, 16, 0.5, 1), 11 | (8, 16, .95, 1), 12 | (16, 8, 0.0, 1), 13 | (16, 8, 0.5, 1), 14 | (16, 8, .95, 1), 15 | (8, 16, 0.0, 8), 16 | (8, 16, 0.5, 8), 17 | (8, 16, 1.0, 8), 18 | (16, 8, 0.0, 8), 19 | (16, 8, 0.5, 8), 20 | (16, 8, 1.0, 8), 21 | (128, 256, 0.5, 16), 22 | (256, 128, 0.75, 32), 23 | (512, 512, .875, 128)) 24 | class MatrixOpsTest(parameterized.TestCase): 25 | 26 | def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): 27 | mask = stk.random.dense_mask(rows, cols, sparsity, blocking) 28 | x = (torch.randn(rows, cols) * mask).type(torch.float16) 29 | 30 | # Convert the matrix to sparse format. 31 | sparse_x = stk.ops.to_sparse(x, blocking) 32 | 33 | # Validate the matrix. 34 | sparse_x.validate() 35 | 36 | # Validate the shape. 37 | self.assertEqual(sparse_x.dim(), 2) 38 | self.assertEqual(sparse_x.size()[0], rows) 39 | self.assertEqual(sparse_x.size()[1], cols) 40 | 41 | # Validate the sparsity. 42 | numblocks = rows // blocking * cols // blocking 43 | nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 44 | self.assertEqual(sparse_x.nnz, nnz) 45 | 46 | # Convert back to dense format. 47 | dense_x = stk.ops.to_dense(sparse_x) 48 | 49 | # Validate the shape. 50 | self.assertEqual(dense_x.dim(), 2) 51 | self.assertEqual(dense_x.size()[0], rows) 52 | self.assertEqual(dense_x.size()[1], cols) 53 | 54 | # Validate the sparsity 55 | self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) 56 | 57 | # Validate the output. 58 | self.assertTrue(torch.all(torch.eq(x, dense_x))) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :abacus: Sparse Toolkit 2 | 3 | A light-weight PyTorch library for block-sparse matrices and block-sparse matrix multiplication. 4 | 5 | STK is built around a core sparse matrix class ([stk.Matrix](stk/matrix.py)), which uses a hybrid [blocked-CSR-COO](https://arxiv.org/abs/2211.15841) sparse matrix encoding to enable efficient matrix products with sparse inputs and outputs in transposed or non-transposed order. The library supports the following operations: 6 | 7 | ``` 8 | op: transpose or non-transpose 9 | 10 | [Sparse Matrix Multiplication] 11 | stk.ops.dsd: dense = op(sparse) x op(dense) 12 | stk.ops.dds: dense = op(dense) x op(sparse) 13 | stk.ops.sdd: sparse = op(dense) x op(dense) 14 | 15 | [Sparse Matrix Conversion] 16 | stk.ops.to_sparse: torch.Tensor => stk.Matrix 17 | stk.ops.to_dense: stk.Matrix => torch.Tensor 18 | 19 | [Sparse Matrix Generation] 20 | stk.random.dense_mask: Create a random, block-sparse dense matrix. 21 | stk.random.mask: Create a random, block-sparse sparse matrix. 22 | ``` 23 | 24 | STK is designed for applications where the sparse matrices change rapidly. This is complementary to libraries like [triton-blocksparse](https://github.com/ptillet/torch-blocksparse), which assume that sparse matrix topologies do not change between invocations. 25 | 26 | # :rocket: Performance 27 | 28 | ![STK Performance](media/block_sparse_matmul_benchmarks.png) 29 | 30 | Block-sparse matrix multiplication with STK is able to match the performance of cuBLAS on a range of problems. On these benchmarks from [MegaBlocks](https://github.com/stanford-futuredata/megablocks) dMoE models, STK realizes **98.6%** of cuBLAS throughput with `128x128` blocks on average. 31 | 32 | ``` 33 | Hardware: A100-SXM4-80GB 34 | Software: CUDA 11.5, CUTLASS 2.5 35 | ``` 36 | 37 | # :building_construction: Installation 38 | 39 | NOTE: This assumes that you have `torch` and `numpy` installed. 40 | 41 | `pip install stanford-stk` 42 | 43 | # :writing_hand: Citation 44 | 45 | ``` 46 | @article{megablocks-arxiv, 47 | author = {Trevor Gale and Deepak Narayanan and Cliff Young and Matei Zaharia}, 48 | title = {MegaBlocks: Efficient Sparse Training with Mixture-of-Experts}, 49 | journal = {CoRR}, 50 | volume = {abs/2211.15841}, 51 | year = {2022}, 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /stk/random/random_ops_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from absl.testing import parameterized 4 | import stk 5 | import torch 6 | 7 | 8 | @parameterized.parameters( 9 | (8, 16, 0.0, 1), 10 | (8, 16, 0.5, 1), 11 | (8, 16, .95, 1), 12 | (16, 8, 0.0, 1), 13 | (16, 8, 0.5, 1), 14 | (16, 8, .95, 1), 15 | (8, 16, 0.0, 8), 16 | (8, 16, 0.5, 8), 17 | (8, 16, 1.0, 8), 18 | (16, 8, 0.0, 8), 19 | (16, 8, 0.5, 8), 20 | (16, 8, 1.0, 8), 21 | (128, 256, 0.5, 16), 22 | (256, 128, 0.75, 32), 23 | (512, 512, .875, 128)) 24 | class RandomOpsTest(parameterized.TestCase): 25 | 26 | def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): 27 | mask = stk.random.dense_mask( 28 | rows, cols, sparsity, blocking) 29 | 30 | # Validate the shape. 31 | self.assertEqual(mask.dim(), 2) 32 | self.assertEqual(mask.size()[0], rows) 33 | self.assertEqual(mask.size()[1], cols) 34 | 35 | # Validate the sparsity 36 | numblocks = rows // blocking * cols // blocking 37 | nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 38 | self.assertEqual( 39 | torch.count_nonzero(mask).item(), 40 | nnz) 41 | 42 | # Check values are zero or one. 43 | self.assertTrue( 44 | torch.all(torch.logical_or( 45 | torch.eq(mask, 0), 46 | torch.eq(mask, 1)))) 47 | 48 | def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): 49 | mask = stk.random.mask( 50 | rows, cols, sparsity, blocking) 51 | 52 | # Validate the matrix. 53 | mask.validate() 54 | 55 | # Validate the shape. 56 | self.assertEqual(mask.dim(), 2) 57 | self.assertEqual(mask.size()[0], rows) 58 | self.assertEqual(mask.size()[1], cols) 59 | 60 | # Validate the sparsity. 61 | numblocks = rows // blocking * cols // blocking 62 | nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 63 | self.assertEqual(mask.nnz, nnz) 64 | 65 | # Check values are zero or one. 66 | self.assertTrue( 67 | torch.all(torch.logical_or( 68 | torch.eq(mask.data, 0), 69 | torch.eq(mask.data, 1)))) 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /stk/ops/eltwise_ops_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import itertools 3 | import torch 4 | from absl.testing import parameterized 5 | 6 | import stk 7 | from stk.ops.linear_ops_test import allclose, _dense_and_sparse 8 | 9 | _MATRIX_SIZES = ( 10 | (128, 128, 0.0), 11 | (256, 256, 0.5), 12 | (2048, 1024, 0.8), 13 | (512, 128, 0.0), 14 | (128, 512, 0.0), 15 | (1024, 512, 0.0), 16 | (1024, 512, 0.5), 17 | (1024, 512, 0.75), 18 | (512, 1024, 0.0), 19 | (512, 1024, 0.5), 20 | (512, 1024, 0.75), 21 | (1024, 1024, 0.0), 22 | (1024, 1024, 0.5), 23 | (1024, 1024, 0.75), 24 | ) 25 | 26 | _DTYPE = ( 27 | torch.float16, torch.bfloat16 28 | ) 29 | 30 | def _generate_testcases(): 31 | testcases = itertools.product(_MATRIX_SIZES, _DTYPE) 32 | testcases = [(*size, 128, dtype) for 33 | (size, dtype) in testcases] 34 | return testcases 35 | 36 | _ELTWISE_OP_TESTS = _generate_testcases() 37 | 38 | def _dense_and_sparse_like(x, std=0.1): 39 | dense_data = torch.randn_like(x.data, device=x.device) * std 40 | sparse = stk.Matrix(x.size(), 41 | dense_data, 42 | x.row_indices, 43 | x.column_indices, 44 | x.offsets) 45 | dense = stk.ops.to_dense(sparse) 46 | 47 | return (dense.requires_grad_(True), 48 | sparse.requires_grad_(True)) 49 | 50 | @parameterized.parameters(_ELTWISE_OP_TESTS) 51 | class EltwiseOpsTest(parameterized.TestCase): 52 | 53 | def testEltwiseMul(self, m, n, sparsity, blocking, dtype): 54 | 55 | a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) 56 | b_dense, b = _dense_and_sparse_like(a) 57 | 58 | out = stk.ops.mul(a, b) 59 | expected_out = torch.mul(a_dense, b_dense) 60 | 61 | # Compute the gradients w.r.t. the inputs. 62 | expected_out.sum().backward() 63 | stk.ops.sum(out).backward() 64 | 65 | # Validate the results. 66 | out = stk.ops.to_dense(out) 67 | self.assertEqual(out.dim(), 2) 68 | self.assertEqual(expected_out.size(), out.size()) 69 | self.assertTrue(allclose(out, expected_out)) 70 | 71 | # LHS gradient. 72 | grad = stk.ops.to_dense(a.grad) 73 | expected_grad = a_dense.grad 74 | self.assertEqual(grad.dim(), 2) 75 | self.assertEqual(expected_grad.size(), grad.size()) 76 | self.assertTrue(allclose(grad, expected_grad)) 77 | 78 | # RHS gradient. 79 | grad = stk.ops.to_dense(b.grad) 80 | expected_grad = b_dense.grad 81 | self.assertEqual(grad.dim(), 2) 82 | self.assertEqual(expected_grad.size(), grad.size()) 83 | self.assertTrue(allclose(grad, expected_grad)) 84 | 85 | if __name__ == '__main__': 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /stk/ops/matrix_ops.py: -------------------------------------------------------------------------------- 1 | from stk.backend import sputnik 2 | from stk.matrix import Matrix 3 | import torch 4 | import numpy as np 5 | 6 | 7 | @torch.no_grad() 8 | def row_indices(shape, data, offsets, column_indices): 9 | return sputnik.row_indices(shape, data, offsets, column_indices) 10 | 11 | 12 | # TODO(tgale): Replace this helper with a custom kernel. This operation 13 | # is much simpler to do than how it's currently implemented. 14 | @torch.no_grad() 15 | def _expand_for_blocking(idxs, blocking): 16 | # Duplicate for block column dimension. 17 | idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) 18 | 19 | # Update the column indices. 20 | idxs[:, :, 1] *= blocking 21 | idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) 22 | 23 | # Duplicate for block row dimension. 24 | idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) 25 | idxs = idxs.repeat(1, blocking, 1, 1) 26 | 27 | # Update the row indices. 28 | idxs[:, :, :, 0] *= blocking 29 | idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) 30 | idxs = torch.reshape(idxs, [-1, 2]) 31 | return idxs 32 | 33 | 34 | # TODO(tgale): Add input type checking. 35 | @torch.no_grad() 36 | def to_dense(x): 37 | assert isinstance(x, Matrix) 38 | 39 | shape = (np.prod(x.shape[:-1]), x.shape[-1]) 40 | row_idxs = x.row_indices.type(torch.int32) 41 | col_idxs = x.column_indices.type(torch.int32) 42 | indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) 43 | indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) 44 | 45 | out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) 46 | out.scatter_(0, indices, x.data.flatten()) 47 | return out.reshape(x.size()) 48 | 49 | 50 | @torch.no_grad() 51 | def _mask(x, blocking=1): 52 | assert x.dim() == 2 53 | assert x.size()[0] % blocking == 0 54 | assert x.size()[1] % blocking == 0 55 | block_rows = x.size()[0] // blocking 56 | block_cols = x.size()[1] // blocking 57 | x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) 58 | x = torch.sum(torch.abs(x), dim=(1, 3)) 59 | return x != 0 60 | 61 | 62 | # TODO(tgale): Add input type checking. 63 | @torch.no_grad() 64 | def to_sparse(x, blocking=1): 65 | m = _mask(x, blocking) 66 | 67 | # TODO(tgale): Set to appropriate type for input matrix. 68 | row_nnzs = torch.sum(m, dim=1).type(torch.int32) 69 | zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) 70 | offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) 71 | offsets = offsets.type(torch.int32) 72 | 73 | indices = torch.nonzero(m).type(torch.int16) 74 | row_indices = indices[:, 0] 75 | column_indices = indices[:, 1] 76 | 77 | # Nonzero indices in the dense matrix. 78 | nonzero_indices = torch.nonzero(m) 79 | nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) 80 | nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] 81 | 82 | # Gather the data and construct the sparse matrix. 83 | data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) 84 | data = torch.reshape(data, [-1, blocking, blocking]) 85 | return Matrix(x.size(), data, row_indices, column_indices, offsets) 86 | 87 | 88 | @torch.no_grad() 89 | def ones_like(x): 90 | return Matrix(x.size(), 91 | torch.ones_like(x.data), 92 | x.row_indices, 93 | x.column_indices, x.offsets) 94 | 95 | 96 | def sum(x): 97 | assert isinstance(x, Matrix) 98 | return x.data.sum() 99 | -------------------------------------------------------------------------------- /stk/ops/linear_ops_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import itertools 3 | import numpy as np 4 | import torch 5 | from absl.testing import parameterized 6 | 7 | import stk 8 | 9 | 10 | def allclose(x, y, pct=0.25): 11 | mask = torch.isclose(x, y, rtol=5e-2) 12 | pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 13 | if pct_diff > pct: 14 | print("{:.2f}% of values not close.".format(pct_diff)) 15 | return False 16 | return True 17 | 18 | 19 | # An assortment of problems designed to make sure 20 | # the bindings are operating correctly. 21 | _MATRIX_SIZES = ( 22 | (128, 128, 128, 0.0), 23 | (256, 256, 256, 0.5), 24 | (2048, 1024, 512, 0.8), 25 | (512, 128, 128, 0.0), 26 | (128, 128, 512, 0.0), 27 | (1024, 512, 512, 0.0), 28 | (1024, 512, 512, 0.5), 29 | (1024, 512, 512, 0.75), 30 | (512, 512, 1024, 0.0), 31 | (512, 512, 1024, 0.5), 32 | (512, 512, 1024, 0.75), 33 | (1024, 1024, 1024, 0.0), 34 | (1024, 1024, 1024, 0.5), 35 | (1024, 1024, 1024, 0.75), 36 | ) 37 | 38 | _TRANSPOSE = ( 39 | (False, False), 40 | (False, True), 41 | (True, False), 42 | (True, True), 43 | ) 44 | 45 | _DTYPE = ( 46 | torch.float16, torch.bfloat16 47 | ) 48 | 49 | def _generate_testcases(): 50 | testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) 51 | testcases = [(*size, *trans, 128, dtype) for 52 | (size, trans, dtype) in testcases] 53 | return testcases 54 | 55 | _LINEAR_OP_TESTS = _generate_testcases() 56 | 57 | def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): 58 | mask = stk.random.dense_mask(rows, cols, sparsity, blocking) 59 | dense = (torch.randn(rows, cols) * std * mask).type(dtype) 60 | sparse = stk.ops.to_sparse(dense, blocking) 61 | cuda_device = torch.device("cuda") 62 | return (dense.to(cuda_device).requires_grad_(True), 63 | sparse.to(cuda_device).requires_grad_(True)) 64 | 65 | 66 | def _dense(rows, cols, dtype, std=0.1): 67 | cuda_device = torch.device("cuda") 68 | out = (torch.randn(rows, cols) * std).type(dtype) 69 | return out.to(cuda_device).requires_grad_(True) 70 | 71 | 72 | def _dense_2x(rows, cols, dtype): 73 | a = _dense(rows, cols, dtype) 74 | return a, a.detach().requires_grad_(True) 75 | 76 | 77 | def _with_transpose(op, a, b, trans_a, trans_b): 78 | a = a.t() if trans_a else a 79 | b = b.t() if trans_b else b 80 | return op(a, b) 81 | 82 | 83 | def _mmm(a, b, topo): 84 | mask = stk.ops.to_dense(stk.ops.ones_like(topo)) 85 | return torch.mm(a, b) * mask 86 | 87 | 88 | def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): 89 | a = a.t() if trans_a else a 90 | b = b.t() if trans_b else b 91 | return op(a, b, topo) 92 | 93 | 94 | def _mask(x, mask): 95 | mask = stk.ops.to_dense(stk.ops.ones_like(mask)) 96 | return x * mask 97 | 98 | 99 | @parameterized.parameters(*_LINEAR_OP_TESTS) 100 | class LinearOpsTest(parameterized.TestCase): 101 | 102 | def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): 103 | # Construct the operands. 104 | a_shape = (k, m) if trans_a else (m, k) 105 | a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) 106 | b_shape = (n, k) if trans_b else (k, n) 107 | b, bcp = _dense_2x(*b_shape, dtype) 108 | 109 | # Execute the matmul. 110 | out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) 111 | expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) 112 | 113 | # Compute the gradients w.r.t. the inputs. 114 | expected_out.sum().backward() 115 | out.sum().backward() 116 | 117 | # Validate the results. 118 | self.assertEqual(out.dim(), 2) 119 | self.assertEqual(expected_out.size()[0], out.size()[0]) 120 | self.assertEqual(expected_out.size()[1], out.size()[1]) 121 | self.assertTrue(allclose(out, expected_out)) 122 | 123 | # LHS gradient. 124 | grad = stk.ops.to_dense(a.grad) 125 | expected_grad = _mask(a_dense.grad, a.grad) 126 | self.assertEqual(grad.dim(), 2) 127 | self.assertEqual(expected_grad.size()[0], grad.size()[0]) 128 | self.assertEqual(expected_grad.size()[1], grad.size()[1]) 129 | self.assertTrue(allclose(grad, expected_grad)) 130 | 131 | # RHS gradient. 132 | grad = b.grad 133 | expected_grad = bcp.grad 134 | self.assertEqual(grad.dim(), 2) 135 | self.assertEqual(expected_grad.size()[0], grad.size()[0]) 136 | self.assertEqual(expected_grad.size()[1], grad.size()[1]) 137 | self.assertTrue(allclose(grad, expected_grad)) 138 | 139 | def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): 140 | # Construct the operands. 141 | a_shape = (k, m) if trans_a else (m, k) 142 | a, acp = _dense_2x(*a_shape, dtype) 143 | b_shape = (n, k) if trans_b else (k, n) 144 | b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) 145 | 146 | # Execute the matmul. 147 | out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) 148 | expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) 149 | 150 | # Compute the gradients w.r.t. the inputs. 151 | expected_out.sum().backward() 152 | out.sum().backward() 153 | 154 | # Validate the results. 155 | self.assertEqual(out.dim(), 2) 156 | self.assertEqual(expected_out.size()[0], out.size()[0]) 157 | self.assertEqual(expected_out.size()[1], out.size()[1]) 158 | self.assertTrue(allclose(out, expected_out)) 159 | 160 | # LHS gradient. 161 | grad = a.grad 162 | expected_grad = acp.grad 163 | self.assertEqual(grad.dim(), 2) 164 | self.assertEqual(expected_grad.size()[0], grad.size()[0]) 165 | self.assertEqual(expected_grad.size()[1], grad.size()[1]) 166 | self.assertTrue(allclose(grad, expected_grad)) 167 | 168 | # RHS gradient. 169 | grad = stk.ops.to_dense(b.grad) 170 | expected_grad = _mask(b_dense.grad, b.grad) 171 | self.assertEqual(grad.dim(), 2) 172 | self.assertEqual(expected_grad.size()[0], grad.size()[0]) 173 | self.assertEqual(expected_grad.size()[1], grad.size()[1]) 174 | self.assertTrue(allclose(grad, expected_grad)) 175 | 176 | def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): 177 | # Construct the operands. 178 | a_shape = (k, m) if trans_a else (m, k) 179 | a, acp = _dense_2x(*a_shape, dtype) 180 | b_shape = (n, k) if trans_b else (k, n) 181 | b, bcp = _dense_2x(*b_shape, dtype) 182 | _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) 183 | 184 | # Execute the matmul. 185 | out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) 186 | expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) 187 | 188 | # Compute the gradients w.r.t. the inputs. 189 | expected_out.sum().backward() 190 | stk.ops.sum(out).backward() 191 | 192 | # Validate the results. 193 | out = stk.ops.to_dense(out) 194 | self.assertEqual(out.dim(), 2) 195 | self.assertEqual(expected_out.size()[0], out.size()[0]) 196 | self.assertEqual(expected_out.size()[1], out.size()[1]) 197 | self.assertTrue(allclose(out, expected_out)) 198 | 199 | # LHS gradient. 200 | grad = a.grad 201 | expected_grad = acp.grad 202 | self.assertEqual(grad.dim(), 2) 203 | self.assertEqual(expected_grad.size()[0], grad.size()[0]) 204 | self.assertEqual(expected_grad.size()[1], grad.size()[1]) 205 | self.assertTrue(allclose(grad, expected_grad)) 206 | 207 | # RHS gradient. 208 | grad = b.grad 209 | expected_grad = bcp.grad 210 | self.assertEqual(grad.dim(), 2) 211 | self.assertEqual(expected_grad.size()[0], grad.size()[0]) 212 | self.assertEqual(expected_grad.size()[1], grad.size()[1]) 213 | self.assertTrue(allclose(grad, expected_grad)) 214 | 215 | if __name__ == '__main__': 216 | unittest.main() 217 | -------------------------------------------------------------------------------- /stk/backend/sputnik.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from stk.backend import triton_kernels as backend 4 | from stk.backend.autocast import custom_bwd, custom_fwd 5 | 6 | 7 | def _standardize_shape(x, transpose): 8 | if transpose: 9 | return torch.Size((x[1], x[0])) 10 | return x 11 | 12 | 13 | def _sparse_transpose(x): 14 | return (torch.Size((x[0][1], x[0][0])), ) + x[1:] 15 | 16 | 17 | def _transpose_helper(x, transpose): 18 | if isinstance(x, torch.Tensor): 19 | return x.t() if transpose else x 20 | if transpose: 21 | x = _sparse_transpose(x) 22 | return x + (transpose,) 23 | 24 | 25 | def _wrap(x): 26 | if isinstance(x, torch.Tensor): 27 | return (x,) 28 | return x 29 | 30 | 31 | def _is_transposed(x): 32 | return (not x.is_contiguous() and 33 | x.stride()[0] == 1 and 34 | x.stride()[1] == x.size()[0]) 35 | 36 | 37 | def _call_helper(op, out, a, b, trans_a, trans_b): 38 | args = (_wrap(_transpose_helper(a, trans_a)) + 39 | _wrap(_transpose_helper(b, trans_b))) 40 | if isinstance(out, tuple): 41 | args = args + out 42 | return op(*args) 43 | 44 | 45 | def _preprocess_inputs(lhs, rhs, dy): 46 | if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): 47 | lhs = lhs.t() 48 | if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): 49 | rhs = rhs.t() 50 | if (isinstance(dy, torch.Tensor) and 51 | not dy.is_contiguous() and 52 | not _is_transposed(dy)): 53 | dy = dy.contiguous() 54 | if isinstance(dy, tuple) and not dy[1].is_contiguous(): 55 | dy = (dy[0], dy[1].contiguous()) + dy[2:] 56 | return lhs, rhs, dy 57 | 58 | 59 | def _postprocess_outputs(x, transpose, grad): 60 | if isinstance(x, torch.Tensor) and transpose: 61 | return grad.t() 62 | return grad 63 | 64 | 65 | def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): 66 | lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) 67 | 68 | a, b = (rhs, dy) if trans_lhs else (dy, rhs) 69 | trans_a = trans_lhs and trans_rhs 70 | trans_b = trans_lhs or not trans_rhs 71 | out = _call_helper(op, lhs, a, b, trans_a, trans_b) 72 | return _postprocess_outputs(lhs, trans_lhs, out) 73 | 74 | 75 | def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): 76 | lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) 77 | 78 | a, b = (dy, lhs) if trans_rhs else (lhs, dy) 79 | trans_a = not trans_lhs or trans_rhs 80 | trans_b = trans_lhs and trans_rhs 81 | out = _call_helper(op, rhs, a, b, trans_a, trans_b) 82 | return _postprocess_outputs(rhs, trans_rhs, out) 83 | 84 | 85 | class DSD(torch.autograd.Function): 86 | 87 | @staticmethod 88 | @custom_fwd 89 | def forward(ctx, 90 | shape, 91 | data, 92 | offsets, 93 | row_indices, 94 | column_indices, 95 | offsets_t, 96 | column_indices_t, 97 | block_offsets_t, 98 | transpose_a, 99 | rhs): 100 | ctx.save_for_backward(data, 101 | offsets, 102 | row_indices, 103 | column_indices, 104 | offsets_t, 105 | column_indices_t, 106 | block_offsets_t, 107 | rhs) 108 | ctx.shape = _standardize_shape(shape, transpose_a) 109 | ctx.transpose_a = transpose_a 110 | 111 | out = torch.empty( 112 | (shape[0], rhs.size()[1]), 113 | dtype=rhs.dtype, 114 | device=rhs.device) 115 | 116 | backend.dsd(shape, 117 | data, 118 | offsets, 119 | row_indices, 120 | column_indices, 121 | offsets_t, 122 | column_indices_t, 123 | block_offsets_t, 124 | transpose_a, 125 | rhs, 126 | out) 127 | return out 128 | 129 | @staticmethod 130 | @custom_bwd 131 | def backward(ctx, dy): 132 | saved_tensors = ctx.saved_tensors 133 | lhs = (ctx.shape,) + saved_tensors[:-1] 134 | rhs = saved_tensors[-1] 135 | trans_a = ctx.transpose_a 136 | trans_b = _is_transposed(rhs) 137 | 138 | ddata = None 139 | if ctx.needs_input_grad[1]: 140 | ddata = _lhs_gradient(sdd, 141 | lhs, 142 | rhs, 143 | dy, 144 | trans_a, 145 | trans_b) 146 | drhs = None 147 | if ctx.needs_input_grad[-1]: 148 | op = dds if trans_b else dsd 149 | drhs = _rhs_gradient(op, 150 | lhs, 151 | rhs, 152 | dy, 153 | trans_a, 154 | trans_b) 155 | return None, ddata, None, None, None, None, None, None, None, drhs 156 | 157 | 158 | dsd = DSD.apply 159 | 160 | 161 | class DDS(torch.autograd.Function): 162 | 163 | @staticmethod 164 | @custom_fwd 165 | def forward(ctx, 166 | lhs, 167 | shape, 168 | data, 169 | offsets, 170 | row_indices, 171 | column_indices, 172 | offsets_t, 173 | column_indices_t, 174 | block_offsets_t, 175 | transpose_b): 176 | ctx.save_for_backward(lhs, 177 | data, 178 | offsets, 179 | row_indices, 180 | column_indices, 181 | offsets_t, 182 | column_indices_t, 183 | block_offsets_t) 184 | ctx.shape = _standardize_shape(shape, transpose_b) 185 | ctx.transpose_b = transpose_b 186 | out = torch.empty((lhs.size()[0], shape[1]), 187 | dtype=lhs.dtype, 188 | device=lhs.device) 189 | backend.dds(lhs, 190 | shape, 191 | data, 192 | offsets, 193 | row_indices, 194 | column_indices, 195 | offsets_t, 196 | column_indices_t, 197 | block_offsets_t, 198 | transpose_b, 199 | out) 200 | return out 201 | 202 | @staticmethod 203 | @custom_bwd 204 | def backward(ctx, dy): 205 | saved_tensors = ctx.saved_tensors 206 | lhs = saved_tensors[0] 207 | rhs = (ctx.shape,) + saved_tensors[1:] 208 | trans_a = _is_transposed(lhs) 209 | trans_b = ctx.transpose_b 210 | 211 | dlhs = None 212 | if ctx.needs_input_grad[0]: 213 | op = dsd if trans_a else dds 214 | dlhs = _lhs_gradient(op, 215 | lhs, 216 | rhs, 217 | dy, 218 | trans_a, 219 | trans_b) 220 | ddata = None 221 | if ctx.needs_input_grad[2]: 222 | ddata = _rhs_gradient(sdd, 223 | lhs, 224 | rhs, 225 | dy, 226 | trans_a, 227 | trans_b) 228 | return dlhs, None, ddata, None, None, None, None, None, None, None 229 | 230 | 231 | dds = DDS.apply 232 | 233 | 234 | class SDD(torch.autograd.Function): 235 | 236 | @staticmethod 237 | @custom_fwd 238 | def forward(ctx, 239 | lhs, 240 | rhs, 241 | shape, 242 | data, 243 | offsets, 244 | row_indices, 245 | column_indices, 246 | offsets_t, 247 | column_indices_t, 248 | block_offsets_t): 249 | ctx.save_for_backward( 250 | lhs, 251 | rhs, 252 | offsets, 253 | row_indices, 254 | column_indices, 255 | offsets_t, 256 | column_indices_t, 257 | block_offsets_t) 258 | ctx.shape = shape 259 | out = torch.empty( 260 | data.shape, 261 | dtype=lhs.dtype, 262 | device=lhs.device) 263 | backend.sdd(lhs, 264 | rhs, 265 | shape, 266 | out, 267 | offsets, 268 | row_indices, 269 | column_indices) 270 | return out 271 | 272 | @staticmethod 273 | @custom_bwd 274 | def backward(ctx, dy): 275 | saved_tensors = ctx.saved_tensors 276 | lhs, rhs = saved_tensors[:2] 277 | dy = (ctx.shape, dy) + saved_tensors[2:] 278 | trans_a = _is_transposed(lhs) 279 | trans_b = _is_transposed(rhs) 280 | 281 | dlhs = None 282 | if ctx.needs_input_grad[0]: 283 | op = dds if trans_a else dsd 284 | dlhs = _lhs_gradient(op, 285 | lhs, 286 | rhs, 287 | dy, 288 | trans_a, 289 | trans_b) 290 | drhs = None 291 | if ctx.needs_input_grad[1]: 292 | op = dsd if trans_b else dds 293 | drhs = _rhs_gradient(op, 294 | lhs, 295 | rhs, 296 | dy, 297 | trans_a, 298 | trans_b) 299 | return dlhs, drhs, None, None, None, None, None, None, None, None 300 | 301 | 302 | sdd = SDD.apply 303 | 304 | class RowIndices(torch.autograd.Function): 305 | 306 | @staticmethod 307 | def forward(ctx, shape, data, offsets, column_indices): 308 | out = torch.empty( 309 | column_indices.shape, 310 | dtype=column_indices.dtype, 311 | device=column_indices.device) 312 | backend.row_indices(shape, data, offsets, column_indices, out) 313 | return out 314 | 315 | 316 | row_indices = RowIndices.apply 317 | -------------------------------------------------------------------------------- /stk/matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | # 1. Add heavyweight (data) validation helper. 5 | # 2. Add construction helpers 6 | # 3. Make indentation consistent 7 | # 4. Replace asserts with descriptive errors. 8 | 9 | ## 10 | ### Validation helpers. 11 | ## 12 | 13 | 14 | def _validate_matrix(shape, data, row_indices, column_indices, offsets): 15 | # Data should be [nnz, block_size, block_size] 16 | if data.dim() == 1: 17 | data = torch.reshape(data, [data.numel(), 1, 1]) 18 | 19 | # Blocks should be square. 20 | if data.shape[-2] != data.shape[-1]: 21 | raise ValueError( 22 | "Expected square blocking in data. " 23 | f"Got block shape {[data.shape[-2], data.shape[-1]]}") 24 | 25 | # Flatten batch dimensions on data - original shape preserved 26 | # in shape argument. 27 | block_size = data.shape[-1] 28 | data = data.view([-1, block_size, block_size]) 29 | 30 | if data.dim() != 3: 31 | raise ValueError( 32 | "Expected 3D shape for data (nnz, block, block). " 33 | f"Got shape {data.dim()}D shape.") 34 | 35 | block_size = data.shape[1] 36 | if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: 37 | raise ValueError( 38 | "Matrix shape must be dividible by blocking. " 39 | f"Got shape {shape} with " 40 | f"{[block_size, block_size]} blocking.") 41 | 42 | if np.prod(shape) < data.numel(): 43 | raise ValueError( 44 | "Invalid matrix. Number of nonzeros exceeds matrix capacity " 45 | f"({data.numel()} v. {np.prod(shape)})") 46 | 47 | if row_indices.dim() != 1: 48 | raise ValueError( 49 | f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") 50 | 51 | if column_indices.dim() != 1: 52 | raise ValueError( 53 | f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") 54 | 55 | if offsets.dim() != 1: 56 | raise ValueError( 57 | f"Expected 1D offsets. Got {offsets.dim()}D offsets.") 58 | 59 | if row_indices.numel() != data.shape[0]: 60 | raise ValueError( 61 | "Expected 1 index per nonzero block. " 62 | f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") 63 | 64 | if column_indices.numel() != data.shape[0]: 65 | raise ValueError( 66 | "Expected 1 index per nonzero block. " 67 | f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") 68 | 69 | block_rows = np.prod(shape[:-1]) / block_size 70 | if offsets.numel() != block_rows + 1: 71 | raise ValueError( 72 | "Expected one offset per block row plus one. " 73 | f"Got {offsets.numel()} offsets with {block_rows} block rows.") 74 | 75 | is_cuda = (data.is_cuda and 76 | row_indices.is_cuda and 77 | column_indices.is_cuda and 78 | offsets.is_cuda) 79 | is_cpu = (not data.is_cuda and 80 | not row_indices.is_cuda and 81 | not column_indices.is_cuda and 82 | not offsets.is_cuda) 83 | if not (is_cuda or is_cpu): 84 | raise ValueError( 85 | "Expected data & meta-data on common device. " 86 | f"Got data on {data.device}, row_indices on {row_indices.device} " 87 | f"column_indices on {column_indices.device} and " 88 | f"offsets on {offsets.device}.") 89 | 90 | if data.dtype != torch.float16: 91 | raise ValueError( 92 | f"Expected float16 data. Got {data.dtype} data.") 93 | if row_indices.dtype != torch.int16: 94 | raise ValueError( 95 | f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") 96 | if column_indices.dtype != torch.int16: 97 | raise ValueError( 98 | f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") 99 | if offsets.dtype != torch.int32: 100 | raise ValueError( 101 | f"Expected int32 offsets. Got {offsets.dtype} offsets.") 102 | return data 103 | 104 | 105 | def _transpose(size, data, row_indices, column_indices, offsets): 106 | block_columns = size[1] // data.shape[1] 107 | 108 | # Sort row indices by column indices to get the transposed matrix's 109 | # column indices. 110 | gather_indices = column_indices.argsort() 111 | column_indices_t = row_indices.gather(0, gather_indices) 112 | block_offsets_t = gather_indices.int() 113 | 114 | # NOTE: Histogram is not implemented for any integer type on CPU. Do 115 | # the histogram in 32-bit float, which can exactly represent 16-bit 116 | # integers. 117 | column_indices_float = column_indices.float() 118 | 119 | zero = torch.zeros((1,), dtype=torch.int32, device=data.device) 120 | nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) 121 | nnz_per_column = nnz_per_column.int() 122 | offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) 123 | return column_indices_t, offsets_t, block_offsets_t 124 | 125 | 126 | class Matrix(torch.nn.Module): 127 | """A matrix stored in sparse format. 128 | 129 | Underlying format is block compressed sparse row (BCSR). 130 | 131 | TODO(tgale): Make this mirror torch.Tensor API as much as possible. 132 | """ 133 | 134 | def __init__(self, 135 | size, 136 | data, 137 | row_indices, 138 | column_indices, 139 | offsets, 140 | column_indices_t=None, 141 | offsets_t=None, 142 | block_offsets_t=None): 143 | super().__init__() 144 | self._size = size 145 | self._data = data 146 | self._row_indices = row_indices 147 | self._column_indices = column_indices 148 | self._offsets = offsets 149 | 150 | # Produce the transpose meta-data if it is not passed in. 151 | if ((column_indices_t is None) or (offsets_t is None) or 152 | (block_offsets_t is None)): 153 | column_indices_t, offsets_t, block_offsets_t = _transpose( 154 | size, data, row_indices, column_indices, offsets) 155 | self._column_indices_t = column_indices_t 156 | self._offsets_t = offsets_t 157 | self._block_offsets_t = block_offsets_t 158 | 159 | self._transposed = False 160 | 161 | # Validate that our metadata will not overflow. 162 | max_dim = np.iinfo(np.int16).max * self.blocking 163 | if column_indices.dtype == torch.int16: 164 | if size[0] > max_dim or size[1] > max_dim: 165 | raise ValueError( 166 | "Sparse matrix with shape {size} exceeds representable " 167 | "size with 16-bit indices.") 168 | 169 | def validate(self): 170 | _validate_matrix(self._size, 171 | self._data, 172 | self._row_indices, 173 | self._column_indices, 174 | self._offsets) 175 | 176 | # TODO(tgale): Add heavyweight data validation. 177 | 178 | def to(self, device): 179 | # TODO(tgale): Handle type conversions here. We 180 | # need to set the appropriate meta-data type for 181 | # the given floating-point type. 182 | self._data = self._data.to(device) 183 | self._row_indices = self._row_indices.to(device) 184 | self._column_indices = self._column_indices.to(device) 185 | self._offsets = self._offsets.to(device) 186 | self._column_indices_t = self._column_indices_t.to(device) 187 | self._offsets_t = self._offsets_t.to(device) 188 | self._block_offsets_t = self._block_offsets_t.to(device) 189 | return self 190 | 191 | def cuda(self): 192 | return self.to(torch.cuda.current_device()) 193 | 194 | def clone(self): 195 | return Matrix( 196 | self.size(), 197 | self.data.clone(), 198 | self.row_indices.clone(), 199 | self.column_indices.clone(), 200 | self.offsets.clone(), 201 | self.column_indices_t.clone(), 202 | self.offsets_t.clone(), 203 | self.block_offsets_t.clone()) 204 | 205 | def t(self): 206 | if self.dim() != 2: 207 | raise ValueError( 208 | "t() expects a tensor with <= 2 dimensions, " 209 | f"but self is {self.dim()}D.") 210 | out = Matrix(self.size(), 211 | self.data, 212 | self.row_indices, 213 | self.column_indices, 214 | self.offsets, 215 | self.column_indices_t, 216 | self.offsets_t, 217 | self.block_offsets_t) 218 | out._transposed = not self._transposed 219 | out._size = torch.Size((self._size[1], self._size[0])) 220 | return out 221 | 222 | def contiguous(self): 223 | raise ValueError("Not yet implemented.") 224 | 225 | def is_contiguous(self): 226 | return not self._transposed 227 | 228 | @property 229 | def is_cuda(self): 230 | return self._data.is_cuda 231 | 232 | @property 233 | def device(self): 234 | return self._data.device 235 | 236 | def size(self): 237 | return self._size 238 | 239 | @property 240 | def shape(self): 241 | return self.size() 242 | 243 | def dim(self): 244 | return len(self._size) 245 | 246 | @property 247 | def data(self): 248 | return self._data 249 | 250 | @property 251 | def row_indices(self): 252 | return self._row_indices 253 | 254 | @property 255 | def column_indices(self): 256 | return self._column_indices 257 | 258 | @property 259 | def offsets(self): 260 | return self._offsets 261 | 262 | @property 263 | def offsets_t(self): 264 | return self._offsets_t 265 | 266 | @property 267 | def column_indices_t(self): 268 | return self._column_indices_t 269 | 270 | @property 271 | def block_offsets_t(self): 272 | return self._block_offsets_t 273 | 274 | @property 275 | def dtype(self): 276 | return self.data.dtype 277 | 278 | @property 279 | def nnz(self): 280 | return self.data.numel() 281 | 282 | @property 283 | def blocking(self): 284 | return self.data.shape[1] 285 | 286 | @property 287 | def requires_grad(self): 288 | return self.data.requires_grad 289 | 290 | def requires_grad_(self, x): 291 | self.data.requires_grad_(x) 292 | return self 293 | 294 | def view(self, *shape): 295 | assert self.is_contiguous() 296 | if shape[-1] != self.size()[-1]: 297 | raise ValueError( 298 | "Can't change view on compressed dimension. " 299 | f"{self.size()[-1]} v. {shape[-1]}.") 300 | if np.prod(shape) != np.prod(self.size()): 301 | raise ValueError( 302 | "Mismatch in numel of Matrix and new shape. " 303 | f"{np.prod(self.size())} v. {np.prod(shape)}") 304 | return Matrix(shape, 305 | self.data, 306 | self.row_indices, 307 | self.column_indices, 308 | self.offsets, 309 | self.column_indices_t, 310 | self.offsets_t, 311 | self.block_offsets_t) 312 | 313 | @property 314 | def grad(self): 315 | # TODO(tgale): Make sure this mirrors torch.Tensor 316 | # behavior in the case where we ask for the gradient 317 | # of a non-contiguous tensor. 318 | size = self.size() 319 | if not self.is_contiguous(): 320 | size = torch.Size((size[1], size[0])) 321 | out = Matrix(size, 322 | self.data.grad, 323 | self.row_indices, 324 | self.column_indices, 325 | self.offsets, 326 | self.column_indices_t, 327 | self.offsets_t, 328 | self.block_offsets_t) 329 | return out if self.is_contiguous() else out.t() 330 | -------------------------------------------------------------------------------- /stk/backend/triton_kernels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | import triton.language as tl 4 | from dataclasses import dataclass 5 | 6 | @dataclass 7 | class TritonConfig: 8 | BLOCK_M: int = 128 9 | BLOCK_N: int = 128 10 | BLOCK_K: int = 32 11 | BLOCK_SIZE: int = 128 12 | NUM_STAGES: int = 4 13 | NUM_WARPS: int = 4 14 | 15 | def _validate_matmul_dims(M: int, K: int, N: int): 16 | error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" 17 | assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) 18 | assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) 19 | assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) 20 | 21 | @triton.autotune( 22 | configs=[ 23 | # basic configs for compute-bound matmuls 24 | triton.Config({ 25 | 'BLOCK_M': TritonConfig.BLOCK_M, 26 | 'BLOCK_N': TritonConfig.BLOCK_N, 27 | 'BLOCK_K': TritonConfig.BLOCK_K, 28 | 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE 29 | }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), 30 | ], 31 | key=['M', 'N', 'K'], 32 | ) 33 | @triton.jit 34 | def _sdd_kernel(A, B, C, M, N, K, 35 | stride_am, stride_ak, 36 | stride_bk, stride_bn, 37 | stride_cm, stride_cn, 38 | row_indices, column_indices, 39 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 40 | BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, 41 | ): 42 | # matrix multiplication 43 | pid = tl.program_id(0) 44 | pid_m = tl.load(row_indices + pid) 45 | pid_n = tl.load(column_indices + pid) 46 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 47 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 48 | ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 49 | rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 50 | rk = tl.arange(0, BLOCK_K) 51 | # pointers 52 | A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 53 | B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 54 | # do matrix multiplication 55 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 56 | for k in range(0, tl.cdiv(K, BLOCK_K)): 57 | a = tl.load(A) 58 | b = tl.load(B) 59 | acc += tl.dot(a, b) 60 | A += BLOCK_K * stride_ak 61 | B += BLOCK_K * stride_bk 62 | #Store to sparse matrix 63 | acc = acc.to(C.dtype.element_ty) 64 | BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE 65 | cm = tl.arange(0, BLOCK_M) 66 | cn = tl.arange(0, BLOCK_N) 67 | C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) 68 | tl.store(C, acc, mask=True) 69 | 70 | @triton.autotune( 71 | configs=[ 72 | # basic configs for compute-bound matmuls 73 | triton.Config({ 74 | 'BLOCK_M': TritonConfig.BLOCK_M, 75 | 'BLOCK_N': TritonConfig.BLOCK_N, 76 | 'BLOCK_K': TritonConfig.BLOCK_K, 77 | 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE 78 | }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), 79 | ], 80 | key=['M', 'N', 'K'], 81 | ) 82 | @triton.jit 83 | def _dsd_kernel(A, B, C, M, N, K, 84 | stride_am, stride_ak, 85 | stride_bk, stride_bn, 86 | stride_cm, stride_cn, 87 | row_indices, column_indices, offsets, 88 | block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, 89 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 90 | BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, 91 | ): 92 | 93 | # matrix multiplication 94 | pid_m = tl.program_id(0) 95 | pid_n = tl.program_id(1) 96 | 97 | num_pid_m = tl.num_programs(0) 98 | num_pid_n = tl.num_programs(1) 99 | pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) 100 | 101 | start_inx = tl.load(offsets + pid_m) 102 | end_inx = tl.load(offsets + pid_m + 1) 103 | 104 | # pointers to sparse matrix 105 | rm = tl.arange(0, BLOCK_M) 106 | rak = tl.arange(0, BLOCK_K) 107 | 108 | A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) 109 | 110 | # pointers to dense matrix 111 | rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 112 | rbk = tl.arange(0, BLOCK_K) 113 | B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) 114 | 115 | # do matrix multiplication 116 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 117 | nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) 118 | 119 | BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE 120 | ak_sub_incr = BLOCK_K * stride_ak 121 | bk_sub_incr = BLOCK_K * stride_bk 122 | bk_block_incr = BLOCK_SIZE * stride_bk 123 | 124 | for k in range(nsub_blocks * (end_inx - start_inx)): 125 | sub_block_inx = k % nsub_blocks 126 | block_inx = k // nsub_blocks 127 | 128 | if trans_A: 129 | ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr 130 | else: 131 | ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr 132 | 133 | ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr 134 | 135 | a = tl.load(ptr_A) 136 | b = tl.load(ptr_B) 137 | acc += tl.dot(a, b) 138 | 139 | acc = acc.to(C.dtype.element_ty) 140 | 141 | cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 142 | cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 143 | 144 | C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) 145 | tl.store(C, acc, mask=True) 146 | 147 | @triton.autotune( 148 | configs=[ 149 | # basic configs for compute-bound matmuls 150 | triton.Config({ 151 | 'BLOCK_M': TritonConfig.BLOCK_M, 152 | 'BLOCK_N': TritonConfig.BLOCK_N, 153 | 'BLOCK_K': TritonConfig.BLOCK_K, 154 | 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE 155 | }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), 156 | ], 157 | key=['M', 'N', 'K'], 158 | ) 159 | @triton.jit 160 | def _dds_kernel(A, B, C, M, N, K, 161 | stride_am, stride_ak, 162 | stride_bk, stride_bn, 163 | stride_cm, stride_cn, 164 | row_indices, column_indices, offsets, 165 | block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, 166 | BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, 167 | BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, 168 | ): 169 | 170 | # matrix multiplication 171 | pid_m = tl.program_id(0) 172 | pid_n = tl.program_id(1) 173 | 174 | num_pid_m = tl.num_programs(0) 175 | num_pid_n = tl.num_programs(1) 176 | pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) 177 | 178 | start_inx = tl.load(offsets + pid_n) 179 | end_inx = tl.load(offsets + pid_n + 1) 180 | 181 | # pointers to dense matrix 182 | rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 183 | rak = tl.arange(0, BLOCK_K) 184 | 185 | A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) 186 | 187 | # pointers to sparse matrix 188 | rn = tl.arange(0, BLOCK_N) 189 | rbk = tl.arange(0, BLOCK_K) 190 | B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) 191 | 192 | # do matrix multiplication 193 | acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) 194 | nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) 195 | 196 | BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE 197 | 198 | ak_sub_incr = BLOCK_K * stride_ak 199 | ak_block_incr = BLOCK_SIZE * stride_ak 200 | bk_sub_incr = BLOCK_K * stride_bk 201 | 202 | for k in range(nsub_blocks * (end_inx - start_inx)): 203 | sub_block_inx = k % nsub_blocks 204 | block_inx = k // nsub_blocks 205 | 206 | if trans_B: 207 | ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr 208 | else: 209 | ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr 210 | 211 | ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr 212 | a = tl.load(ptr_A) 213 | b = tl.load(ptr_B) 214 | acc += tl.dot(a, b) 215 | 216 | acc = acc.to(C.dtype.element_ty) 217 | cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 218 | cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 219 | C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) 220 | tl.store(C, acc, mask=True) 221 | 222 | def dsd(shape, 223 | data, 224 | offsets, 225 | row_indices, 226 | column_indices, 227 | offsets_t, 228 | column_indices_t, 229 | block_offsets_t, 230 | transpose_a, 231 | rhs, 232 | out 233 | ): 234 | 235 | device = rhs.device 236 | trans_A = transpose_a 237 | trans_B = False 238 | 239 | if rhs.stride(0) > 1 and rhs.stride(1) > 1: 240 | trans_B = True 241 | 242 | # checks constraints 243 | assert shape[1] == rhs.shape[0], "incompatible dimensions" 244 | M, K = shape 245 | _, N = rhs.shape 246 | 247 | _validate_matmul_dims(M, K, N) 248 | 249 | # accumulator types 250 | ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 251 | 252 | stride_am, stride_ak = data.stride(1), data.stride(2) 253 | stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) 254 | a_column_indices = column_indices 255 | a_offsets = offsets 256 | 257 | # launch kernel 258 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) 259 | 260 | if trans_A: 261 | stride_am, stride_ak = data.stride(2), data.stride(1) 262 | a_column_indices, a_offsets = column_indices_t, offsets_t 263 | 264 | if trans_B: 265 | stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) 266 | 267 | _dsd_kernel[grid]( 268 | data.data, rhs, out, M, N, K, 269 | stride_am, stride_ak, 270 | stride_bk, stride_bn, 271 | out.stride(0), out.stride(1), 272 | row_indices, a_column_indices, a_offsets, 273 | block_offsets_t, trans_A, trans_B, 274 | GROUP_M=128, ACC_TYPE=ACC_TYPE 275 | ) 276 | # return out 277 | 278 | def dds(lhs, 279 | shape, 280 | data, 281 | offsets, 282 | row_indices, 283 | column_indices, 284 | offsets_t, 285 | column_indices_t, 286 | block_offsets_t, 287 | transpose_b, 288 | out 289 | ): 290 | 291 | device = lhs.device 292 | trans_B = transpose_b 293 | trans_A = False 294 | 295 | if lhs.stride(0) > 1 and lhs.stride(1) > 1: 296 | trans_A = True 297 | 298 | # checks constraints 299 | assert lhs.shape[1] == shape[0], "incompatible dimensions" 300 | M, K = lhs.shape 301 | _, N = shape 302 | 303 | _validate_matmul_dims(M, K, N) 304 | 305 | # accumulator types 306 | ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 307 | 308 | stride_am, stride_ak = lhs.stride(0), lhs.stride(1) 309 | stride_bk, stride_bn = data.stride(1), data.stride(2) 310 | b_column_indices = column_indices_t 311 | b_offsets = offsets_t 312 | 313 | # launch kernel 314 | grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) 315 | 316 | if trans_A: 317 | stride_am, stride_ak = lhs.stride(1), lhs.stride(0) 318 | if trans_B: 319 | stride_bk, stride_bn = data.stride(2), data.stride(1) 320 | b_column_indices, b_offsets = column_indices, offsets 321 | 322 | _dds_kernel[grid]( 323 | lhs, data, out, M, N, K, 324 | stride_am, stride_ak, 325 | stride_bk, stride_bn, 326 | out.stride(0), out.stride(1), 327 | row_indices, b_column_indices, b_offsets, 328 | block_offsets_t, trans_A, trans_B, 329 | GROUP_M=128, ACC_TYPE=ACC_TYPE 330 | ) 331 | 332 | def sdd(lhs, 333 | rhs, 334 | shape, 335 | out, 336 | offsets, 337 | row_indices, 338 | column_indices 339 | ): 340 | 341 | device = out.device 342 | trans_A = False 343 | trans_B = False 344 | 345 | if lhs.stride(0) > 1 and lhs.stride(1) > 1: 346 | trans_A = True 347 | if rhs.stride(0) > 1 and rhs.stride(1) > 1: 348 | trans_B = True 349 | 350 | # checks constraints 351 | assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" 352 | M, K = lhs.shape 353 | _, N = rhs.shape 354 | 355 | _validate_matmul_dims(M, K, N) 356 | 357 | # accumulator types 358 | ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 359 | 360 | # launch kernel 361 | nnz_blocks = len(row_indices) 362 | grid = lambda META: (nnz_blocks,) 363 | 364 | stride_am, stride_ak = lhs.stride(0), lhs.stride(1) 365 | stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) 366 | 367 | if trans_A: 368 | stride_am, stride_ak = lhs.stride(1), lhs.stride(0) 369 | if trans_B: 370 | stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) 371 | 372 | _sdd_kernel[grid]( 373 | lhs, rhs, out, M, N, K, 374 | stride_am, stride_ak, 375 | stride_bk, stride_bn, 376 | out.stride(1), out.stride(2), 377 | row_indices, column_indices, 378 | GROUP_M=128, ACC_TYPE=ACC_TYPE 379 | ) 380 | 381 | @triton.jit 382 | def _row_indices_kernel(offsets, out): 383 | pid = tl.program_id(0) 384 | row_offset = tl.load(offsets + pid) 385 | nnz_blocks = tl.load(offsets + pid + 1) - row_offset 386 | for nnz_block in range(nnz_blocks): 387 | tl.store(out + row_offset + nnz_block, pid) 388 | 389 | def row_indices( 390 | shape, data, offsets, column_indices, out 391 | ): 392 | block_rows = len(offsets) - 1 393 | _row_indices_kernel[(block_rows, )](offsets, out) 394 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. Apache License 202 | Version 2.0, January 2004 203 | http://www.apache.org/licenses/ 204 | 205 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 206 | 207 | 1. Definitions. 208 | 209 | "License" shall mean the terms and conditions for use, reproduction, 210 | and distribution as defined by Sections 1 through 9 of this document. 211 | 212 | "Licensor" shall mean the copyright owner or entity authorized by 213 | the copyright owner that is granting the License. 214 | 215 | "Legal Entity" shall mean the union of the acting entity and all 216 | other entities that control, are controlled by, or are under common 217 | control with that entity. For the purposes of this definition, 218 | "control" means (i) the power, direct or indirect, to cause the 219 | direction or management of such entity, whether by contract or 220 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 221 | outstanding shares, or (iii) beneficial ownership of such entity. 222 | 223 | "You" (or "Your") shall mean an individual or Legal Entity 224 | exercising permissions granted by this License. 225 | 226 | "Source" form shall mean the preferred form for making modifications, 227 | including but not limited to software source code, documentation 228 | source, and configuration files. 229 | 230 | "Object" form shall mean any form resulting from mechanical 231 | transformation or translation of a Source form, including but 232 | not limited to compiled object code, generated documentation, 233 | and conversions to other media types. 234 | 235 | "Work" shall mean the work of authorship, whether in Source or 236 | Object form, made available under the License, as indicated by a 237 | copyright notice that is included in or attached to the work 238 | (an example is provided in the Appendix below). 239 | 240 | "Derivative Works" shall mean any work, whether in Source or Object 241 | form, that is based on (or derived from) the Work and for which the 242 | editorial revisions, annotations, elaborations, or other modifications 243 | represent, as a whole, an original work of authorship. For the purposes 244 | of this License, Derivative Works shall not include works that remain 245 | separable from, or merely link (or bind by name) to the interfaces of, 246 | the Work and Derivative Works thereof. 247 | 248 | "Contribution" shall mean any work of authorship, including 249 | the original version of the Work and any modifications or additions 250 | to that Work or Derivative Works thereof, that is intentionally 251 | submitted to Licensor for inclusion in the Work by the copyright owner 252 | or by an individual or Legal Entity authorized to submit on behalf of 253 | the copyright owner. For the purposes of this definition, "submitted" 254 | means any form of electronic, verbal, or written communication sent 255 | to the Licensor or its representatives, including but not limited to 256 | communication on electronic mailing lists, source code control systems, 257 | and issue tracking systems that are managed by, or on behalf of, the 258 | Licensor for the purpose of discussing and improving the Work, but 259 | excluding communication that is conspicuously marked or otherwise 260 | designated in writing by the copyright owner as "Not a Contribution." 261 | 262 | "Contributor" shall mean Licensor and any individual or Legal Entity 263 | on behalf of whom a Contribution has been received by Licensor and 264 | subsequently incorporated within the Work. 265 | 266 | 2. Grant of Copyright License. Subject to the terms and conditions of 267 | this License, each Contributor hereby grants to You a perpetual, 268 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 269 | copyright license to reproduce, prepare Derivative Works of, 270 | publicly display, publicly perform, sublicense, and distribute the 271 | Work and such Derivative Works in Source or Object form. 272 | 273 | 3. Grant of Patent License. Subject to the terms and conditions of 274 | this License, each Contributor hereby grants to You a perpetual, 275 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 276 | (except as stated in this section) patent license to make, have made, 277 | use, offer to sell, sell, import, and otherwise transfer the Work, 278 | where such license applies only to those patent claims licensable 279 | by such Contributor that are necessarily infringed by their 280 | Contribution(s) alone or by combination of their Contribution(s) 281 | with the Work to which such Contribution(s) was submitted. If You 282 | institute patent litigation against any entity (including a 283 | cross-claim or counterclaim in a lawsuit) alleging that the Work 284 | or a Contribution incorporated within the Work constitutes direct 285 | or contributory patent infringement, then any patent licenses 286 | granted to You under this License for that Work shall terminate 287 | as of the date such litigation is filed. 288 | 289 | 4. Redistribution. You may reproduce and distribute copies of the 290 | Work or Derivative Works thereof in any medium, with or without 291 | modifications, and in Source or Object form, provided that You 292 | meet the following conditions: 293 | 294 | (a) You must give any other recipients of the Work or 295 | Derivative Works a copy of this License; and 296 | 297 | (b) You must cause any modified files to carry prominent notices 298 | stating that You changed the files; and 299 | 300 | (c) You must retain, in the Source form of any Derivative Works 301 | that You distribute, all copyright, patent, trademark, and 302 | attribution notices from the Source form of the Work, 303 | excluding those notices that do not pertain to any part of 304 | the Derivative Works; and 305 | 306 | (d) If the Work includes a "NOTICE" text file as part of its 307 | distribution, then any Derivative Works that You distribute must 308 | include a readable copy of the attribution notices contained 309 | within such NOTICE file, excluding those notices that do not 310 | pertain to any part of the Derivative Works, in at least one 311 | of the following places: within a NOTICE text file distributed 312 | as part of the Derivative Works; within the Source form or 313 | documentation, if provided along with the Derivative Works; or, 314 | within a display generated by the Derivative Works, if and 315 | wherever such third-party notices normally appear. The contents 316 | of the NOTICE file are for informational purposes only and 317 | do not modify the License. You may add Your own attribution 318 | notices within Derivative Works that You distribute, alongside 319 | or as an addendum to the NOTICE text from the Work, provided 320 | that such additional attribution notices cannot be construed 321 | as modifying the License. 322 | 323 | You may add Your own copyright statement to Your modifications and 324 | may provide additional or different license terms and conditions 325 | for use, reproduction, or distribution of Your modifications, or 326 | for any such Derivative Works as a whole, provided Your use, 327 | reproduction, and distribution of the Work otherwise complies with 328 | the conditions stated in this License. 329 | 330 | 5. Submission of Contributions. Unless You explicitly state otherwise, 331 | any Contribution intentionally submitted for inclusion in the Work 332 | by You to the Licensor shall be under the terms and conditions of 333 | this License, without any additional terms or conditions. 334 | Notwithstanding the above, nothing herein shall supersede or modify 335 | the terms of any separate license agreement you may have executed 336 | with Licensor regarding such Contributions. 337 | 338 | 6. Trademarks. This License does not grant permission to use the trade 339 | names, trademarks, service marks, or product names of the Licensor, 340 | except as required for reasonable and customary use in describing the 341 | origin of the Work and reproducing the content of the NOTICE file. 342 | 343 | 7. Disclaimer of Warranty. Unless required by applicable law or 344 | agreed to in writing, Licensor provides the Work (and each 345 | Contributor provides its Contributions) on an "AS IS" BASIS, 346 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 347 | implied, including, without limitation, any warranties or conditions 348 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 349 | PARTICULAR PURPOSE. You are solely responsible for determining the 350 | appropriateness of using or redistributing the Work and assume any 351 | risks associated with Your exercise of permissions under this License. 352 | 353 | 8. Limitation of Liability. In no event and under no legal theory, 354 | whether in tort (including negligence), contract, or otherwise, 355 | unless required by applicable law (such as deliberate and grossly 356 | negligent acts) or agreed to in writing, shall any Contributor be 357 | liable to You for damages, including any direct, indirect, special, 358 | incidental, or consequential damages of any character arising as a 359 | result of this License or out of the use or inability to use the 360 | Work (including but not limited to damages for loss of goodwill, 361 | work stoppage, computer failure or malfunction, or any and all 362 | other commercial damages or losses), even if such Contributor 363 | has been advised of the possibility of such damages. 364 | 365 | 9. Accepting Warranty or Additional Liability. While redistributing 366 | the Work or Derivative Works thereof, You may choose to offer, 367 | and charge a fee for, acceptance of support, warranty, indemnity, 368 | or other liability obligations and/or rights consistent with this 369 | License. However, in accepting such obligations, You may act only 370 | on Your own behalf and on Your sole responsibility, not on behalf 371 | of any other Contributor, and only if You agree to indemnify, 372 | defend, and hold each Contributor harmless for any liability 373 | incurred by, or claims asserted against, such Contributor by reason 374 | of your accepting any such warranty or additional liability. 375 | 376 | END OF TERMS AND CONDITIONS 377 | 378 | Copyright 2023 STK authors 379 | 380 | Licensed under the Apache License, Version 2.0 (the "License"); 381 | you may not use this file except in compliance with the License. 382 | You may obtain a copy of the License at 383 | 384 | http://www.apache.org/licenses/LICENSE-2.0 385 | 386 | Unless required by applicable law or agreed to in writing, software 387 | distributed under the License is distributed on an "AS IS" BASIS, 388 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 389 | See the License for the specific language governing permissions and 390 | limitations under the License. --------------------------------------------------------------------------------