├── pytorch_block_sparse ├── tests │ ├── __init__.py │ ├── test_data_parallel.py │ ├── test_save.py │ ├── test_emulate.py │ ├── test_basic.py │ ├── test_sparse_optimizer.py │ ├── test_integration.py │ ├── test_replace.py │ ├── test_linear_nn.py │ ├── test_matmul_back.py │ └── test_matmul.py ├── __init__.py ├── native │ ├── block_sparse_native.cpp │ ├── block_sparse_cutlass_kernel.cu │ ├── block_sparse_cutlass_kernel_back.cu │ ├── cutlass_dispatch.h │ └── cutlass_dispatch_back.h ├── cutlass │ ├── util │ │ ├── printable.h │ │ ├── util.h │ │ ├── matrix_transform.h │ │ ├── math.h │ │ ├── debug.h │ │ └── device_introspection.h │ └── gemm │ │ ├── epilogue_function.h │ │ ├── block_loader.h │ │ ├── dp_accummulate.h │ │ ├── grid_raster_sparse.h │ │ ├── thread_accumulator.h │ │ ├── k_split_control.h │ │ └── dispatch_policies.h ├── util.py ├── block_sparse_linear.py └── sparse_optimizer.py ├── setup.cfg ├── MANIFEST.in ├── .gitignore ├── LICENSE.TXT ├── doc ├── Troubleshooting.md ├── DevNotes.md └── notebooks │ └── ModelSparsification.ipynb ├── setup.py └── README.md /pytorch_block_sparse/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | graft pytorch_block_sparse/cutlass/ 3 | graft pytorch_block_sparse/native/ 4 | graft pytorch_block_sparse/tests/ 5 | global-exclude *.py[cod] 6 | global-exclude *~ 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | *.pyo 4 | 5 | # Setuptools distribution folder. 6 | /dist/ 7 | /build/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | /*.egg 12 | 13 | # emacs Files 14 | *~ 15 | 16 | # Python cache files 17 | __pycache__/ -------------------------------------------------------------------------------- /pytorch_block_sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_sparse import BlockSparseMatrix, BlockSparseMatrixEmulator 2 | from .block_sparse_linear import BlockSparseLinear 3 | from .sparse_optimizer import SparseOptimizer 4 | from .util import BlockSparseModelPatcher 5 | 6 | __all__ = [BlockSparseMatrix, BlockSparseMatrixEmulator, BlockSparseLinear, BlockSparseModelPatcher, SparseOptimizer] 7 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_data_parallel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import torch 5 | import torch.nn 6 | 7 | from pytorch_block_sparse import BlockSparseLinear 8 | 9 | 10 | class TestFun(TestCase): 11 | def test1(self): 12 | linear = BlockSparseLinear(64, 128, False).to("cuda") 13 | model = torch.nn.DataParallel(linear) 14 | 15 | input_tensor = torch.randn(64, 64).cuda() 16 | 17 | _ = model(input_tensor) 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /pytorch_block_sparse/native/block_sparse_native.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | int blocksparse_matmul_cutlass(torch::Tensor dense_a, 7 | bool pytorch_contiguous_a, 8 | torch::Tensor ptr_b, 9 | torch::Tensor indices_b, 10 | torch::Tensor data_b, 11 | int m, 12 | int n, 13 | int k, 14 | int block_size_rows_b, 15 | int block_size_cols_b, 16 | torch::Tensor dense_out); 17 | 18 | int blocksparse_matmul_back_cutlass(torch::Tensor dense_a, 19 | bool pytorch_contiguous_a, 20 | torch::Tensor dense_b, 21 | bool pytorch_contiguous_b, 22 | int m, 23 | int n, 24 | int k, 25 | int block_size_rows_b, 26 | int block_size_cols_b, 27 | torch::Tensor sparse_c, 28 | torch::Tensor sparse_blocks_c, 29 | long sparse_blocks_length_c); 30 | 31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 32 | m.def("blocksparse_matmul_cutlass", &blocksparse_matmul_cutlass, "blocksparse_matmul_cutlass"); 33 | m.def("blocksparse_matmul_back_cutlass", &blocksparse_matmul_back_cutlass, "blocksparse_matmul_back_cutlass"); 34 | } -------------------------------------------------------------------------------- /LICENSE.TXT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | Copyright (c) 2020, HUGGING FACE INC. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | * Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | * Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | * Neither the name of the NVIDIA CORPORATION nor the 12 | names of its contributors may be used to endorse or promote products 13 | derived from this software without specific prior written permission. 14 | * Neither the name of HUGGING FACE INC nor the 15 | names of its contributors may be used to endorse or promote products 16 | derived from this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 22 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 25 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /doc/Troubleshooting.md: -------------------------------------------------------------------------------- 1 | # Troubleshooting 2 | 3 | ## Locate the problem 4 | If your model has an unexpected behaviour, and you suspect that pytorch_block_sparse maybe the culprit, 5 | you can easily test the hypothesis. 6 | - First make sure that your model works with a dense version, aka torch.nn.Linear. 7 | 8 | - Then, if everything is ok, use the `block_sparse_linear.PseudoBlockSparseLinear` layer, 9 | using your `BlockSparseLinear` object to initialize it, instead of directly using the `BlockSparseLinear`. 10 | `PseudoBlockSparseLinear` use only PyTorch primitives, so if your model has still issues after that change, 11 | that means that the CUDA kernels are not responsible and that 12 | the problem may be that sparsity creates learning instability, or that there is a problem in you own code. 13 | 14 | - If your issue disappear after using `PseudoBlockSparseLinear`, 15 | please fill a PR with the details to reproduce it, we will be glad to investigate. 16 | 17 | ## Helper included in BlockSparseModelPatcher 18 | 19 | If you are using `BlockSparseModelPatcher`, there is an easy way to switch to PseudoBlockSparseLinear. 20 | Just use the `"pseudo_linear":True` (key,value) in the `add_pattern(...,patch_info=)` parameter: 21 | 22 | 23 | ```python 24 | from pytorch_block_sparse import BlockSparseModelPatcher 25 | # Create a model patcher 26 | mp = BlockSparseModelPatcher() 27 | 28 | # Selecting some layers to sparsify. 29 | # We use a PseudoBlockSparseLayer to check CUDA kernels 30 | mp.add_pattern("roberta\.encoder\.layer\.[0-9]+\.intermediate\.dense", patch_info={"density":0.5, "pseudo_linear":True}) 31 | mp.add_pattern("roberta\.encoder\.layer\.[0-9]+\.output\.dense", patch_info={"density":0.5, "pseudo_linear":True}) 32 | mp.add_pattern("roberta\.encoder\.layer\.[0-9]+\.attention\.output\.dense", patch_info={"density":0.5, "pseudo_linear":True}) 33 | mp.patch_model(model) 34 | 35 | print(f"Final model parameters count={model.num_parameters()}") 36 | 37 | ``` 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from setuptools import setup 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | rootdir = os.path.dirname(os.path.realpath(__file__)) 8 | 9 | version = "0.1.2" 10 | 11 | ext_modules = [] 12 | 13 | if torch.cuda.is_available(): 14 | ext = CUDAExtension( 15 | "block_sparse_native", 16 | [ 17 | "pytorch_block_sparse/native/block_sparse_native.cpp", 18 | "pytorch_block_sparse/native/block_sparse_cutlass_kernel_back.cu", 19 | "pytorch_block_sparse/native/block_sparse_cutlass_kernel.cu", 20 | ], 21 | extra_compile_args=["-I", "%s/pytorch_block_sparse" % rootdir], 22 | ) 23 | ext_modules = [ext] 24 | else: 25 | print("WARNING: torch cuda seems unavailable, emulated features only will be available.") 26 | 27 | setup( 28 | name="pytorch_block_sparse", 29 | version=version, 30 | description="PyTorch extension for fast block sparse matrices computation," 31 | " drop in replacement for torch.nn.Linear.", 32 | long_description="pytorch_block_sparse is a PyTorch extension for fast block sparse matrices computation," 33 | " drop in replacement for torch.nn.Linear", 34 | classifiers=[ 35 | "Development Status :: 4 - Beta", 36 | "License :: OSI Approved :: BSD License", 37 | "Programming Language :: Python :: 3.0", 38 | ], 39 | keywords="PyTorch,sparse,matrices,machine learning", 40 | url="https://github.com/huggingface/pytorch_block_sparse", 41 | author="François Lagunas", 42 | author_email="francois.lagunas@m4x.org", 43 | download_url=f"https://test.pypi.org/project/pytorch-block-sparse/{version}/", 44 | license='BSD 3-Clause "New" or "Revised" License', 45 | packages=["pytorch_block_sparse"], 46 | install_requires=[], 47 | include_package_data=True, 48 | zip_safe=False, 49 | ext_modules=ext_modules, 50 | cmdclass={"build_ext": BuildExtension}, 51 | ) 52 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_save.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import unittest 3 | from unittest import TestCase 4 | 5 | import torch 6 | 7 | from pytorch_block_sparse import BlockSparseLinear, BlockSparseMatrix 8 | 9 | 10 | class TestFun(TestCase): 11 | def test0(self): 12 | sizes = [64, 64] 13 | block_size = (32, 32) 14 | block_count = 2 15 | bsm = BlockSparseMatrix.randn(sizes, block_count, blocks=None, block_shape=block_size, device="cuda") 16 | 17 | with tempfile.NamedTemporaryFile() as tf: 18 | torch.save(bsm, tf.name) 19 | 20 | bsm2 = torch.load(tf.name) 21 | 22 | self.assertTrue((bsm.to_dense() == bsm2.to_dense()).all()) 23 | 24 | def test1(self): 25 | sizes = [256, 256] 26 | 27 | linear = BlockSparseLinear(sizes[0], sizes[1], True, 0.5) 28 | 29 | with tempfile.NamedTemporaryFile() as tf: 30 | torch.save(linear, tf.name) 31 | 32 | linear2 = torch.load(tf.name) 33 | 34 | self.assertTrue((linear.weight.to_dense() == linear2.weight.to_dense()).all()) 35 | self.assertTrue((linear.bias == linear2.bias).all()) 36 | 37 | def test2(self): 38 | sizes = [256, 256] 39 | 40 | linear = BlockSparseLinear(sizes[0], sizes[1], True, 0.5) 41 | 42 | with tempfile.NamedTemporaryFile() as tf: 43 | state_dict = linear.state_dict() 44 | torch.save(state_dict, tf.name) 45 | 46 | linear2 = BlockSparseLinear(sizes[0], sizes[1], True, 0.5) 47 | 48 | linear2.load_state_dict(torch.load(tf.name)) 49 | 50 | self.assertTrue((linear.weight.to_dense() == linear2.weight.to_dense()).all()) 51 | self.assertTrue((linear.bias == linear2.bias).all()) 52 | 53 | def tst3(self): 54 | sizes = [256, 256] 55 | 56 | linear = BlockSparseLinear(sizes[0], sizes[1], True, 0.5) 57 | 58 | with tempfile.NamedTemporaryFile() as tf: 59 | state_dict = linear.state_dict() 60 | torch.save(state_dict, tf.name) 61 | 62 | linear2 = BlockSparseLinear(sizes[0], sizes[1], True, 1.0) 63 | 64 | linear2.load_state_dict(torch.load(tf.name)) 65 | 66 | self.assertTrue((linear.weight.to_dense() == linear2.weight.to_dense()).all()) 67 | self.assertTrue((linear.bias == linear2.bias).all()) 68 | 69 | 70 | if __name__ == "__main__": 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_emulate.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import torch 4 | 5 | from pytorch_block_sparse import BlockSparseMatrix, BlockSparseMatrixEmulator 6 | 7 | 8 | class TestFun(TestCase): 9 | def help_contruct(self, shape, block_mask, data, block_shape=(16, 16)): 10 | try: 11 | real = BlockSparseMatrix(shape, block_mask, data, block_shape) 12 | except Exception: 13 | real = None 14 | emul = BlockSparseMatrixEmulator(shape, block_mask, data, block_shape) 15 | 16 | return real, emul 17 | 18 | def help_randn( 19 | cls, 20 | shape, 21 | n_blocks, 22 | blocks=None, 23 | block_shape=(32, 32), 24 | device="cuda", 25 | positive=False, 26 | ): 27 | try: 28 | real = BlockSparseMatrix.randn(shape, n_blocks, blocks, block_shape, device=device, positive=positive) 29 | except Exception: 30 | real = None 31 | emul = BlockSparseMatrixEmulator.randn(shape, n_blocks, blocks, block_shape, device=device, positive=positive) 32 | 33 | return real, emul 34 | 35 | def test0(self): 36 | d = dict 37 | test_sizes = [d(nb=2, s=(4, 8), bs=(1, 4))] 38 | map = d(nb="n_blocks", s="shape", bs="block_shape") 39 | 40 | for ts in test_sizes: 41 | ts = {map[k]: v for k, v in ts.items()} 42 | self.help_randn(**ts, device="cpu") 43 | 44 | def test_from_dense(self): 45 | dense = torch.randn(8, 8).cuda() 46 | d = dict 47 | 48 | tests = [ 49 | d(blocks=[[0, 0], [1, 3], [3, 2]], block_shape=(1, 2)), 50 | d(blocks=[[0, 0], [1, 1], [3, 1]], block_shape=(2, 4)), 51 | d(block_shape=(2, 4)), 52 | ] 53 | 54 | for test in tests: 55 | blocks = test.get("blocks") 56 | nblocks = test.get("nblocks") 57 | block_shape = test["block_shape"] 58 | mask = BlockSparseMatrixEmulator.ones( 59 | dense.shape, block_shape=block_shape, blocks=blocks, n_blocks=nblocks 60 | ).to_dense() 61 | 62 | versions = [] 63 | for slow in False, True: 64 | sparse = BlockSparseMatrixEmulator.from_dense(dense, block_shape=block_shape, blocks=blocks, slow=slow) 65 | versions.append(sparse) 66 | 67 | for i, sparse in enumerate(versions): 68 | dense2 = sparse.to_dense() 69 | self.assertTrue(((dense * mask) == dense2).all()) 70 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_basic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import torch 5 | from torch import tensor 6 | 7 | from pytorch_block_sparse import BlockSparseMatrix 8 | 9 | 10 | class TestFun(TestCase): 11 | def test0(self): 12 | tests = [ 13 | dict( 14 | size=[128, 64], 15 | blocks=[ 16 | (0, 0), 17 | (1, 0), 18 | (2, 0), 19 | (0, 1), 20 | ], 21 | row_start_ends_a=tensor([0, 2, 3, 4, 4]), 22 | cols_a=tensor([[0, 0], [1, 1], [0, 2], [0, 3]]), 23 | col_start_ends_b=tensor([0, 3, 4]), 24 | rows_b=tensor([[0, 0], [1, 2], [2, 3], [0, 1]]), 25 | ) 26 | ] 27 | block_shape = (32, 32) 28 | device = "cuda" 29 | for test_info in tests: 30 | size = test_info["size"] 31 | blocks = test_info["blocks"] 32 | bsm = BlockSparseMatrix.randn( 33 | (size[0], size[1]), 34 | None, 35 | blocks=blocks, 36 | block_shape=block_shape, 37 | device=device, 38 | ) 39 | 40 | for key in test_info: 41 | if "row" in key or "col" in key: 42 | bsm_a = getattr(bsm, key) 43 | ref = test_info[key].to(device=device, dtype=torch.int32) 44 | check = (bsm_a == ref).all() 45 | if not check: 46 | raise Exception(f"Non matching attribute {key}:\n{bsm_a}\n!=\n{ref} (ref).") 47 | 48 | def test1(self): 49 | sizes = [(32, 32), (64, 32), (32, 64), (64, 64), (256, 64)] 50 | for size in sizes: 51 | print(f"size={size}") 52 | block_shape = (32, 32) 53 | block_count = size[0] * size[1] // (block_shape[0] * block_shape[1]) 54 | device = "cuda" 55 | 56 | bsm = BlockSparseMatrix.randn(size, block_count, block_shape=block_shape, device=device) 57 | a = bsm.to_dense() 58 | bsm.check_with_dense(a) 59 | 60 | bsm2 = BlockSparseMatrix.from_dense(a, block_shape, block_count=None) 61 | bsm2.check_with_dense(a) 62 | 63 | a2 = bsm2.to_dense() 64 | 65 | if not (a == a2).all(): 66 | print((a == a2)[::8, ::8]) 67 | raise Exception("Non matching matrices, BlockSparseMatrix.from_dense is not correct.") 68 | 69 | def test2(self): 70 | bsm = BlockSparseMatrix.zeros((32, 32), 1, block_shape=(32, 32), device="cuda") 71 | hash(bsm) 72 | 73 | 74 | if __name__ == "__main__": 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/util/printable.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * \brief Pure virtual base class for printable types 33 | */ 34 | 35 | #include 36 | 37 | 38 | namespace cutlass { 39 | 40 | 41 | /****************************************************************************** 42 | * printable_t 43 | ******************************************************************************/ 44 | 45 | /** 46 | * Pure virtual base class for printable types 47 | */ 48 | struct printable_t 49 | { 50 | /// Returns the instance as a string 51 | __host__ __device__ inline 52 | virtual char const* to_string() const = 0; 53 | 54 | /// Insert the formatted instance into the output stream 55 | virtual void print(std::ostream& out) const = 0; 56 | 57 | /// Destructor 58 | virtual ~printable_t() {} 59 | }; 60 | 61 | 62 | /// Insert the formatted \p printable into the output stream 63 | inline std::ostream& operator<<( 64 | std::ostream& out, 65 | printable_t const& printable) 66 | { 67 | printable.print(out); 68 | return out; 69 | } 70 | 71 | 72 | } // namespace cutlass 73 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_sparse_optimizer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import torch 5 | import torch.optim as optim 6 | 7 | from pytorch_block_sparse import ( 8 | BlockSparseLinear, 9 | BlockSparseMatrix, 10 | SparseOptimizer, 11 | ) 12 | from pytorch_block_sparse.sparse_optimizer import ( 13 | MagnitudeSparseOptimizerStrategy, 14 | ) 15 | 16 | 17 | class TestFun(TestCase): 18 | def check_differences(self, bsm, reference_dense, expected_block_changes): 19 | dense = bsm.to_dense() 20 | 21 | differences = reference_dense != dense 22 | block_shape = bsm.block_shape 23 | differences = float(differences.float().sum() / (block_shape[0] * block_shape[1])) 24 | 25 | self.assertEqual(differences, expected_block_changes) 26 | 27 | def test0(self): 28 | size = (256, 256) 29 | block_count = 32 30 | cleanup_ratio = 0.1 31 | block_shape = (32, 32) 32 | bsm = BlockSparseMatrix.randn(size, block_count, block_shape=block_shape, device="cuda") 33 | 34 | dense0 = bsm.to_dense() 35 | 36 | strategy = MagnitudeSparseOptimizerStrategy(cleanup_ratio) 37 | strategy.run(bsm) 38 | 39 | expected_block_changes = int(cleanup_ratio * block_count) * 2 40 | self.check_differences(bsm, dense0, expected_block_changes) 41 | 42 | def test_sparse_optimizer(self): 43 | size = (256, 256) 44 | block_count = 32 45 | cleanup_ratio = 0.1 46 | block_shape = (32, 32) 47 | bsm = BlockSparseMatrix.randn(size, block_count, block_shape=block_shape, device="cuda") 48 | dense0 = bsm.to_dense() 49 | 50 | so = SparseOptimizer([bsm], lr=cleanup_ratio) 51 | 52 | so.step() 53 | 54 | expected_block_changes = int(cleanup_ratio * block_count) * 2 55 | self.check_differences(bsm, dense0, expected_block_changes) 56 | 57 | def test_sparse_optimizer_attached_optimizer(self): 58 | size = (256, 256) 59 | density = 0.5 60 | cleanup_ratio = 0.1 61 | 62 | linear = BlockSparseLinear(size[0], size[1], True, density).cuda() 63 | 64 | sparse_objects = SparseOptimizer.sparse_objects(linear) 65 | 66 | self.assertEqual(len(sparse_objects), 1) 67 | 68 | so = SparseOptimizer(sparse_objects, lr=cleanup_ratio) 69 | 70 | adam = optim.Adam(linear.parameters()) 71 | 72 | so.attach_optimizer(adam) 73 | 74 | # Run forward and backward 75 | a = torch.randn([1, size[0]]).abs().cuda() 76 | out = linear(a) 77 | 78 | loss = out.sum() 79 | 80 | loss.backward() 81 | 82 | adam.step() 83 | 84 | dense0 = linear.weight.to_dense() 85 | 86 | so.step() 87 | 88 | block_count = linear.block_count 89 | expected_block_changes = int(cleanup_ratio * block_count) * 2 90 | self.check_differences(linear.weight, dense0, expected_block_changes) 91 | 92 | 93 | if __name__ == "__main__": 94 | unittest.main() 95 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/util/util.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * \brief Umbrella header file for utilities 33 | */ 34 | 35 | #include "debug.h" 36 | #include "device_introspection.h" 37 | #include "io_intrinsics.h" 38 | #include "math.h" 39 | #include "nv_std.h" 40 | #include "printable.h" 41 | #include "matrix_transform.h" 42 | 43 | 44 | 45 | namespace cutlass { 46 | 47 | 48 | /****************************************************************************** 49 | * int_constant 50 | ******************************************************************************/ 51 | 52 | /** 53 | * Shorthand for nv_std::integral_constant of int32_t type 54 | */ 55 | template 56 | struct int_constant : nv_std::integral_constant 57 | {}; 58 | 59 | 60 | /****************************************************************************** 61 | * Uninitialized 62 | ******************************************************************************/ 63 | 64 | /** 65 | * \brief A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions 66 | */ 67 | template 68 | struct __align__(16) uninitialized 69 | { 70 | /// Backing storage 71 | uint8_t storage[sizeof(T)]; 72 | 73 | /// Alias 74 | __host__ __device__ __forceinline__ T& alias() 75 | { 76 | return reinterpret_cast(*this); 77 | } 78 | }; 79 | 80 | 81 | 82 | } // namespace cutlass 83 | -------------------------------------------------------------------------------- /doc/DevNotes.md: -------------------------------------------------------------------------------- 1 | # Development Notes 2 | 3 | 4 | This python package provides a PyTorch extension . 5 | 6 | 7 | ## Organisation 8 | ### Build 9 | 10 | The setup.py script use the standard PyTorch extension mechanism to build the package: 11 | 12 | ``` 13 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 14 | ... 15 | ext_modules=[ 16 | CUDAExtension('block_sparse_native', 17 | ['pytorch_block_sparse/native/block_sparse_native.cpp', 18 | 'pytorch_block_sparse/native/block_sparse_cutlass_kernel_back.cu', 19 | 'pytorch_block_sparse/native/block_sparse_cutlass_kernel.cu'], 20 | extra_compile_args=['-I', '%s/pytorch_block_sparse' % rootdir] 21 | ), 22 | ], 23 | cmdclass={ 24 | 'build_ext': BuildExtension 25 | } 26 | ``` 27 | 28 | ### Native functions python interface 29 | A single c++ file `block_sparse_native.cpp` provides the native functions visible from python. 30 | These functions provides access to CUDA kernels which computes : 31 | - dense x native -> dense 32 | - dense x dense on sparse support -> sparse 33 | 34 | ### CUDA/Cutlass kernels 35 | The `*.cu` files in the `native` directory provides the kernel themselves. 36 | They are using the cutlass primitives available in the `cutlass` subdirectory. 37 | 38 | Multiple levels of C++ templating provides dispatch/code generation of the kernels. 39 | 40 | The main files in the `cutlass/gemm` directory are `block_task.h` and `block_task_back.h` . 41 | They express the final CUDA kernel that will be executed, using 42 | - `block_loader_.*` to load A and B matrix tiles in an efficient way 43 | - `thread_accumulator.h` to store the result tiles 'R' 44 | - `epilogue_function` to combine R with C `C' = alpha * R + beta * C` 45 | - `grid_raster_.*` to list the output tiles that must be computed 46 | 47 | ### block_sparse python module 48 | This library includes as little native code as possible, because native code is hard to write/debug/understand. 49 | 50 | The native functions are performing the performance critical tasks, and the python code in `block_sparse.py` is doing 51 | all the preparatory work, which is executed only once, or a unfrequently. 52 | 53 | The main job of `block_sparse.py` is to build indexes into the sparse matrices. 54 | Three sets of sparse indices are built: 55 | - row wise index of non-zero entries (for dense x sparse) 56 | - column wise index of non-zero entries (for dense x sparse with transposition) 57 | - linear list of 2D coordinates of non-zero entries (for dense x dense on sparse support) 58 | 59 | These structures are created using standard PyTorch primitives, and so are easy to debug, understand, 60 | or reimplement in other languages. 61 | 62 | ### block_sparse_linear python module 63 | The block_sparse_linear is a thin layer on top of `block_sparse` 64 | It use the linear algebra primitives of block_sparse to create a drop in replacement for `torch.nn.Linear`, 65 | with the proper back-propagation primitives, implemented using a `torch.autograd.Function` subclass. 66 | 67 | ## Testing 68 | Debugging CUDA kernels is hard. Fortunately, it's easy to compare the kernel results with 69 | a reference PyTorch implementation. 70 | The `tests` directory provides some code to test and measure performance of the library. 71 | 72 | ## TODO 73 | 74 | block_sparse 75 | - add input parameters sanity checks 76 | - add dispatch for 77 | - different matrix size -> different dispatch strategy (tile sizes in k-dimension) 78 | - different block sizes 79 | 80 | tests 81 | - Refactor/cleanup tests 82 | 83 | doc 84 | - schema of sparse index structures 85 | 86 | cutlass 87 | - move to 2.x version 88 | 89 | cleanup algorithms 90 | - add algorithms to measure weights importance and optimize the sparsity pattern 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /pytorch_block_sparse/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | 5 | from pytorch_block_sparse import BlockSparseLinear 6 | from pytorch_block_sparse.block_sparse_linear import PseudoBlockSparseLinear 7 | 8 | 9 | class ModelPatcher: 10 | def __init__(self): 11 | self.patterns = [] 12 | 13 | def is_patchable(self, module_name, module, raiseError): 14 | return True 15 | 16 | def get_patchable_layers(self, model): 17 | # Layer names (displayed as regexps)") 18 | ret = [] 19 | for k, v in model.named_modules(): 20 | if self.is_patchable(k, v, raiseError=False): 21 | r = re.escape(k) 22 | ret.append({"regexp": r, "layer": v}) 23 | return ret 24 | 25 | def add_pattern(self, pattern, patch_info): 26 | self.patterns.append(dict(pattern=pattern, patch_info=patch_info)) 27 | 28 | def pattern_match(self, module_name): 29 | for pattern_def in self.patterns: 30 | if re.match(pattern_def["pattern"], module_name): 31 | return True, pattern_def["patch_info"] 32 | return False, -1 33 | 34 | def new_child_module(self, child_module_name, child_module, patch_info): 35 | raise NotImplementedError("Implement this in subclasses") 36 | 37 | def replace_module(self, father, child_module_name, child_name, child_module, patch_info): 38 | new_child_module = self.new_child_module(child_module_name, child_module, patch_info) 39 | if new_child_module is not None: 40 | setattr(father, child_name, new_child_module) 41 | 42 | def patch_model(self, model): 43 | modules = {} 44 | modified = False 45 | for k, v in model.named_modules(): 46 | modules[k] = v 47 | match, patch_info = self.pattern_match(k) 48 | if match and self.is_patchable(k, v, raiseError=True): 49 | parts = k.split(".") 50 | father_module_name = ".".join(parts[:-1]) 51 | child_name = parts[-1] 52 | father = modules[father_module_name] 53 | self.replace_module(father, k, child_name, v, patch_info) 54 | modified = True 55 | if not modified: 56 | print( 57 | "Warning: the patcher did not patch anything!" 58 | " Check patchable layers with `mp.get_patchable_layers(model)`" 59 | ) 60 | 61 | 62 | class BlockSparseModelPatcher(ModelPatcher): 63 | """Use {"density":d} with d in [0,1] in patch_info} 64 | Use {"pseudo_linear":True} in patch_info to use a pytorch only implementation, if you think there is a bug 65 | in pytorch_block_sparse library""" 66 | 67 | def is_patchable(self, module_name, module, raiseError): 68 | if isinstance(module, torch.nn.Linear): 69 | return True 70 | else: 71 | if raiseError: 72 | raise Exception(f"Cannot patch {module_name}: this is not a Linear layer:\n{module}") 73 | return False 74 | 75 | def new_child_module(self, child_module_name, child_module, patch_info): 76 | density = patch_info["density"] 77 | pseudo = patch_info.get("pseudo_linear") 78 | if pseudo: 79 | patch_type = "PseudoBlockSparseLinear (debug)" 80 | else: 81 | patch_type = "BlockSparseLinear" 82 | 83 | self.is_patchable(child_module_name, child_module, raiseError=True) 84 | print( 85 | f"Patching with {patch_type} '{child_module_name}' with density={density}, in={child_module.in_features}," 86 | f" out={child_module.out_features},bias={child_module.bias is not None} " 87 | ) 88 | ret = BlockSparseLinear(0, 0, False, torch_nn_linear=child_module, density=density) 89 | if pseudo: 90 | ret = PseudoBlockSparseLinear(ret) 91 | 92 | return ret 93 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/gemm/epilogue_function.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * Epilogue operation to compute final output 33 | */ 34 | 35 | namespace cutlass { 36 | namespace gemm { 37 | 38 | //// Used by GEMM to compute the final result C <= alpha * accumulator + beta * C 39 | template < 40 | typename accum_t, 41 | typename output_t, 42 | typename scalar_t 43 | > 44 | class blas_scaled_epilogue 45 | { 46 | public: 47 | 48 | scalar_t alpha; 49 | scalar_t beta; 50 | 51 | inline __device__ __host__ 52 | blas_scaled_epilogue( 53 | scalar_t alpha, 54 | scalar_t beta) 55 | : 56 | alpha(alpha), 57 | beta(beta) 58 | {} 59 | 60 | 61 | /// Epilogue operator 62 | inline __device__ __host__ 63 | output_t operator()( 64 | accum_t accumulator, 65 | output_t c, 66 | size_t idx) const 67 | { 68 | return output_t(alpha * scalar_t(accumulator) + beta * scalar_t(c)); 69 | } 70 | 71 | 72 | /// Epilogue operator 73 | inline __device__ __host__ 74 | output_t operator()( 75 | accum_t accumulator, 76 | size_t idx) const 77 | { 78 | return output_t(alpha * scalar_t(accumulator)); 79 | } 80 | 81 | /** 82 | * Configure epilogue as to whether the thread block is a secondary 83 | * accumulator in an inter-block k-splitting scheme 84 | */ 85 | inline __device__ 86 | void set_secondary_accumulator() 87 | { 88 | beta = scalar_t(1); 89 | } 90 | 91 | 92 | /// Return whether the beta-scaled addend needs initialization 93 | inline __device__ 94 | bool must_init_addend() 95 | { 96 | return (beta != scalar_t(0)); 97 | } 98 | }; 99 | 100 | 101 | 102 | 103 | } // namespace gemm 104 | } // namespace cutlass 105 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/util/matrix_transform.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * \brief Enumeration of dense matrix view transformations 33 | */ 34 | 35 | #include "printable.h" 36 | 37 | namespace cutlass { 38 | 39 | 40 | /****************************************************************************** 41 | * matrix_transform_t 42 | ******************************************************************************/ 43 | 44 | /** 45 | * \brief Enumeration of dense matrix view transformations 46 | * 47 | * These enumerators (and corresponding tag types) describe which view 48 | * transformation needs to be applied prior to operation upon a given dense 49 | * matrix. Its values correspond to Fortran characters 'n' (non-transpose), 50 | * 't'(transpose) and 'c'(conjugate transpose) that are often 51 | * used as parameters to legacy BLAS implementations 52 | */ 53 | struct matrix_transform_t : printable_t 54 | { 55 | /// \brief Enumerants (same as CUBLAS) 56 | enum kind_t 57 | { 58 | /// Invalid view 59 | Invalid = -1, 60 | 61 | /// Non-transpose view 62 | NonTranspose = 0, 63 | 64 | /// Transpose view 65 | Transpose = 1, 66 | 67 | /// Conjugate transpose view 68 | ConjugateTranpose = 2, 69 | }; 70 | 71 | /// Enumerant value 72 | kind_t kind; 73 | 74 | /// Default constructor 75 | matrix_transform_t() : kind(Invalid) {} 76 | 77 | /// Copy constructor 78 | matrix_transform_t(const kind_t &other_kind) : kind(other_kind) {} 79 | 80 | /// Cast to kind_t 81 | operator kind_t() const { return kind; } 82 | 83 | /// Returns the instance as a string 84 | __host__ __device__ inline 85 | char const* to_string() const 86 | { 87 | switch (kind) 88 | { 89 | case NonTranspose: return "NonTranspose"; 90 | case Transpose: return "Transpose"; 91 | case ConjugateTranpose: return "ConjugateTranpose"; 92 | default: return "Invalid"; 93 | } 94 | } 95 | 96 | /// Insert the formatted instance into the output stream 97 | void print(std::ostream& out) const { out << to_string(); } 98 | 99 | }; 100 | 101 | 102 | } // namespace cutlass 103 | -------------------------------------------------------------------------------- /doc/notebooks/ModelSparsification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## How to sparsify a Pytorch model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from transformers import RobertaConfig\n", 17 | "from transformers import RobertaForMaskedLM\n", 18 | "from pytorch_block_sparse import BlockSparseModelPatcher\n", 19 | "import re\n", 20 | "\n", 21 | "config = RobertaConfig(\n", 22 | " vocab_size=52_000,\n", 23 | " max_position_embeddings=514,\n", 24 | " num_attention_heads=12,\n", 25 | " num_hidden_layers=6,\n", 26 | " type_vocab_size=1,\n", 27 | ")\n", 28 | "\n", 29 | "model = RobertaForMaskedLM(config=config).cuda()\n", 30 | "\n", 31 | "# =>84 million parameters\n", 32 | "print(f\"Initial model parameters count={model.num_parameters()}\")\n", 33 | " " 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Create a model patcher\n", 43 | "mp = BlockSparseModelPatcher()\n", 44 | "\n", 45 | "# Show names that can be used: this returns a list of all names in the network that are patchable.\n", 46 | "# These names are escaped to be used as regexps in mp.add_pattern()\n", 47 | "patchables = mp.get_patchable_layers(model)\n", 48 | "\n", 49 | "dedup_layers = []\n", 50 | "\n", 51 | "# Pretty print the regexps: replace layer number with regexp matching numbers, and dedup them\n", 52 | "# This is a bit specific to Roberta, but should work for most transformers, it's just for ease of reading.\n", 53 | "for patchable in patchables:\n", 54 | " r = patchable[\"regexp\"]\n", 55 | " r = re.sub(r'[0-9]+', '[0-9]+', r)\n", 56 | " if r not in dedup_layers:\n", 57 | " dedup_layers.append(r)\n", 58 | " layer = patchable['layer']\n", 59 | " print(f\"{r}\\n => {layer.in_features}x{layer.out_features}, bias={layer.bias is not None}\")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "\n", 69 | "\n", 70 | "# Selecting some layers to sparsify.\n", 71 | "# This is the \"artful\" part, as some parts are more prone to be sparsified, other may impact model precision too much.\n", 72 | "\n", 73 | "# Match layers using regexp (we escape the ., just because, it's more correct, but it does not change anything here)\n", 74 | "# the [0-9]+ match any layer number.\n", 75 | "# We setup a density of 0.5 on these layers, you can test other layers / densities .\n", 76 | "mp.add_pattern(\"roberta\\.encoder\\.layer\\.[0-9]+\\.intermediate\\.dense\", {\"density\":0.5})\n", 77 | "mp.add_pattern(\"roberta\\.encoder\\.layer\\.[0-9]+\\.output\\.dense\", {\"density\":0.5})\n", 78 | "mp.add_pattern(\"roberta\\.encoder\\.layer\\.[0-9]+\\.attention\\.output\\.dense\", {\"density\":0.5})\n", 79 | "mp.patch_model(model)\n", 80 | "\n", 81 | "print(f\"Final model parameters count={model.num_parameters()}\")\n", 82 | "\n", 83 | "# => 68 million parameters instead of 84 million parameters (embeddings are taking a lof space in Roberta)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "Python 3", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.7.3" 111 | }, 112 | "pycharm": { 113 | "stem_cell": { 114 | "cell_type": "raw", 115 | "metadata": { 116 | "collapsed": false 117 | }, 118 | "source": [] 119 | } 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 4 124 | } 125 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/util/math.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * \brief Math utilities 33 | */ 34 | 35 | #include "nv_std.h" 36 | 37 | namespace cutlass { 38 | 39 | 40 | /****************************************************************************** 41 | * Static math utilities 42 | ******************************************************************************/ 43 | 44 | /** 45 | * Statically determine if N is a power-of-two 46 | */ 47 | template 48 | struct is_pow2 : nv_std::integral_constant 49 | {}; 50 | 51 | 52 | 53 | 54 | 55 | /** 56 | * Statically determine log2(N), rounded down 57 | */ 58 | template 59 | struct log2_down 60 | { 61 | /// Static logarithm value 62 | enum { value = log2_down> 1), Count + 1>::value }; 63 | }; 64 | 65 | // Base case 66 | template 67 | struct log2_down 68 | { 69 | enum { value = Count }; 70 | }; 71 | 72 | 73 | 74 | 75 | /** 76 | * Statically determine log2(N), rounded up 77 | */ 78 | template 79 | struct log2_up 80 | { 81 | /// Static logarithm value 82 | enum { value = log2_up> 1), Count + 1>::value }; 83 | }; 84 | 85 | // Base case 86 | template 87 | struct log2_up 88 | { 89 | enum { value = ((1 << Count) < N) ? Count + 1 : Count }; 90 | }; 91 | 92 | 93 | 94 | /** 95 | * Statically estimate sqrt(N) to the nearest power-of-two 96 | */ 97 | template 98 | struct sqrt_est 99 | { 100 | enum { value = 1 << (log2_up::value / 2) }; 101 | }; 102 | 103 | 104 | 105 | /** 106 | * For performing a constant-division with a compile-time assertion that the 107 | * Divisor evenly-divides the Dividend. 108 | */ 109 | template 110 | struct divide_assert 111 | { 112 | enum { value = Dividend / Divisor}; 113 | 114 | static_assert((Dividend % Divisor == 0), "Not an even multiple"); 115 | }; 116 | 117 | 118 | 119 | 120 | 121 | /****************************************************************************** 122 | * Rounding 123 | ******************************************************************************/ 124 | 125 | /** 126 | * Round dividend up to the nearest multiple of divisor 127 | */ 128 | template 129 | inline __host__ __device__ 130 | dividend_t round_nearest(dividend_t dividend, divisor_t divisor) 131 | { 132 | return ((dividend + divisor - 1) / divisor) * divisor; 133 | } 134 | 135 | 136 | /** 137 | * Greatest common divisor 138 | */ 139 | template 140 | inline __host__ __device__ 141 | value_t gcd(value_t a, value_t b) 142 | { 143 | for (;;) 144 | { 145 | if (a == 0) return b; 146 | b %= a; 147 | if (b == 0) return a; 148 | a %= b; 149 | } 150 | } 151 | 152 | 153 | /** 154 | * Least common multiple 155 | */ 156 | template 157 | inline __host__ __device__ 158 | value_t lcm(value_t a, value_t b) 159 | { 160 | value_t temp = gcd(a, b); 161 | 162 | return temp ? (a / temp * b) : 0; 163 | } 164 | 165 | 166 | } // namespace cutlass 167 | 168 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/util/debug.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * \brief Debugging and logging functionality 33 | */ 34 | 35 | #include 36 | 37 | namespace cutlass { 38 | 39 | /****************************************************************************** 40 | * Debug and logging macros 41 | ******************************************************************************/ 42 | 43 | /** 44 | * Formats and prints the given message to stdout 45 | */ 46 | #if !defined(CUDA_LOG) 47 | #if !defined(__CUDA_ARCH__) 48 | #define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) 49 | #else 50 | inline __host__ __device__ unsigned get_threadidx_x() { return threadIdx.x; } 51 | inline __host__ __device__ unsigned get_threadidx_y() { return threadIdx.y; } 52 | inline __host__ __device__ unsigned get_threadidx_z() { return threadIdx.z; } 53 | inline __host__ __device__ unsigned get_blockidx_x() { return blockIdx.x; } 54 | inline __host__ __device__ unsigned get_blockidx_y() { return blockIdx.y; } 55 | inline __host__ __device__ unsigned get_blockidx_z() { return blockIdx.z; } 56 | #define CUDA_LOG(format, ...) \ 57 | printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ 58 | get_blockidx_x(), get_blockidx_y(), get_blockidx_z(), \ 59 | get_threadidx_x(), get_threadidx_y(), get_threadidx_z(), \ 60 | __VA_ARGS__); 61 | #endif 62 | #endif 63 | 64 | 65 | /** 66 | * Formats and prints the given message to stdout only if DEBUG is defined 67 | */ 68 | #if !defined(CUDA_LOG_DEBUG) 69 | #ifdef DEBUG 70 | #define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) 71 | #else 72 | #define CUDA_LOG_DEBUG(format, ...) 73 | #endif 74 | #endif 75 | 76 | 77 | /** 78 | * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) along with the supplied source context. 79 | * 80 | * \return The CUDA error. 81 | */ 82 | __host__ __device__ inline cudaError_t cuda_perror_impl( 83 | cudaError_t error, 84 | const char* filename, 85 | int line) 86 | { 87 | (void)filename; 88 | (void)line; 89 | if (error) 90 | { 91 | #if !defined(__CUDA_ARCH__) 92 | fprintf(stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error)); 93 | fflush(stderr); 94 | #else 95 | printf("CUDA error %d [%s, %d]\n", error, filename, line); 96 | #endif 97 | } 98 | return error; 99 | } 100 | 101 | 102 | /** 103 | * \brief Perror macro 104 | */ 105 | #ifndef CUDA_PERROR 106 | #define CUDA_PERROR(e) cuda_perror_impl((cudaError_t) (e), __FILE__, __LINE__) 107 | #endif 108 | 109 | 110 | /** 111 | * \brief Perror macro with exit 112 | */ 113 | #ifndef CUDA_PERROR_EXIT 114 | #define CUDA_PERROR_EXIT(e) if (cuda_perror_impl((cudaError_t) (e), __FILE__, __LINE__)) { exit(1); } 115 | #endif 116 | 117 | 118 | /** 119 | * \brief Perror macro only if DEBUG is defined 120 | */ 121 | #ifndef CUDA_PERROR_DEBUG 122 | #ifdef DEBUG 123 | #define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) 124 | #else 125 | #define CUDA_PERROR_DEBUG(e) (e) 126 | #endif 127 | #endif 128 | 129 | 130 | } // namespace cutlass 131 | -------------------------------------------------------------------------------- /pytorch_block_sparse/native/block_sparse_cutlass_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | /* 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | */ 13 | 14 | // CUBLAS GEMM API 15 | #include 16 | 17 | // Cutlass GEMM API 18 | #include 19 | #include 20 | #include 21 | 22 | // Dispatch routines to CUTLASS 23 | #include "cutlass_dispatch.h" 24 | 25 | using namespace std; 26 | using namespace cutlass; 27 | 28 | /** 29 | * Compute C = A * B, where B is block sparse, A and C dense 30 | **/ 31 | template < 32 | typename func_t, ///< Test function type 33 | gemm::tiling_strategy::kind_t TilingStrategy, 34 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 35 | matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B 36 | typename value_t, ///< Multiplicand value type (matrices A and B) 37 | typename accum_t> ///< Accumulator value type (matrix C and scalars) 38 | cudaError_t forward_full(value_t* A_data, 39 | value_t* B_data, 40 | int* B_ptr, 41 | int* B_indices, 42 | accum_t* C_data, 43 | int m, ///< Height of C in rows 44 | int n, ///< Width of C in columns 45 | int k ///< Width (height) of A (B) 46 | ) 47 | { 48 | typedef gemm::gemm_policy block_task_policy_t; 49 | 50 | cudaStream_t stream = 0; 51 | 52 | func_t func; 53 | 54 | cudaError_t error = func(m, 55 | n, 56 | k, 57 | A_data, 58 | B_data, 59 | B_ptr, 60 | B_indices, 61 | C_data, 62 | accum_t(1.0), 63 | accum_t(0.0), 64 | stream, 65 | false).result; 66 | 67 | return error; 68 | } 69 | 70 | /** 71 | * Compute C = A.matmul(B), where A and B are dense, and only on the sparse support of C 72 | **/ 73 | template 78 | cudaError_t forward(value_t* A_data, 79 | value_t* B_data, 80 | int* B_ptr, 81 | int* B_indices, 82 | accum_t* C_data, 83 | int m, 84 | int n, 85 | int k) 86 | 87 | { 88 | const math_operation_class_t math_op = math_operation_class_t::scalar; 89 | 90 | cudaError_t error = forward_full, 96 | gemm::tiling_strategy::Custom, 97 | TransformA, 98 | TransformB, 99 | value_t, 100 | accum_t>(A_data,B_data,B_ptr, B_indices, C_data, m, n, k); 101 | return error; 102 | } 103 | 104 | typedef cudaError_t (*forward_t)(float* A_data, 105 | float* B_data, 106 | int* B_ptr, 107 | int* B_indices, 108 | float* C_data, 109 | int m, 110 | int n, 111 | int k); 112 | 113 | int blocksparse_matmul_cutlass(torch::Tensor dense_a, 114 | bool pytorch_contiguous_a, 115 | torch::Tensor ptr_b, 116 | torch::Tensor indices_b, 117 | torch::Tensor data_b, 118 | int m, 119 | int n, 120 | int k, 121 | int block_size_rows_b, 122 | int block_size_cols_b, 123 | torch::Tensor dense_out) 124 | { 125 | typedef float value_t; 126 | typedef float accum_t; 127 | //static const matrix_transform_t::kind_t TransformA = matrix_transform_t::Transpose; 128 | static const matrix_transform_t::kind_t TransformB = matrix_transform_t::Transpose; 129 | 130 | value_t* A_data = (value_t*)dense_a.data_ptr(); 131 | value_t* B_data = (value_t*)data_b.data_ptr(); 132 | int* B_ptr = (int*)ptr_b.data_ptr(); 133 | int* B_indices = (int*)indices_b.data_ptr(); 134 | value_t* C_data = (value_t*)dense_out.data_ptr(); 135 | 136 | static const matrix_transform_t::kind_t NonTranspose = matrix_transform_t::NonTranspose; 137 | static const matrix_transform_t::kind_t Transpose = matrix_transform_t::Transpose; 138 | 139 | forward_t forward_fun; 140 | 141 | assert(pytorch_contiguous_a); 142 | //if (pytorch_contiguous_a) { 143 | // forward_fun = forward; 144 | //} else { 145 | forward_fun = forward; 146 | //} 147 | cudaError_t error = forward_fun(A_data,B_data,B_ptr, B_indices, C_data, m, n, k); 148 | 149 | return error; 150 | } 151 | 152 | 153 | -------------------------------------------------------------------------------- /pytorch_block_sparse/native/block_sparse_cutlass_kernel_back.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | // CUBLAS GEMM API 15 | #include 16 | 17 | // Cutlass GEMM API 18 | #include 19 | #include 20 | #include 21 | 22 | // Dispatch routines to CUTLASS 23 | #include "cutlass_dispatch_back.h" 24 | 25 | using namespace std; 26 | using namespace cutlass; 27 | 28 | 29 | /** 30 | * Compute C = A.matmul(B), where A and B are dense, and only on the sparse support of C 31 | **/ 32 | template 39 | cudaError_t back_full(value_t* A_data, 40 | value_t* B_data, 41 | accum_t* C_data, 42 | int2* C_blocks, 43 | long C_blocks_length, 44 | int m, ///< Height of C in rows 45 | int n, ///< Width of C in columns 46 | int k) 47 | { 48 | 49 | typedef gemm::gemm_policy block_task_back_policy_t; 50 | 51 | cudaStream_t stream = 0; 52 | 53 | func_t func; 54 | 55 | cudaError_t error = func(m, 56 | n, 57 | k, 58 | A_data, 59 | B_data, 60 | C_data, 61 | C_blocks, 62 | C_blocks_length, 63 | accum_t(1.0), 64 | accum_t(0.0), 65 | stream, 66 | false).result; 67 | 68 | return error; 69 | } 70 | 71 | /** 72 | * Compute C = A.matmul(B), where A and B are dense, and only on the sparse support of C 73 | **/ 74 | template 79 | cudaError_t back(value_t* A_data, 80 | value_t* B_data, 81 | accum_t* C_data, 82 | int2* C_blocks, 83 | long C_blocks_length, 84 | int m, ///< Height of C in rows 85 | int n, ///< Width of C in columns 86 | int k) 87 | 88 | { 89 | cudaError_t error = back_full, 95 | gemm::tiling_strategy::CustomBack, 96 | TransformA, 97 | TransformB, 98 | value_t, 99 | accum_t>(A_data, B_data, C_data, C_blocks, C_blocks_length, m, n, k); 100 | return error; 101 | } 102 | 103 | typedef cudaError_t (*back_t)(float* A_data, 104 | float* B_data, 105 | float* C_data, 106 | int2* C_blocks, 107 | long C_blocks_length, 108 | int m, 109 | int n, 110 | int k); 111 | 112 | /** 113 | * matrix a must be of dimensions [m,k] 114 | * matrix b must be of dimensions [k,n] 115 | * if pytorch_contiguous_a is true, then dense_a must be contiguous, ortherwise dense_a.t() must be contiguous. 116 | * if pytorch_contiguous_b is true, then dense_b must be contiguous, ortherwise dense_b.t() must be contiguous. 117 | **/ 118 | int blocksparse_matmul_back_cutlass(torch::Tensor dense_a, 119 | bool pytorch_contiguous_a, 120 | torch::Tensor dense_b, 121 | bool pytorch_contiguous_b, 122 | int m, 123 | int n, 124 | int k, 125 | int block_size_rows_b, 126 | int block_size_cols_b, 127 | torch::Tensor sparse_c, 128 | torch::Tensor sparse_blocks_c, 129 | long sparse_blocks_length_c) 130 | { 131 | typedef float value_t; 132 | typedef float accum_t; 133 | 134 | value_t* A_data = (value_t*)dense_a.data_ptr(); 135 | value_t* B_data = (value_t*)dense_b.data_ptr(); 136 | value_t* C_data = (value_t*)sparse_c.data_ptr(); 137 | int2* C_blocks = (int2*)sparse_blocks_c.data_ptr(); 138 | long C_blocks_length = sparse_blocks_length_c; 139 | 140 | back_t back_fun; 141 | 142 | static const matrix_transform_t::kind_t NonTranspose = matrix_transform_t::NonTranspose; 143 | static const matrix_transform_t::kind_t Transpose = matrix_transform_t::Transpose; 144 | 145 | if (pytorch_contiguous_a) { 146 | if (pytorch_contiguous_b) { 147 | back_fun = back; 148 | } else { 149 | back_fun = back; 150 | } 151 | } else { 152 | if (pytorch_contiguous_b) { 153 | back_fun = back; 154 | } else { 155 | back_fun = back; 156 | } 157 | } 158 | 159 | return back_fun(A_data,B_data, C_data, C_blocks, C_blocks_length, m, n, k); 160 | } 161 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_integration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import tempfile 4 | import unittest 5 | from typing import Any, Dict, Union 6 | from unittest import TestCase 7 | 8 | import torch 9 | import torch.nn as nn 10 | from transformers import ( 11 | RobertaConfig, 12 | RobertaForMaskedLM, 13 | RobertaTokenizerFast, 14 | Trainer, 15 | TrainingArguments, 16 | ) 17 | 18 | from pytorch_block_sparse import BlockSparseModelPatcher, SparseOptimizer 19 | 20 | 21 | class TestFun(TestCase): 22 | def helper(self, model, input_tensor, patterns, patch_info, param_counts): 23 | for i in range(2): 24 | parameter_count = 0 25 | for param in model.parameters(): 26 | parameter_count += param.numel() 27 | 28 | self.assertEqual(parameter_count, param_counts[i]) 29 | 30 | if i == 0: 31 | mp = BlockSparseModelPatcher() 32 | for p in patterns: 33 | mp.add_pattern(p, patch_info) 34 | mp.patch_model(model) 35 | _ = model(input_tensor) 36 | 37 | def test0(self): 38 | density = 0.5 39 | for bias in [False, True]: 40 | for patch_info in [ 41 | {"density": 0.5}, 42 | {"density": density, "pseudo_linear": True}, 43 | ]: 44 | linear = torch.nn.Linear(64, 128, bias) 45 | model = torch.nn.Sequential(linear).cuda() 46 | input_tensor = torch.randn(64, 64).cuda() 47 | 48 | pc = linear.weight.numel() 49 | if "pseudo_linear" in patch_info: 50 | pc_sparse = pc 51 | else: 52 | pc_sparse = int(pc * density) 53 | 54 | if bias: 55 | pc += linear.bias.numel() 56 | pc_sparse += linear.bias.numel() 57 | 58 | self.helper( 59 | model, 60 | input_tensor, 61 | ["0"], 62 | patch_info=patch_info, 63 | param_counts=[pc, pc_sparse], 64 | ) 65 | 66 | def roberta_build(self, sparse=False, base_model=None, density=1.0, eval=True): 67 | if base_model is None: 68 | config = RobertaConfig( 69 | vocab_size=52_000, 70 | max_position_embeddings=514, 71 | num_attention_heads=12, 72 | num_hidden_layers=6, 73 | type_vocab_size=1, 74 | ) 75 | 76 | model = RobertaForMaskedLM(config=config).cuda() 77 | else: 78 | model = base_model 79 | 80 | if sparse: 81 | mp = BlockSparseModelPatcher() 82 | mp.add_pattern( 83 | "roberta\\.encoder\\.layer\\.[0-9]+.intermediate\\.dense", 84 | {"density": density}, 85 | ) 86 | mp.add_pattern("roberta\\.encoder\\.layer\\.[0-9]+.output\\.dense", {"density": density}) 87 | mp.patch_model(model) 88 | 89 | if eval: 90 | model.eval() 91 | 92 | return model, model.num_parameters() 93 | 94 | def test1(self): 95 | model0, num_parameters0 = self.roberta_build() 96 | 97 | input_ids = torch.tensor([[4, 5, 6, 7] * 8]).cuda() 98 | input_ids = input_ids.expand((1, 32)) 99 | 100 | out0 = model0(input_ids) 101 | 102 | model1, num_parameters1 = self.roberta_build(sparse=True, base_model=model0) 103 | out1 = model1(input_ids) 104 | 105 | self.assertTrue(torch.isclose(out0[0], out1[0], atol=1e-3).all()) 106 | 107 | model2, num_parameters2 = self.roberta_build(sparse=True, density=0.5, eval=True) 108 | model2.eval() 109 | 110 | _ = model2(input_ids) 111 | 112 | self.assertEqual(num_parameters0, num_parameters1) 113 | self.assertGreater(70000000, num_parameters2) 114 | 115 | def test_with_trainer(self): 116 | test_dir = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) 117 | data_dir = test_dir / "data" 118 | 119 | with tempfile.TemporaryDirectory() as tmpdir: 120 | model, num_parameters = self.roberta_build(sparse=True, density=0.5, eval=False) 121 | 122 | tokenizer = RobertaTokenizerFast.from_pretrained(str(data_dir), max_len=512) 123 | 124 | from transformers import DataCollatorForLanguageModeling 125 | 126 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) 127 | 128 | from transformers import LineByLineTextDataset 129 | 130 | dataset = LineByLineTextDataset( 131 | tokenizer=tokenizer, 132 | file_path=data_dir / "oscar.eo.small.txt", 133 | block_size=128, 134 | ) 135 | 136 | training_args = TrainingArguments( 137 | output_dir=tmpdir, 138 | num_train_epochs=1, 139 | per_device_train_batch_size=16, # Adapt it to your size 140 | save_steps=10_000, 141 | ) 142 | 143 | class CustomTrainer(Trainer): 144 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 145 | if self.first_step: 146 | so.attach_optimizer(self.optimizer) 147 | self.first_step = False 148 | self.sparse_optimizer.step() 149 | ret = super().training_step(model, inputs) 150 | return ret 151 | 152 | trainer = CustomTrainer( 153 | model=model, 154 | args=training_args, 155 | data_collator=data_collator, 156 | train_dataset=dataset, 157 | ) 158 | 159 | cleanup_ratio = 0.1 160 | sparse_objects = SparseOptimizer.sparse_objects(model) 161 | 162 | self.assertEqual(len(sparse_objects), 12) 163 | so = SparseOptimizer(sparse_objects, lr=cleanup_ratio) 164 | 165 | trainer.sparse_optimizer = so 166 | trainer.first_step = True 167 | 168 | trainer.train() 169 | 170 | 171 | if __name__ == "__main__": 172 | unittest.main() 173 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_replace.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import torch 5 | 6 | from pytorch_block_sparse import BlockSparseMatrix 7 | 8 | 9 | class TestFun(TestCase): 10 | def test_block_norm(self): 11 | nblocks = 6 12 | block_shape = (32, 32) 13 | bsm = BlockSparseMatrix.randn((256, 256), nblocks, block_shape=block_shape, device="cuda") 14 | n = bsm.block_norm() 15 | self.assertEqual(n.dim(), 1) 16 | self.assertEqual(n.shape[0], nblocks) 17 | 18 | d = bsm.data.reshape(-1, block_shape[0] * block_shape[1]) 19 | d = (d * d).sum(-1).sqrt() 20 | self.assertTrue(d.isclose(n).all()) 21 | 22 | def test_block_replace(self): 23 | tests = [ 24 | dict( 25 | size=[128, 64], 26 | blocks=[ 27 | (0, 0), 28 | (1, 0), 29 | (2, 0), 30 | (0, 1), 31 | ], 32 | block_info=[(0, 0), (0, 1), (1, 0), (2, 0)], 33 | block_replace=[ 34 | (3, 1, 0), 35 | (2, 1, 2), 36 | (1, 1, 1), 37 | ], # row, col, block_index 38 | after=dict( 39 | row_start_ends_a=[0, 0, 1, 3, 4], 40 | cols_a=[[1, 1], [0, 3], [1, 2], [1, 0]], 41 | block_mask=[[0, 0], [0, 1], [1, 1], [0, 1]], 42 | ), 43 | ), 44 | dict( 45 | size=[128, 64], 46 | blocks=[ 47 | (0, 0), 48 | (1, 0), 49 | (2, 0), 50 | (0, 1), 51 | ], 52 | block_info=[(0, 0), (0, 1), (1, 0), (2, 0)], 53 | block_replace=[(0, 1, 0)], # row, col, block_index 54 | error="Block position (0,1) was already used", 55 | ), 56 | ] 57 | block_shape = (32, 32) 58 | device = "cuda" 59 | verbose = False 60 | for test_info in tests[:1]: 61 | size = test_info["size"] 62 | blocks = test_info["blocks"] 63 | block_replace = torch.tensor(test_info["block_replace"]) 64 | bsm = BlockSparseMatrix.randn( 65 | (size[0], size[1]), 66 | None, 67 | blocks=blocks, 68 | block_shape=block_shape, 69 | device=device, 70 | positive=True, 71 | ) 72 | bsm.check_ = True 73 | 74 | if verbose: 75 | print(block_replace) 76 | block_mask0 = bsm.block_mask_build(None) 77 | print(block_mask0) 78 | 79 | dbsm0 = bsm.to_dense() 80 | block_positions = bsm.build_coo_block_index().t() 81 | for i, b in enumerate(test_info["block_info"]): 82 | block_position = tuple(block_positions[i].cpu().numpy()) 83 | self.assertEqual(b, block_position) 84 | 85 | try: 86 | bsm.block_replace(block_replace) 87 | except Exception as e: 88 | if test_info.get("error") == str(e): 89 | continue 90 | raise 91 | 92 | for k, v in test_info["after"].items(): 93 | if k != "block_mask": 94 | r = getattr(bsm, k) 95 | else: 96 | r = bsm.block_mask_build(None).long() 97 | v = torch.tensor(v, device=r.device) 98 | 99 | self.assertTrue((r == v).all()) 100 | 101 | dbsm = bsm.to_dense() 102 | bsm.check_with_dense(dbsm) 103 | 104 | # Check changed positions 105 | bs = block_shape 106 | for b in block_replace: 107 | block_index = b[2] 108 | bp = block_positions[block_index] 109 | block0 = dbsm0[ 110 | bp[0] * bs[0] : (bp[0] + 1) * bs[0], 111 | bp[1] * bs[1] : (bp[1] + 1) * bs[1], 112 | ] 113 | block = dbsm[b[0] * bs[0] : (b[0] + 1) * bs[0], b[1] * bs[1] : (b[1] + 1) * bs[1]] 114 | 115 | self.assertTrue((block0 == block).all()) 116 | 117 | # Check unchanged positions 118 | for i, b in enumerate(block_positions): 119 | if i not in block_replace[:, 2]: 120 | bp = b 121 | block0 = dbsm0[ 122 | bp[0] * bs[0] : (bp[0] + 1) * bs[0], 123 | bp[1] * bs[1] : (bp[1] + 1) * bs[1], 124 | ] 125 | block = dbsm[ 126 | b[0] * bs[0] : (b[0] + 1) * bs[0], 127 | b[1] * bs[1] : (b[1] + 1) * bs[1], 128 | ] 129 | self.assertTrue((block0 == block).all()) 130 | 131 | # Check that empty positions are indeed empty 132 | block_mask = bsm.block_mask_build(None) 133 | 134 | if verbose: 135 | print(block_mask) 136 | 137 | block_mask = block_mask.repeat_interleave(32, dim=0).repeat_interleave(32, dim=1).float() 138 | self.assertEqual((dbsm * (1 - block_mask)).abs().sum(), 0) 139 | 140 | # Part 2: check multiplication behaviour 141 | a = torch.randn((1, size[1]), device=bsm.data.device).abs() 142 | 143 | c = bsm.reverse_matmul(a, transpose=True) 144 | c_0 = a.matmul(dbsm.t()) 145 | 146 | # Basic check 147 | all_compare = torch.isclose(c, c_0) 148 | if not all_compare.all(): 149 | # print((all_compare != True).nonzero()) 150 | # print((c-c_0).abs().max()) 151 | self.assertTrue(False) 152 | 153 | # Check matmul with sparse support 154 | b = torch.randn((1, size[0]), device=bsm.data.device).abs() 155 | 156 | bsm.matmul_with_output_sparse_support(b, a, overwrite_data=True) 157 | dbsm_back = bsm.to_dense() 158 | dbsm0_back = b.t().mm(a) 159 | dbsm0_back = dbsm0_back * bsm.to_dense(data_replace=torch.ones_like(bsm.data)) 160 | 161 | self.assertTrue(dbsm0_back.isclose(dbsm_back).all()) 162 | 163 | 164 | if __name__ == "__main__": 165 | unittest.main() 166 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_linear_nn.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import torch 5 | import torch.optim as optim 6 | 7 | from pytorch_block_sparse import ( 8 | BlockSparseLinear, 9 | BlockSparseMatrix, 10 | BlockSparseMatrixEmulator, 11 | ) 12 | from pytorch_block_sparse.block_sparse_linear import PseudoBlockSparseLinear 13 | 14 | 15 | class TestFun(TestCase): 16 | def test0(self): 17 | d = dict 18 | tests = [ 19 | d(size_a=[8, 4], size_b=[4, 4], block_shape_b=(4, 1), density=0.5), 20 | d(size_a=[32, 32], size_b=[32, 32], block_shape_b=(4, 8), density=1.0), 21 | d(size_a=[32, 32], size_b=[32, 32], density=1.0), 22 | d(size_a=[256, 32], size_b=[32, 32], density=1.0), 23 | ] 24 | verbose = False 25 | for test in tests: 26 | lr = 0.001 27 | 28 | stride = 1 29 | size_a = test["size_a"] 30 | size_b = test["size_b"] 31 | block_shape_b = test.get("block_shape_b", (32, 32)) 32 | # print(f"size_a={size_a}, size_b={size_b}") 33 | # Create the sparse linear layer 34 | linear = BlockSparseLinear(size_b[0], size_b[1], True, test["density"], block_shape=block_shape_b) 35 | if verbose: 36 | print(f"linear weight {linear.weight.data.shape}\n", linear.weight.data[::stride, ::stride]) 37 | if hasattr(linear.weight, "_dense"): 38 | print( 39 | f"linear weight dense {linear.weight._dense.shape}\n", linear.weight._dense[::stride, ::stride] 40 | ) 41 | 42 | # TODO : this does nothing 43 | linear.cuda() 44 | 45 | # Input vector 46 | a1 = torch.nn.Parameter(torch.ones([size_a[0], size_a[1]]).cuda()) 47 | a2 = torch.nn.Parameter(torch.ones([size_a[0], size_a[1]]).cuda()) 48 | 49 | # Build a dense equivalent to the sparse 50 | dense = torch.nn.Parameter(linear.weight.to_dense().cuda()) 51 | bias = torch.nn.Parameter(torch.zeros(size_b[1]).cuda()) 52 | if verbose: 53 | print("dense\n", dense[::stride, ::stride]) 54 | 55 | optimizer0 = optim.Adam([a1] + list(linear.parameters()), lr=lr) 56 | optimizer1 = optim.Adam([a2, dense, bias], lr=lr) 57 | 58 | for i in range(40): 59 | s = dense.isclose(linear.weight.to_dense(), atol=1e-05).all() 60 | 61 | if not s: 62 | raise Exception("Matrices are different") 63 | 64 | optimizer0.zero_grad() 65 | optimizer1.zero_grad() 66 | 67 | # Apply the linear function 68 | b1 = linear(a1) 69 | 70 | # Compute a reference value 71 | b2 = a2.matmul(dense.t()) + bias 72 | 73 | # Check that both results match 74 | s = b1.isclose(b2, atol=1e-05).all() 75 | 76 | if not s: 77 | raise Exception("Output are differents") 78 | 79 | loss1 = b1.sum() 80 | loss2 = b2.sum() 81 | 82 | loss1.backward() 83 | loss2.backward() 84 | 85 | s = a1.grad.isclose(a2.grad, atol=1e-05).all() 86 | if not s: 87 | raise Exception("Input gradients are differents") 88 | 89 | a_grad = linear.weight.reverse_matmul(torch.ones_like(a1), transpose=False) 90 | 91 | s = a_grad.isclose(a2.grad, atol=1e-05).all() 92 | if not s: 93 | print("input gradient 0\n", a_grad[::stride, ::stride]) 94 | print("input gradient 1\n", a1.grad[::stride, ::stride]) 95 | print("input gradient 2\n", a2.grad[::stride, ::stride]) 96 | 97 | raise Exception("Input gradients are differents, manual check") 98 | 99 | if verbose: 100 | print("a_grad\n", a_grad[::stride, ::stride]) 101 | print("a1 grad\n", a1.grad[::stride, ::stride]) 102 | print("a2 grad\n", a2.grad[::stride, ::stride]) 103 | 104 | print(linear.weight.get_differentiable_data().grad) 105 | 106 | if isinstance(linear.weight, BlockSparseMatrix): 107 | dense_grad = linear.weight.to_dense(data_replace=linear.weight.get_differentiable_data().grad) 108 | elif isinstance(linear.weight, BlockSparseMatrixEmulator): 109 | dense_grad = linear.weight.get_differentiable_data().grad 110 | else: 111 | raise RuntimeError("Unknown linear weight type {linear.weight.__class__}") 112 | dense_mask = linear.weight.to_dense( 113 | data_replace=torch.ones_like(linear.weight.get_differentiable_data().grad) 114 | ) 115 | 116 | dense_grad_reference = dense.grad * dense_mask 117 | 118 | if verbose: 119 | print("dense_grad\n", dense_grad[::stride, ::stride]) 120 | print( 121 | "dense_grad_reference\n", 122 | dense_grad_reference[::stride, ::stride], 123 | ) 124 | 125 | s = dense_grad.isclose(dense_grad_reference, atol=1e-05).all() 126 | 127 | if not s: 128 | raise Exception("Weight gradients are differents") 129 | 130 | optimizer0.step() 131 | optimizer1.step() 132 | 133 | with torch.no_grad(): 134 | dense *= dense_mask 135 | 136 | def test_pseudo_sparse(self): 137 | tests = [{"size_a": [256, 64], "size_b": [64, 128], "density": 1.0}] 138 | for test in tests: 139 | size_a = test["size_a"] 140 | size_b = test["size_b"] 141 | print(f"size_a={size_a}, size_b={size_b}") 142 | # Create the sparse linear layer 143 | linear = BlockSparseLinear(size_b[0], size_b[1], True, test["density"]) 144 | with torch.no_grad(): 145 | linear.weight.data.copy_(linear.weight.data.abs()) 146 | pseudo_linear = PseudoBlockSparseLinear(linear) 147 | 148 | a1 = torch.randn([size_a[0], size_a[1]]).cuda().abs() 149 | 150 | b1_l = linear(a1) 151 | 152 | b1_pl = pseudo_linear(a1) 153 | 154 | self.assertTrue(torch.isclose(b1_l, b1_pl, rtol=1e-5).all()) 155 | 156 | 157 | if __name__ == "__main__": 158 | unittest.main() 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast Block Sparse Matrices for Pytorch 2 | 3 | This PyTorch extension provides a **drop-in replacement** for torch.nn.Linear using **block sparse matrices** instead of dense ones. 4 | 5 | It enables very easy experimentation with sparse matrices since you can directly replace Linear layers in your model with sparse ones. 6 | 7 | ## Motivation 8 | The goal of this library is to show that **sparse matrices can be used in neural networks**, instead of dense ones, without significantly altering the precision. 9 | 10 | This is great news as sparse matrices unlock savings in both space and compute: a **50% sparse matrix** will use **only 50% memory**, and theoretically will use only 50% of computation. 11 | In this library we make use of Cutlass to improve the CUDA performances versus a naive implementation. 12 | However, due to the very optimized nature of cuBLAS based torch.nn.Linear, the current version of the library is still slower, by roughly a factor of 2 (this may be improved in the future). 13 | 14 | In the present stage of the library, the performances for sparse matrices are roughly a factor of 2 slower than their optimized dense counterpart (we hope to improve this in the future). However, the performance gain of using sparse matrices grows with the sparsity, so a **75% sparse matrix** is roughly **2x** faster than the dense equivalent. 15 | This is a huge improvement on PyTorch sparse matrices: their current implementation is an order of magnitude slower than the dense one. 16 | 17 | Combined with other methods like distillation and quantization this allow to obtain networks which are both smaller and faster! 18 | 19 | ## Original code 20 | This work is based on the [cutlass tilesparse](https://github.com/YulhwaKim/cutlass_tilesparse) proof of concept by [Yulhwa Kim](https://github.com/YulhwaKim). 21 | 22 | It is using C++ CUDA templates for block-sparse matrix multiplication based on [CUTLASS](https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/). 23 | 24 | ## Basic usage 25 | You can use the BlockSparseLinear drop in replacement for torch.nn.Linear in your own model: 26 | 27 | ```python 28 | # from torch.nn import Linear 29 | from pytorch_block_sparse import BlockSparseLinear 30 | 31 | ... 32 | 33 | # self.fc = nn.Linear(1024, 256) 34 | self.fc = BlockSparseLinear(1024, 256, density=0.1) 35 | ``` 36 | 37 | ## Advanced usage: converting whole models 38 | 39 | Or you can use a utility called BlockSparseModelPatcher to modify easily an existing model before training it. (you will need to train it from scratch rather than sparsifying a pre-trained model). 40 | 41 | Here is an example with a Roberta Model from Hugging Face ([full example](doc/notebooks/ModelSparsification.ipynb)) 42 | 43 | ```python 44 | from pytorch_block_sparse import BlockSparseModelPatcher 45 | # Create a model patcher 46 | mp = BlockSparseModelPatcher() 47 | 48 | # Selecting some layers to sparsify. 49 | # This is the "artful" part, as some parts are more prone to be sparsified, other may impact model precision too much. 50 | 51 | # Match layers using regexp (we escape the ., just because, it's more correct, but it does not change anything here) 52 | # the [0-9]+ match any layer number. 53 | # We setup a density of 0.5 on these layers, you can test other layers / densities . 54 | mp.add_pattern("roberta\.encoder\.layer\.[0-9]+\.intermediate\.dense", {"density":0.5}) 55 | mp.add_pattern("roberta\.encoder\.layer\.[0-9]+\.output\.dense", {"density":0.5}) 56 | mp.add_pattern("roberta\.encoder\.layer\.[0-9]+\.attention\.output\.dense", {"density":0.5}) 57 | mp.patch_model(model) 58 | 59 | print(f"Final model parameters count={model.num_parameters()}") 60 | 61 | # => 68 million parameters instead of 84 million parameters (embeddings are taking a lof of space in Roberta) 62 | ``` 63 | 64 | You can use the provided [notebook](doc/notebooks/01_how_to_train_sparse/01_how_to_train_sparse.ipynb) to train a partially sparse Roberta. 65 | 66 | ## Performance 67 | It's notoriously hard to approach cuBLAS performance with custom CUDA kernels. 68 | OpenAI kernels for example make ample use of assembly language to achieve a good performance. 69 | 70 | The promise of Cutlass was to provide tools that abstract the different parts of CUDA kernels using smart C++ templates. 71 | 72 | This allows the `pytorch_block_sparse` library to achieve roughly 50% of cuBLAS performance: 73 | depending on the exact matrix computation, it achieves 40% to 55% of the cuBLAS performance on large matrices 74 | (which is the case when using large batch x sequence sizes in Transformers for example). 75 | Practically, this means that a Transformer with BlockSparseLinear with a 50% sparsity is as fast as the dense version. 76 | This may be improved in next releases, especially when newer version of Cutlass are used. 77 | 78 | ## Related work 79 | OpenAI announced in January 2020 that their very advanced (and complex) TensorFlow code [would be ported](https://openai.com/blog/openai-pytorch/) to PyTorch. 80 | Unfortunately this has not happened yet. 81 | 82 | Google and Stanford June 2020 paper [Sparse GPU Kernels for Deep Learning](https://arxiv.org/abs/2006.10901) is promising too, as the code should be released at some time. 83 | This would be even more general, as the sparsity pattern is not constrained, and the performance looks very good, with some smart ad hoc optimizations. 84 | 85 | ## Future work 86 | - Implement some paper methods (and provide new ones) to optimize the sparse pattern during training, while doing the classic parameter optimization using backprop. The basic idea is to remove some smaller magnitude weights (or blocks of weights) at some positions and try other ones. 87 | - [Movement Pruning: Adaptive Sparsity by Fine-Tuning](https://arxiv.org/abs/2005.07683) 88 | - [Sparse Networks from Scratch: Faster Training without Losing Performance](https://arxiv.org/abs/1907.04840) 89 | - [Structured Pruning of Large Language Models](https://arxiv.org/abs/1910.04732) 90 | - [Learning Sparse Neural Networks through L0 Regularization](https://arxiv.org/abs/1712.01312), ) 91 | - Upgrade to the latest CUTLASS version to optimize speed for the latest architectures (using Tensor Cores for example) 92 | - Use the new Ampere 50% sparse pattern within blocks themselves: more information on the [Hugging Face Blog](https://medium.com/huggingface/sparse-neural-networks-2-n-gpu-performance-b8bc9ce950fc). 93 | 94 | ## Installation 95 | You can just use pip: 96 | ``` 97 | pip install pytorch-block-sparse 98 | ``` 99 | 100 | Or from source, clone this git repository, and in the root directory just execute: 101 | ``` 102 | python setup.py install 103 | ``` 104 | 105 | # Development Notes 106 | You will find them [here](doc/DevNotes.md) 107 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/gemm/block_loader.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * block-wide tile-loading abstractions 33 | */ 34 | 35 | #include "../util/util.h" 36 | 37 | namespace cutlass { 38 | namespace gemm { 39 | 40 | 41 | /****************************************************************************** 42 | * load_algorithm 43 | ******************************************************************************/ 44 | 45 | /** 46 | * \brief Enumeration of matrix loading algorithms 47 | */ 48 | struct load_algorithm 49 | { 50 | /// \brief Enumerants. See corresponding tag types. 51 | enum kind_t 52 | { 53 | CongruousCopy = 0, 54 | CrosswiseCopy = 1, 55 | CrosswiseCopyPruneDense = 2, 56 | CongruousCopyPruneSparse = 3, 57 | CongruousCopyPruneDense = 4, 58 | CrosswiseCopyPruneSparse = 5, 59 | }; 60 | 61 | /** 62 | * \brief Generic tag 63 | */ 64 | template 65 | struct any_tag : nv_std::integral_constant {}; 66 | 67 | /** 68 | * \brief Copy from a global matrix that is row-major in relation 69 | * to the local row-major tile 70 | */ 71 | typedef any_tag contiguous_tag_t; 72 | 73 | /** 74 | * \brief Copy from a global matrix that is column-major in relation 75 | * to the local row-major tile 76 | */ 77 | typedef any_tag crosswise_tag_t; 78 | 79 | }; 80 | 81 | 82 | /****************************************************************************** 83 | * block_loader 84 | ******************************************************************************/ 85 | 86 | /** 87 | * \brief A three-phase data loading abstraction (prefetch, commit, and 88 | * advance) for iterating over ranges of block-wide matrix tiles. 89 | * 90 | * Each iteration sequence produces a KxL (height-by-width) block-wide tile of 91 | * value_t in shared memory. The layout of the shared 92 | * block-wide tile is a row-major (L-major) tiling of dp_vector_t items, which are 93 | * themselves column-major (K-major) vectors of value_t. Its dimensions are: 94 | * K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t) 95 | * L = BlockDpVectorsL 96 | * 97 | * NB: This generic class is not directly constructible. Architecture- and 98 | * algorithm-specific template specializations will provide the API 99 | * functionality prescribed here. 100 | * 101 | */ 102 | template < 103 | int BlockThreads, ///< Number of threads in each thread block (blockDim.x) 104 | int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height) 105 | int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width) 106 | typename value_t, ///< Input matrix value type 107 | int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension 108 | bool AllowRaggedTiles, ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions 109 | typename dp_vector_t, ///< Dot-product vector type along the K-axis 110 | load_algorithm::kind_t LoadAlgorithm> ///< Algorithm for loading a shared tile of KxL matrix data 111 | struct block_loader 112 | { 113 | //------------------------------------------------------------------------- 114 | // Constructor API 115 | //------------------------------------------------------------------------- 116 | 117 | /// Constructor 118 | block_loader( 119 | value_t *d_matrix, ///< Pointer to input matrix 120 | int matrix_values_l, ///< Extent of the input matrix in value_t along the L-axis 121 | int matrix_values_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis 122 | int matrix_values_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis 123 | int2 block_begin_item_coords, ///< Thread block's starting value_t coordinates (l, k) within the input matrix 124 | int block_end_item_k); ///< Thread block's ending coordinate (k) within the input matrix (one-past) 125 | 126 | //------------------------------------------------------------------------- 127 | // Loader API 128 | //------------------------------------------------------------------------- 129 | 130 | /** 131 | * Request the current block-wide tile 132 | */ 133 | void request(); 134 | 135 | 136 | /** 137 | * Advance the loader to the next block-wide tile in the K-axis 138 | */ 139 | void next(); 140 | 141 | 142 | /** 143 | * Commit the previously-requested block-wide tile to shared memory 144 | * 145 | * NB: To facilitate padding for avoiding shared memory bank conflicts, we 146 | * allow the row stride _BlockDpVectorsL to be arbitrarily bigger than the 147 | * tile width BlockDpVectorsL. 148 | */ 149 | template 150 | void commit( 151 | dp_vector_t (&scratch_tile)[BlockDpVectorsK][_BlockDpVectorsL]); 152 | 153 | }; 154 | 155 | 156 | } // namespace gemm 157 | } // namespace cutlass 158 | 159 | 160 | /****************************************************************************** 161 | * Tail-include specializations that adhere to the block_loader API 162 | ******************************************************************************/ 163 | 164 | #include "block_loader_crosswise.h" 165 | #include "block_loader_congruous_dp1.h" 166 | #include "block_loader_crosswise_prune_dense.h" 167 | #include "block_loader_congruous_dp1_prune_sparse.h" 168 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/gemm/dp_accummulate.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * Abstraction for exposing architecture-specific "dot-product-accumulate" 33 | * ISA operations 34 | */ 35 | 36 | #include 37 | 38 | #include "../util/util.h" 39 | 40 | 41 | namespace cutlass { 42 | namespace gemm { 43 | 44 | 45 | /****************************************************************************** 46 | * dp_accummulate 47 | ******************************************************************************/ 48 | 49 | 50 | /** 51 | * \brief Abstraction for exposing architecture-specific "dot-product-accumulate" 52 | * ISA operations 53 | * 54 | * Given two K-component vectors a and b having type value_t[K] and an addend c 55 | * of type accum_t, the "dot-product-accumulate" of type accum_t is computed 56 | * as d = x[0]*y[0] + x[1]*y[1] + ... + x[K-1]*y[K-1] + c. 57 | * 58 | * We use the notation "dpK" to connote a K-component dot-product-accumulate. 59 | * For example, "dp1" is a simple multiply-add. 60 | * 61 | * For given pairing of value_t and accum_t types, the corresponding 62 | * dp_accummulate class will: 63 | * 64 | * - Define the member-type dp_vector_t as the appropriate K-component vector 65 | * type needed to leverage architecture-specific "dot-product accumulate" 66 | * ISA operations. 67 | * - Implement the corresponding dot-product operation between two dp_vector_t 68 | * inputs a and b. 69 | * 70 | */ 71 | template < 72 | typename value_t, ///< Component value type 73 | typename accum_t> ///< Accumulator value type 74 | struct dp_accummulate; 75 | 76 | 77 | 78 | /// Default "dp1" dot-product-accumulate traits specialization for value_t->accum_t 79 | template < 80 | typename value_t, ///< Component value type 81 | typename accum_t> ///< Accumulator value type 82 | struct dp_accummulate 83 | { 84 | /// Single-component "dp1" dot-product vector type 85 | typedef value_t dp_vector_t; 86 | 87 | 88 | /// Compute "dp1" float->float 89 | inline __device__ 90 | static void mad( 91 | float &d, 92 | const float &a, 93 | const float &b, 94 | const float &c) 95 | { 96 | asm volatile ( "fma.rn.f32 %0, %1, %2, %3;\n" 97 | : "=f"(d) : "f"(a), "f"(b), "f"(c)); 98 | } 99 | 100 | 101 | /// Compute "dp1" double->double 102 | inline __device__ 103 | static void mad( 104 | double &d, 105 | const double &a, 106 | const double &b, 107 | const double &c) 108 | { 109 | asm volatile ("fma.rn.f64 %0, %1, %2, %3;\n" 110 | : "=d"(d) : "d"(a), "d"(b), "d"(c)); 111 | } 112 | 113 | 114 | /// Compute "dp1" int16_t->int32_t 115 | inline __device__ 116 | static void mad( 117 | int32_t &d, 118 | const int16_t &a, 119 | const int16_t &b, 120 | const int32_t &c) 121 | { 122 | asm volatile ("mad.wide.s16 %0, %1, %2, %3;\n" 123 | : "=r"(d) : "h"(a), "h"(b), "r"(c)); 124 | } 125 | 126 | 127 | /// Compute "dp1" uint16_t->uint32_t 128 | inline __device__ 129 | static void mad( 130 | uint32_t &d, 131 | const uint16_t &a, 132 | const uint16_t &b, 133 | const uint32_t &c) 134 | { 135 | asm volatile ("mad.wide.u16 %0, %1, %2, %3;\n" 136 | : "=r"(d) : "h"(a), "h"(b), "r"(c)); 137 | } 138 | 139 | 140 | /// Compute "dp1" int32_t->int32_t 141 | inline __device__ 142 | static void mad( 143 | int32_t &d, 144 | const int32_t &a, 145 | const int32_t &b, 146 | const int32_t &c) 147 | { 148 | asm volatile ("mad.lo.s32 %0, %1, %2, %3;\n" 149 | : "=r"(d) : "r"(a), "r"(b), "r"(c)); 150 | } 151 | 152 | 153 | /// Compute "dp1" uint32_t->uint32_t 154 | inline __device__ 155 | static void mad( 156 | uint32_t &d, 157 | const uint32_t &a, 158 | const uint32_t &b, 159 | const uint32_t &c) 160 | { 161 | asm volatile ("mad.lo.u32 %0, %1, %2, %3;\n" 162 | : "=r"(d) : "r"(a), "r"(b), "r"(c)); 163 | } 164 | 165 | }; 166 | 167 | 168 | 169 | #if (CUTLASS_ARCH >= 610) // Specializations only enabled for Pascal SM610+ 170 | 171 | 172 | /// "dp4" dot-product-accumulate traits specialization for int8_t->int32_t 173 | template <> 174 | struct dp_accummulate< 175 | int8_t, ///< Component value type 176 | int32_t> ///< Accumulator value type 177 | { 178 | /// Four-component signed "idp4" 179 | typedef int32_t dp_vector_t; 180 | 181 | /// Compute "dp4" int16_t->int32_t 182 | inline __device__ 183 | static void mad( 184 | int32_t &d, 185 | const int32_t &a, 186 | const int32_t &b, 187 | const int32_t &c) 188 | { 189 | asm volatile ( "dp4a.s32.s32 %0, %1, %2, %3;\n" 190 | : "=r"(d) : "r"(a), "r"(b), "r"(c)); 191 | } 192 | }; 193 | 194 | 195 | /// "dp4" dot-product-accumulate traits specialization for uint8_t->uint32_t 196 | template <> 197 | struct dp_accummulate< 198 | uint8_t, ///< Component value type 199 | uint32_t> ///< Accumulator value type 200 | { 201 | /// Four-component unsigned "idp4" 202 | typedef uint32_t dp_vector_t; 203 | 204 | /// Compute "dp4" uint16_t->uint32_t 205 | inline __device__ 206 | static void mad( 207 | uint32_t &d, 208 | const uint32_t &a, 209 | const uint32_t &b, 210 | const uint32_t &c) 211 | { 212 | asm volatile ( "dp4a.u32.u32 %0, %1, %2, %3;\n" 213 | : "=r"(d) : "r"(a), "r"(b), "r"(c)); 214 | } 215 | }; 216 | 217 | 218 | #endif // Specializations only enabled for Pascal SM610+ 219 | 220 | 221 | } // namespace gemm 222 | } // namespace cutlass 223 | 224 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/gemm/grid_raster_sparse.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * Abstraction for enumerating \p block_task within an input matrix 33 | */ 34 | 35 | #include 36 | 37 | #include "../util/util.h" 38 | 39 | 40 | namespace cutlass { 41 | namespace gemm { 42 | 43 | 44 | /****************************************************************************** 45 | * grid_raster_strategy 46 | ******************************************************************************/ 47 | 48 | /** 49 | * \brief Strategies for enumerating \p block_task within an input matrix 50 | */ 51 | struct grid_raster_sparse_strategy 52 | { 53 | /// \brief Enumerants 54 | enum kind_t 55 | { 56 | /** 57 | * Default \p block_task assignment (currently ColumnMajor for N*, 58 | * RowMajor for TT, and TiledCohort for TN) 59 | */ 60 | Sparse, 61 | }; 62 | }; 63 | 64 | 65 | 66 | /****************************************************************************** 67 | * grid_raster 68 | ******************************************************************************/ 69 | 70 | /** 71 | * \brief Abstraction for enumerating \p block_task within an input matrix 72 | * 73 | * NB: This generic class is not directly constructible. Algorithm-specific 74 | * template specializations will provide the API functionality prescribed here. 75 | */ 76 | template < 77 | int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C 78 | int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C 79 | matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A 80 | matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B 81 | grid_raster_sparse_strategy::kind_t RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix 82 | struct grid_raster_sparse 83 | { 84 | //------------------------------------------------------------------------- 85 | // Device API 86 | //------------------------------------------------------------------------- 87 | 88 | /// Thread block's base item coordinates (x, y) in matrix C 89 | int2 block_item_coords_src; 90 | int2 block_item_coords_dst; 91 | 92 | /// Constructor 93 | grid_raster_sparse(int2* mapping, long mapping_length); 94 | 95 | /// Whether the thread block base coordinates are out-of-bounds for an m*n matrix C 96 | bool is_block_oob(int m, int n); 97 | 98 | 99 | //------------------------------------------------------------------------- 100 | // Grid launch API 101 | //------------------------------------------------------------------------- 102 | 103 | /// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C 104 | static dim3 grid_dims(long mapping_length); 105 | }; 106 | 107 | 108 | 109 | /****************************************************************************** 110 | * grid_raster_sparse (Sparse specialization) 111 | ******************************************************************************/ 112 | 113 | /** 114 | * \brief Abstraction for enumerating \p block_task within an input matrix 115 | * (ColumnMajor specialization) 116 | * 117 | * Maps thread blocksin column-major fashion 118 | */ 119 | template < 120 | int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C 121 | int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C 122 | matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A 123 | matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B 124 | struct grid_raster_sparse< 125 | BlockItemsY, 126 | BlockItemsX, 127 | TransformA, 128 | TransformB, 129 | grid_raster_sparse_strategy::Sparse> ///< Strategy for enumerating \p block_task within an input matrix 130 | { 131 | //------------------------------------------------------------------------- 132 | // Device API 133 | //------------------------------------------------------------------------- 134 | 135 | /// Thread block's base item coordinates (x, y) in matrix C 136 | int2 block_item_coords_src; 137 | int2 block_item_coords_dst; 138 | 139 | /// Constructor 140 | inline __device__ 141 | grid_raster_sparse(int2* mapping, long mapping_length) 142 | { 143 | // printf("ColumnMajor\n"); 144 | // blockDim.x is the fastest changing grid dim on current architectures 145 | int2 pos = mapping[blockIdx.x]; 146 | 147 | block_item_coords_src = make_int2( 148 | BlockItemsX * pos.x, 149 | BlockItemsY * pos.y); 150 | block_item_coords_dst = make_int2( 151 | BlockItemsX * blockIdx.x, 152 | 0); 153 | 154 | // printf("blockIdx.y: %d, blockIdx.x: %d \n", blockIdx.y, blockIdx.x); 155 | } 156 | 157 | /// Whether the base \p block_item_coords are out-of-bounds for an m*n matrix C 158 | inline __device__ 159 | bool is_block_oob(int m, int n) 160 | { 161 | // ColumnMajor never rasterizes fully out-of-bounds thread blocks 162 | return false; 163 | } 164 | 165 | //------------------------------------------------------------------------- 166 | // Grid launch API 167 | //------------------------------------------------------------------------- 168 | 169 | /// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C 170 | inline __host__ __device__ 171 | static dim3 grid_dims(long mapping_length) 172 | { 173 | // blockDim.x is the fastest changing grid dim on current architectures 174 | return dim3(mapping_length); 175 | } 176 | }; 177 | 178 | } // namespace gemm 179 | } // namespace cutlass 180 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/util/device_introspection.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * \brief Utilities for device introspection 33 | */ 34 | 35 | #include "debug.h" 36 | #include "nv_std.h" 37 | #include "printable.h" 38 | 39 | namespace cutlass { 40 | 41 | 42 | /****************************************************************************** 43 | * math_operation_class_t 44 | * 45 | * Enumeration to select the appropriate math operation 46 | * 47 | * The assumption is multiple math operations may be used to compute GEMM 48 | * for a given selection of operand and accumulator types. 49 | * 50 | ******************************************************************************/ 51 | 52 | /// Math operation 53 | enum class math_operation_class_t 54 | { 55 | scalar, // scalar (and vector) multiply-accumulate operations 56 | matrix // Volta tensor operations 57 | }; 58 | 59 | /****************************************************************************** 60 | * arch_family_t 61 | ******************************************************************************/ 62 | 63 | /** 64 | * \brief Enumeration of NVIDIA GPU architectural families 65 | */ 66 | struct arch_family_t 67 | { 68 | /// \brief Enumerants 69 | enum kind_t 70 | { 71 | Unsupported = 0, 72 | Kepler = 3, 73 | Maxwell = 5, 74 | Volta = 7, 75 | }; 76 | 77 | /// Enumerant value 78 | kind_t kind; 79 | 80 | /// Default constructor 81 | arch_family_t() : kind(Unsupported) {} 82 | 83 | /// Copy constructor 84 | arch_family_t(const kind_t &other_kind) : kind(other_kind) {} 85 | 86 | /// Cast to kind_t 87 | operator kind_t() const { return kind; } 88 | 89 | /// Returns the instance as a string 90 | __host__ __device__ inline 91 | char const* to_string() const 92 | { 93 | switch (kind) 94 | { 95 | case Kepler: return "Kepler"; 96 | case Maxwell: return "Maxwell"; 97 | case Volta: return "Volta"; 98 | case Unsupported: 99 | default: return "Unsupported"; 100 | } 101 | } 102 | 103 | /// Insert the formatted instance into the output stream 104 | void print(std::ostream& out) const { out << to_string(); } 105 | 106 | }; 107 | 108 | 109 | /** 110 | * Macro for architecture targeted by the current compiler pass 111 | */ 112 | #if defined(__CUDA_ARCH__) 113 | #define CUTLASS_ARCH __CUDA_ARCH__ 114 | #else 115 | #define CUTLASS_ARCH 0 116 | #endif 117 | 118 | 119 | /** 120 | * Macro for architecture family targeted by the current compiler pass 121 | */ 122 | #define CUTLASS_ARCH_FAMILY \ 123 | ( \ 124 | (CUTLASS_ARCH < 300) ? \ 125 | arch_family_t::Unsupported : \ 126 | (CUTLASS_ARCH < 500) ? \ 127 | arch_family_t::Kepler : \ 128 | (CUTLASS_ARCH < 700) ? \ 129 | arch_family_t::Maxwell : \ 130 | arch_family_t::Volta \ 131 | ) 132 | 133 | 134 | 135 | 136 | /****************************************************************************** 137 | * Device introspection 138 | ******************************************************************************/ 139 | 140 | /** 141 | * Empty kernel for querying PTX manifest metadata (e.g., version) for the current device 142 | */ 143 | template 144 | __global__ void empty_kernel(void) { } 145 | 146 | 147 | 148 | /** 149 | * \brief Retrieves the PTX version that will be used on the current device (major * 100 + minor * 10) 150 | */ 151 | inline cudaError_t ptx_version(int &version) 152 | { 153 | struct Dummy 154 | { 155 | /// Type definition of the empty_kernel kernel entry point 156 | typedef void (*EmptyKernelPtr)(); 157 | 158 | /// Force empty_kernel to be generated if this class is used 159 | EmptyKernelPtr Empty() 160 | { 161 | return empty_kernel; 162 | } 163 | }; 164 | 165 | cudaError_t error = cudaSuccess; 166 | do 167 | { 168 | cudaFuncAttributes empty_kernel_attrs; 169 | if (CUDA_PERROR_DEBUG(error = cudaFuncGetAttributes(&empty_kernel_attrs, empty_kernel))) break; 170 | version = empty_kernel_attrs.ptxVersion * 10; 171 | } 172 | while (0); 173 | 174 | return error; 175 | } 176 | 177 | 178 | /** 179 | * \brief Retrieves the SM version (major * 100 + minor * 10) for the current device 180 | */ 181 | inline cudaError_t get_sm_version(int &sm_version) 182 | { 183 | cudaError_t error = cudaSuccess; 184 | 185 | // Get device ordinal 186 | int device_ordinal; 187 | if (CUDA_PERROR_DEBUG(error = cudaGetDevice(&device_ordinal))) 188 | return error; 189 | 190 | // Fill in SM version 191 | int major, minor; 192 | if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_ordinal))) 193 | return error; 194 | if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_ordinal))) 195 | return error; 196 | sm_version = major * 100 + minor * 10; 197 | 198 | return error; 199 | } 200 | 201 | 202 | /** 203 | * \brief Retrieves the count for the current device 204 | */ 205 | inline cudaError_t get_sm_count(int &sm_count) 206 | { 207 | cudaError_t error = cudaSuccess; 208 | 209 | // Get device ordinal 210 | int device_ordinal; 211 | if (CUDA_PERROR_DEBUG(error = cudaGetDevice(&device_ordinal))) 212 | return error; 213 | 214 | // Get SM count 215 | if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal))) 216 | return error; 217 | 218 | return error; 219 | } 220 | 221 | 222 | } // namespace cutlass 223 | 224 | 225 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/gemm/thread_accumulator.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * Thread-level multiply-accumulate abstraction 33 | */ 34 | 35 | #include "../util/util.h" 36 | #include "dp_accummulate.h" 37 | 38 | 39 | namespace cutlass { 40 | namespace gemm { 41 | 42 | 43 | /****************************************************************************** 44 | * thread_accumulator (generic specialization) 45 | ******************************************************************************/ 46 | 47 | /** 48 | * \brief Thread-level multiply-accumulate abstraction (generic specialization) 49 | * 50 | * The thread_accumulator class maintains a MxN tile of accumulators in 51 | * registers to which MxNxK matrix products of two thread tiles A (MxK) 52 | * and B (KxN) can be added, where: 53 | * M = ThreadItemsY 54 | * N = ThreadItemsX 55 | * K = sizeof(dp_vector_t) / sizeof(value_t). 56 | * 57 | * In order to leverage architecture-specific "dot-product accumulate" ISA 58 | * operations, K is dictated by the thread_accumulator class in the form of 59 | * the member-type dp_vector_t, which defines a K-component vector of value_t. 60 | * The multiplicand inputs A and B are provided as arrays of dp_vector_t having 61 | * extents ThreadItemsY and ThreadItemsX, respectively. (In the single 62 | * component "dp1" scenario where dp_vector_t == value_t and thus K == 1, the 63 | * multiplication is simply the outer product of two vectors.) 64 | * 65 | * The accumulators are zero-initialized in a two-phase process (construction + 66 | * initialization) that requires shared storage in the form of the member-type 67 | * scratch_storage_t during construction. (A single scratch_storage_t instance 68 | * can be uniformly referenced across all threads in the block during 69 | * construction *if* the block is synchronized between construction and 70 | * initialization.) 71 | * 72 | * NB: This generic class is not directly constructible. Architecture- and 73 | * algorithm-specific template specializations will provide the API 74 | * functionality prescribed here. 75 | */ 76 | template < 77 | int ThreadItemsY, ///< Height of thread tile in accum_t 78 | int ThreadItemsX, ///< Width of thread tile in accum_t 79 | typename value_t, ///< Multiplicand value type 80 | typename accum_t, ///< Accumulator value type 81 | int ACCUM_BYTES = ///< Size in bytes of accum_t 82 | sizeof(accum_t), 83 | arch_family_t::kind_t ArchFamily = ///< Architectural family enumerant 84 | CUTLASS_ARCH_FAMILY> 85 | struct thread_accumulator 86 | { 87 | protected: 88 | 89 | //------------------------------------------------------------------------- 90 | // Constants and types 91 | //------------------------------------------------------------------------- 92 | 93 | /// Specialized dot-product traits type 94 | typedef dp_accummulate dp_accum_traits_t; 95 | 96 | 97 | public: 98 | 99 | //------------------------------------------------------------------------- 100 | // Member types 101 | //------------------------------------------------------------------------- 102 | 103 | /// Dot-product vector type 104 | typedef typename dp_accum_traits_t::dp_vector_t dp_vector_t; // for "dp1", it is value_t 105 | 106 | /// Scratch storage layout 107 | struct scratch_storage_t {}; 108 | 109 | 110 | protected: 111 | 112 | //------------------------------------------------------------------------- 113 | // Data members 114 | //------------------------------------------------------------------------- 115 | 116 | /// Thread's tile of accumulators 117 | accum_t accumulators[ThreadItemsY][ThreadItemsX]; 118 | 119 | 120 | //------------------------------------------------------------------------- 121 | // Utility methods 122 | //------------------------------------------------------------------------- 123 | 124 | /** 125 | * Compute a multiply-add at accumulator coordinates (x, y) 126 | */ 127 | inline __device__ 128 | void mad_xy( 129 | dp_vector_t (&tile_a)[ThreadItemsY], 130 | dp_vector_t (&tile_b)[ThreadItemsX], 131 | int x, 132 | int y) 133 | { 134 | dp_accum_traits_t::mad( 135 | accumulators[y][x], 136 | tile_a[y], 137 | tile_b[x], 138 | accumulators[y][x]); 139 | } 140 | 141 | public: 142 | 143 | //------------------------------------------------------------------------- 144 | // Constructor API 145 | //------------------------------------------------------------------------- 146 | 147 | /// Constructor 148 | inline __device__ 149 | thread_accumulator( 150 | scratch_storage_t &scratch) 151 | {} 152 | 153 | 154 | //------------------------------------------------------------------------- 155 | // Accumulator API 156 | //------------------------------------------------------------------------- 157 | 158 | /** 159 | * \brief Zero-initialize thread accumulators. 160 | * 161 | * If a common reference to a single block-wide shared instance of scratch_storage_t 162 | * is used during construction, the block must be synchronized after construction 163 | * but prior to the invocation of init(). 164 | */ 165 | inline __device__ 166 | void init() 167 | { 168 | #pragma unroll 169 | for (int y = 0; y < ThreadItemsY; ++y) { 170 | #pragma unroll 171 | for (int x = 0; x < ThreadItemsX; ++x) 172 | { 173 | accumulators[y][x] = accum_t(0); 174 | } 175 | } 176 | } 177 | 178 | 179 | /** 180 | * Retrieve the accumulator at thread tile coordinates (x, y) 181 | */ 182 | inline __device__ 183 | accum_t get(int x, int y) 184 | { 185 | // Accumulators are row-major 186 | return accumulators[y][x]; 187 | } 188 | 189 | 190 | /** 191 | * \brief Compute the product of tile_a and tile_b and add the result to 192 | * the tile of accumulators. 193 | */ 194 | inline __device__ 195 | void multiply_accumulate( 196 | dp_vector_t (&tile_a)[ThreadItemsY], 197 | dp_vector_t (&tile_b)[ThreadItemsX]) 198 | { 199 | // Simply traverse the accumulator tile in row-major order 200 | #pragma unroll 201 | for (int y = 0; y < ThreadItemsY; ++y) 202 | { 203 | #pragma unroll 204 | for (int x = 0; x < ThreadItemsX; ++x) 205 | { 206 | mad_xy(tile_a, tile_b, x, y); 207 | } 208 | } 209 | } 210 | }; 211 | 212 | } // namespace gemm 213 | } // namespace cutlass 214 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_matmul_back.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import torch 5 | 6 | from pytorch_block_sparse import BlockSparseMatrix 7 | 8 | 9 | class TestFun(TestCase): 10 | def helper_( 11 | self, 12 | sizes, 13 | block_size, 14 | block_count=None, 15 | blocks=None, 16 | density=None, 17 | iterations=1, 18 | non_contiguous_a=False, 19 | non_contiguous_b=False, 20 | ): 21 | device = "cuda" 22 | 23 | if isinstance(sizes[0], tuple): 24 | sizes_0 = sizes[0] 25 | else: 26 | sizes_0 = (sizes[0],) 27 | 28 | # Build positive matrices to easily check results 29 | a = torch.randn(sizes_0 + (sizes[1],), device=device).abs() 30 | b = torch.randn(sizes_0 + (sizes[2],), device=device).abs() 31 | 32 | if non_contiguous_a: 33 | a = a.transpose(-2, -1).contiguous().transpose(-2, -1) 34 | 35 | if non_contiguous_b: 36 | b = b.transpose(-2, -1).contiguous().transpose(-2, -1) 37 | 38 | if block_count is None and blocks is None: 39 | total_block_count = sizes[1] * sizes[2] / block_size[0] / block_size[1] 40 | block_count = int(total_block_count * density) 41 | 42 | bsm = BlockSparseMatrix.zeros((sizes[2], sizes[1]), block_count, blocks, block_size, device=device) 43 | 44 | results = {} 45 | 46 | kinds = ["pytorch", "cutlass"] 47 | kinds.reverse() 48 | for kind in kinds: 49 | start = torch.cuda.Event(enable_timing=True) 50 | end = torch.cuda.Event(enable_timing=True) 51 | 52 | start.record() 53 | for i in range(iterations): 54 | if kind == "pytorch": 55 | aa = a.reshape(-1, a.shape[-1]) 56 | bb = b.reshape(-1, b.shape[-1]) 57 | bb = bb.t() 58 | c = bb.mm(aa) 59 | elif kind == "cutlass": 60 | bsm.matmul_with_output_sparse_support(b, a, overwrite_data=True) 61 | c = bsm 62 | 63 | end.record() 64 | torch.cuda.synchronize() 65 | elapsed = start.elapsed_time(end) 66 | 67 | result = dict(kind=kind, elapsed=elapsed, output=c) 68 | results[kind] = result 69 | 70 | if "pytorch" in results: 71 | c0 = results["pytorch"]["output"] 72 | 73 | for k, t in results.items(): 74 | if k == "pytorch": 75 | t["comparison"] = True 76 | continue 77 | c = t["output"] 78 | 79 | c_dense = c.to_dense() 80 | 81 | c0_ = c0 * (c_dense != 0) 82 | 83 | s = c_dense.isclose(c0_, rtol=1e-4).all() 84 | 85 | if not s.item(): 86 | print( 87 | "max difference %s=" % t["kind"], 88 | float((c_dense - c0_).abs().max()), 89 | float(c.data.abs().max()), 90 | ) 91 | raise Exception( 92 | "Comparison NOK : matmul_with_output_sparse_support issue for ", 93 | k, 94 | ) 95 | t["comparison"] = False 96 | else: 97 | # print("Comparison OK for matmul_with_output_sparse_support for ", k) 98 | # print("max difference %s=" % t["kind"], float((c_dense - c0_).abs().max())) 99 | t["comparison"] = True 100 | 101 | return results 102 | 103 | def helper( 104 | self, 105 | sizes, 106 | block_size, 107 | density, 108 | iterations, 109 | inner_iterations, 110 | block_count=None, 111 | blocks=None, 112 | non_contiguous_a=False, 113 | non_contiguous_b=False, 114 | ): 115 | 116 | import functools 117 | import operator 118 | 119 | if isinstance(sizes[0], int): 120 | sizes_0 = sizes[0] 121 | else: 122 | sizes_0 = functools.reduce(operator.mul, sizes[0], 1) 123 | 124 | flops = float(2 * sizes_0 * sizes[1] * sizes[2]) 125 | 126 | report = {} 127 | for i in range(iterations): 128 | results = self.helper_( 129 | sizes, 130 | block_size, 131 | block_count=block_count, 132 | blocks=blocks, 133 | density=density, 134 | iterations=inner_iterations, 135 | non_contiguous_a=non_contiguous_a, 136 | non_contiguous_b=non_contiguous_b, 137 | ) 138 | 139 | if "pytorch" in results: 140 | pytorch_time = results["pytorch"]["elapsed"] 141 | else: 142 | pytorch_time = None 143 | 144 | for kind, d in results.items(): 145 | if kind == "pytorch": 146 | continue 147 | if kind not in report: 148 | report[kind] = {True: 0, False: 0} 149 | if "comparison" in d: 150 | report[kind][d["comparison"]] += 1 151 | 152 | kind = d["kind"] 153 | kind_elapsed = d["elapsed"] 154 | if pytorch_time is None: 155 | ratio = "Unknown" 156 | else: 157 | ratio = kind_elapsed / pytorch_time 158 | 159 | print( 160 | "kind = %s, elapsed=%f, gflops = %f, ratio = %s" 161 | % ( 162 | kind, 163 | kind_elapsed, 164 | flops * inner_iterations / kind_elapsed / 1e6, 165 | ratio, 166 | ) 167 | ) 168 | 169 | return results 170 | 171 | def check(self, results, sizes, block_size, blocks, verbose=False): 172 | if isinstance(sizes[0], tuple): 173 | sizes_0 = 1 174 | for s in sizes[0]: 175 | sizes_0 *= s 176 | else: 177 | sizes_0 = sizes[0] 178 | 179 | cutlass_result = results["cutlass"]["output"] 180 | pytorch_result = results["pytorch"]["output"] 181 | 182 | if verbose: 183 | # print(cutlass_result) 184 | 185 | stride = 4 186 | print("cutlass block[0][0]", cutlass_result.data[::stride, ::stride].t()) 187 | print("pytorch blocks[0][0]", pytorch_result[::stride, ::stride]) 188 | for i in range(cutlass_result.blocks.shape[0] // 2): 189 | b = cutlass_result.blocks[i * 2 : i * 2 + 2].flip(0) * torch.tensor( 190 | block_size, device=cutlass_result.blocks.device 191 | ) 192 | b_pytorch = pytorch_result[b[0] : b[0] + block_size[0], b[1] : b[1] + block_size[1]] 193 | 194 | b_cutlass = cutlass_result.data[i * 32 : i * 32 + 32].t() 195 | 196 | compare = b_pytorch.isclose(b_cutlass, rtol=1e-4) 197 | if not compare.all().item(): 198 | rel_diff = ((b_pytorch - b_cutlass).abs() / (1e-9 + b_pytorch.abs())).abs().max() 199 | max_diff = (b_pytorch - b_cutlass).abs().max() 200 | print( 201 | f"rel diff={rel_diff}, max diff={max_diff}, max_pytorch={b_pytorch.abs().max()}, max_cutlass={b_cutlass.abs().max()}" 202 | ) 203 | raise Exception(f"Comparison failed out_shape={cutlass_result.shape} blocks={blocks} sizes={sizes}") 204 | 205 | def test0(self): 206 | bsize = 32 207 | tests = [ 208 | dict( 209 | sizes=[bsize * 2, bsize * 4, bsize * 8], 210 | block_tests=[ 211 | [(0, 0)], 212 | [(0, 1)], 213 | [(1, 0)], 214 | [(1, 0), (0, 2)], 215 | [(1, 0), (2, 0), (3, 0)], 216 | ], 217 | ), 218 | dict(sizes=[1, bsize, bsize], block_tests=[[(0, 0)]]), 219 | ] 220 | block_size = (32, 32) 221 | 222 | for test in tests: 223 | sizes = test["sizes"] 224 | blocks_tests = test["block_tests"] 225 | for blocks in blocks_tests: 226 | for non_contiguous_a in [False, True]: 227 | for non_contiguous_b in [False, True]: 228 | results = self.helper( 229 | sizes, 230 | block_size, 231 | density=None, 232 | blocks=blocks, 233 | iterations=1, 234 | inner_iterations=1, 235 | non_contiguous_a=non_contiguous_a, 236 | non_contiguous_b=non_contiguous_b, 237 | ) 238 | self.check(results, sizes, block_size, blocks, verbose=False) 239 | 240 | def test1(self): 241 | size = 512 242 | 243 | test_sizes = [ 244 | [(size * 16, 8), size * 2, size * 4], 245 | [1, size * 2, size * 4], 246 | ] 247 | test_densities = [1.0] # 0.47, 1.0] 248 | 249 | block_size = (32, 32) 250 | iterations = 4 251 | inner_iterations = 10 252 | 253 | for sizes in test_sizes: 254 | for density in test_densities: 255 | for non_contiguous_a in [False, True]: 256 | for non_contiguous_b in [False, True]: 257 | results = self.helper( 258 | sizes, 259 | block_size, 260 | density, 261 | iterations, 262 | inner_iterations, 263 | block_count=None, 264 | non_contiguous_a=non_contiguous_a, 265 | non_contiguous_b=non_contiguous_b, 266 | ) 267 | try: 268 | self.check( 269 | results, 270 | sizes, 271 | block_size, 272 | results["cutlass"]["output"].blocks, 273 | ) 274 | except Exception: 275 | raise Exception( 276 | f"Comparison NOK : matmul_with_output_sparse_support issue for sizes={sizes}, density={density}, non_contiguous_a={non_contiguous_a}, non_contiguous_b={non_contiguous_b}" 277 | ) 278 | 279 | 280 | if __name__ == "__main__": 281 | unittest.main() 282 | -------------------------------------------------------------------------------- /pytorch_block_sparse/native/cutlass_dispatch.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file Dispatch routines for CUTLASS GEMM kernels 32 | */ 33 | 34 | // CUDA includes 35 | #include 36 | 37 | // Cutlass GEMM API 38 | #include 39 | #include 40 | #include 41 | 42 | namespace cutlass { 43 | 44 | 45 | 46 | /****************************************************************************** 47 | * Cutlass dispatch entrypoints 48 | ******************************************************************************/ 49 | 50 | // 51 | // Compile-time overrides for alignment and ragged handling. 52 | // 53 | 54 | // If zero, all feasible alignment options are supported. 55 | #ifndef GEMM_ALIGNMENT 56 | #define GEMM_ALIGNMENT 0 57 | #endif 58 | 59 | // If true, kernels are compiled with ragged handling enabled. 60 | #ifndef GEMM_RAGGED 61 | #define GEMM_RAGGED true 62 | #endif 63 | 64 | // 65 | // Dispatch logic given problem size specialization, math operation class, layout 66 | // and type of operands, and epilogue operation. 67 | // 68 | 69 | /** 70 | * Cutlass GEMM dispatch 71 | */ 72 | template < 73 | gemm::tiling_strategy::kind_t _TilingStrategy, ///< Tile-sizing classification category 74 | math_operation_class_t _math_op, // Indicates 75 | matrix_transform_t::kind_t _TransformA, ///< Transformation op for matrix A 76 | matrix_transform_t::kind_t _TransformB, ///< Transformation op for matrix B 77 | typename _value, ///< Multiplicand value type (matrices A and B) 78 | typename _accum, ///< Accumulator value type (matrix C and scalars) 79 | typename _epilogue_op_t ///< Epilogue opeartion to update matrix C 80 | = gemm::blas_scaled_epilogue<_accum, _accum, _accum> 81 | > 82 | struct cutlass_gemm_dispatch 83 | { 84 | // 85 | // Type alias definitions 86 | // 87 | 88 | static const gemm::tiling_strategy::kind_t TilingStrategy = _TilingStrategy; 89 | static const math_operation_class_t math_op = _math_op; 90 | static const matrix_transform_t::kind_t TransformA = _TransformA; 91 | static const matrix_transform_t::kind_t TransformB = _TransformB; 92 | 93 | 94 | using value_t = _value; 95 | using accum_t = _accum; 96 | using epilogue_op_t = _epilogue_op_t; 97 | 98 | 99 | // 100 | // Methods 101 | // 102 | 103 | /// Returns leading dimension for A matrix operand 104 | int leading_dim_a(int m, int k) const 105 | { 106 | return (TransformA == matrix_transform_t::NonTranspose ? m : k); 107 | } 108 | 109 | /// Returns leading dimension for B matrix operand 110 | int leading_dim_b(int k, int n) const 111 | { 112 | return (TransformB == matrix_transform_t::NonTranspose ? k : n); 113 | } 114 | 115 | /// Launches a GEMM 116 | template 117 | gemm::launch_configuration launch( 118 | int m, 119 | int n, 120 | int k, 121 | epilogue_op_t epilogue_op, 122 | value_t *A, 123 | value_t *B, 124 | int *B_ptr, 125 | int *B_indices, 126 | accum_t *C, 127 | cudaStream_t stream = 0, 128 | bool debug_synchronous = false) 129 | { 130 | 131 | // printf("operand_alignment: %d \n", operand_alignment); // it is 16, while value_t is 4. 132 | return gemm::device_gemm< 133 | TilingStrategy, 134 | math_op, 135 | TransformA, 136 | operand_alignment, 137 | TransformB, 138 | operand_alignment, 139 | value_t, 140 | accum_t, 141 | epilogue_op_t, 142 | accumulator_alignment> 143 | ( 144 | m, 145 | n, 146 | k, 147 | epilogue_op, 148 | A, 149 | B, 150 | B_ptr, 151 | B_indices, 152 | C, 153 | stream, 154 | debug_synchronous); 155 | } 156 | 157 | /// Dispatches a CUTLASS GEMM 158 | gemm::launch_configuration operator()( 159 | int m, ///< Rows of GEMM problem 160 | int n, ///< Columns of GEMM problem 161 | int k, ///< Inner dimension of GEMM problem 162 | value_t *A, ///< A matrix 163 | value_t *B, ///< B matrix 164 | int *B_ptr, ///< ptr of pruned B matrix 165 | int *B_indices, ///< indices of pruned B matrix 166 | accum_t *C, ///< C matrix 167 | accum_t alpha, ///< Scalar used for multiplicands 168 | accum_t beta, ///< Scalar used for addend 169 | cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. 170 | bool debug_synchronous = false) ///< Whether or not to synchronize the stream 171 | /// after every kernel launch to check for errors. 172 | { 173 | 174 | // Forces kernel selection to choose specific alignment (in bytes) 175 | int const force_operand_alignment = GEMM_ALIGNMENT; 176 | 177 | // Problem size must be multiple of the smallest vector load size 178 | typedef value_t operand_load_t; 179 | int const accumulator_alignment = sizeof(accum_t); 180 | 181 | int const lda = leading_dim_a(m, k); 182 | int const ldb = leading_dim_b(k, n); 183 | 184 | epilogue_op_t epilogue(alpha, beta); 185 | 186 | // TODO: opportunity for metaprogramming loop 187 | 188 | // Prefer the largest granularity of vector load that is compatible with 189 | // problem size and data alignment. 190 | if ((!force_operand_alignment || force_operand_alignment == 16) && 191 | !((sizeof(operand_load_t) * lda) % 16) && 192 | !((sizeof(operand_load_t) * ldb) % 16)) 193 | { 194 | // printf("here!\n"); // Here!! 195 | #if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 16) 196 | return launch<__NV_STD_MAX(16, sizeof(value_t)), accumulator_alignment>( 197 | m, 198 | n, 199 | k, 200 | epilogue, 201 | A, 202 | B, 203 | B_ptr, 204 | B_indices, 205 | C, 206 | stream, 207 | debug_synchronous); 208 | #endif 209 | } 210 | else if ((!force_operand_alignment || force_operand_alignment == 8) && 211 | !((sizeof(operand_load_t) * lda) % 8) && 212 | !((sizeof(operand_load_t) * ldb) % 8)) 213 | { 214 | #if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 8) 215 | return launch<__NV_STD_MAX(8, sizeof(value_t)), accumulator_alignment>( 216 | m, 217 | n, 218 | k, 219 | epilogue, 220 | A, 221 | B, 222 | B_ptr, 223 | B_indices, 224 | C, 225 | stream, 226 | debug_synchronous); 227 | #endif 228 | } 229 | else if ((!force_operand_alignment || force_operand_alignment == 4) && 230 | !((sizeof(operand_load_t) * lda) % 4) && 231 | !((sizeof(operand_load_t) * ldb) % 4)) 232 | { 233 | #if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 4) 234 | return launch<__NV_STD_MAX(4, sizeof(value_t)), accumulator_alignment>( 235 | m, 236 | n, 237 | k, 238 | epilogue, 239 | A, 240 | B, 241 | B_ptr, 242 | B_indices, 243 | C, 244 | stream, 245 | debug_synchronous); 246 | #endif 247 | } 248 | 249 | return gemm::launch_configuration(cudaErrorInvalidValue); 250 | } 251 | }; 252 | 253 | 254 | } // namespace cutlass 255 | -------------------------------------------------------------------------------- /pytorch_block_sparse/native/cutlass_dispatch_back.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file Dispatch routines for CUTLASS GEMM kernels 32 | */ 33 | 34 | // CUDA includes 35 | #include 36 | 37 | // Cutlass GEMM API 38 | #include 39 | #include 40 | #include 41 | 42 | namespace cutlass { 43 | 44 | 45 | 46 | /****************************************************************************** 47 | * Cutlass dispatch entrypoints 48 | ******************************************************************************/ 49 | 50 | // 51 | // Compile-time overrides for alignment and ragged handling. 52 | // 53 | 54 | // If zero, all feasible alignment options are supported. 55 | #ifndef GEMM_ALIGNMENT 56 | #define GEMM_ALIGNMENT 0 57 | #endif 58 | 59 | // If true, kernels are compiled with ragged handling enabled. 60 | #ifndef GEMM_RAGGED 61 | #define GEMM_RAGGED true 62 | #endif 63 | 64 | // 65 | // Dispatch logic given problem size specialization, math operation class, layout 66 | // and type of operands, and epilogue operation. 67 | // 68 | 69 | /** 70 | * Cutlass GEMM dispatch 71 | */ 72 | template < 73 | gemm::tiling_strategy::kind_t _TilingStrategy, ///< Tile-sizing classification category 74 | math_operation_class_t _math_op, // Indicates 75 | matrix_transform_t::kind_t _TransformA, ///< Transformation op for matrix A 76 | matrix_transform_t::kind_t _TransformB, ///< Transformation op for matrix B 77 | typename _value, ///< Multiplicand value type (matrices A and B) 78 | typename _accum, ///< Accumulator value type (matrix C and scalars) 79 | typename _epilogue_op_t ///< Epilogue opeartion to update matrix C 80 | = gemm::blas_scaled_epilogue<_accum, _accum, _accum> 81 | > 82 | struct cutlass_gemm_dispatch_back 83 | { 84 | // 85 | // Type alias definitions 86 | // 87 | 88 | static const gemm::tiling_strategy::kind_t TilingStrategy = _TilingStrategy; 89 | static const math_operation_class_t math_op = _math_op; 90 | static const matrix_transform_t::kind_t TransformA = _TransformA; 91 | static const matrix_transform_t::kind_t TransformB = _TransformB; 92 | 93 | 94 | using value_t = _value; 95 | using accum_t = _accum; 96 | using epilogue_op_t = _epilogue_op_t; 97 | 98 | 99 | // 100 | // Methods 101 | // 102 | 103 | /// Returns leading dimension for A matrix operand 104 | int leading_dim_a(int m, int k) const 105 | { 106 | return (TransformA == matrix_transform_t::NonTranspose ? m : k); 107 | } 108 | 109 | /// Returns leading dimension for B matrix operand 110 | int leading_dim_b(int k, int n) const 111 | { 112 | return (TransformB == matrix_transform_t::NonTranspose ? k : n); 113 | } 114 | 115 | /// Launches a GEMM 116 | template 117 | gemm::launch_configuration_back launch_back( 118 | int m, 119 | int n, 120 | int k, 121 | epilogue_op_t epilogue_op, 122 | value_t *A, 123 | value_t *B, 124 | accum_t *C, 125 | int2 *C_blocks, 126 | long C_blocks_length, 127 | cudaStream_t stream = 0, 128 | bool debug_synchronous = false) 129 | { 130 | 131 | // printf("operand_alignment: %d \n", operand_alignment); // it is 16, while value_t is 4. 132 | return gemm::device_gemm_back< 133 | TilingStrategy, 134 | math_op, 135 | TransformA, 136 | operand_alignment, 137 | TransformB, 138 | operand_alignment, 139 | value_t, 140 | accum_t, 141 | epilogue_op_t, 142 | accumulator_alignment> 143 | ( 144 | m, 145 | n, 146 | k, 147 | epilogue_op, 148 | A, 149 | B, 150 | C, 151 | C_blocks, 152 | C_blocks_length, 153 | stream, 154 | debug_synchronous); 155 | } 156 | 157 | /// Dispatches a CUTLASS GEMM 158 | gemm::launch_configuration_back operator()( 159 | int m, ///< Rows of GEMM problem 160 | int n, ///< Columns of GEMM problem 161 | int k, ///< Inner dimension of GEMM problem 162 | value_t *A, ///< A matrix 163 | value_t *B, ///< B matrix 164 | accum_t *C, ///< C matrix 165 | int2 *C_blocks, ///< ptr of pruned C matrix 166 | long C_blocks_length, ///< indices of pruned C matrix 167 | accum_t alpha, ///< Scalar used for multiplicands 168 | accum_t beta, ///< Scalar used for addend 169 | cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. 170 | bool debug_synchronous = false) ///< Whether or not to synchronize the stream 171 | /// after every kernel launch to check for errors. 172 | { 173 | 174 | // Forces kernel selection to choose specific alignment (in bytes) 175 | int const force_operand_alignment = GEMM_ALIGNMENT; 176 | 177 | // Problem size must be multiple of the smallest vector load size 178 | typedef value_t operand_load_t; 179 | int const accumulator_alignment = sizeof(accum_t); 180 | 181 | int const lda = leading_dim_a(m, k); 182 | int const ldb = leading_dim_b(k, n); 183 | 184 | epilogue_op_t epilogue(alpha, beta); 185 | 186 | // TODO: opportunity for metaprogramming loop 187 | 188 | // Prefer the largest granularity of vector load that is compatible with 189 | // problem size and data alignment. 190 | if ((!force_operand_alignment || force_operand_alignment == 16) && 191 | !((sizeof(operand_load_t) * lda) % 16) && 192 | !((sizeof(operand_load_t) * ldb) % 16)) 193 | { 194 | // printf("here!\n"); // Here!! 195 | #if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 16) 196 | return launch_back<__NV_STD_MAX(16, sizeof(value_t)), accumulator_alignment>( 197 | m, 198 | n, 199 | k, 200 | epilogue, 201 | A, 202 | B, 203 | C, 204 | C_blocks, 205 | C_blocks_length, 206 | stream, 207 | debug_synchronous); 208 | #endif 209 | } 210 | else if ((!force_operand_alignment || force_operand_alignment == 8) && 211 | !((sizeof(operand_load_t) * lda) % 8) && 212 | !((sizeof(operand_load_t) * ldb) % 8)) 213 | { 214 | #if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 8) 215 | return launch_back<__NV_STD_MAX(8, sizeof(value_t)), accumulator_alignment>( 216 | m, 217 | n, 218 | k, 219 | epilogue, 220 | A, 221 | B, 222 | C, 223 | C_blocks, 224 | C_blocks_length, 225 | stream, 226 | debug_synchronous); 227 | #endif 228 | } 229 | else if ((!force_operand_alignment || force_operand_alignment == 4) && 230 | !((sizeof(operand_load_t) * lda) % 4) && 231 | !((sizeof(operand_load_t) * ldb) % 4)) 232 | { 233 | #if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 4) 234 | return launch_back<__NV_STD_MAX(4, sizeof(value_t)), accumulator_alignment>( 235 | m, 236 | n, 237 | k, 238 | epilogue, 239 | A, 240 | B, 241 | C, 242 | C_blocks, 243 | C_blocks_length, 244 | stream, 245 | debug_synchronous); 246 | #endif 247 | } 248 | 249 | return gemm::launch_configuration_back(cudaErrorInvalidValue); 250 | } 251 | }; 252 | 253 | 254 | } // namespace cutlass 255 | -------------------------------------------------------------------------------- /pytorch_block_sparse/block_sparse_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.autograd 6 | import torch.nn as nn 7 | 8 | from .block_sparse import ( 9 | BlockSparseMatrix, 10 | BlockSparseMatrixBase, 11 | BlockSparseMatrixEmulator, 12 | ) 13 | 14 | 15 | class BlockSparseLinearFunction(torch.autograd.Function): 16 | @staticmethod 17 | def forward(ctx, input, weight_data, weight): 18 | check = False 19 | verbose = False 20 | 21 | if verbose or check: 22 | dense_weight = weight.to_dense() 23 | 24 | if verbose: 25 | stride = 8 26 | print("BlockSparseLinearFunction.forward input\n", input[::stride, ::stride]) 27 | print( 28 | "BlockSparseLinearFunction.forward dense_weight\n", 29 | dense_weight[::stride, ::stride], 30 | ) 31 | print( 32 | "BlockSparseLinearFunction.forward weight\n", 33 | weight.data[::stride, ::stride], 34 | ) 35 | 36 | assert isinstance(weight, BlockSparseMatrixBase) 37 | 38 | ctx.save_for_backward(input, weight_data) 39 | ctx.weight = weight 40 | output = weight.reverse_matmul(input, transpose=True) 41 | if check: 42 | dense = weight.to_dense() 43 | output1 = input.matmul(dense.t()) 44 | if not output1.isclose(output, ator=1e-05).all(): 45 | raise Exception("BlockSparseLinearFunction.forward non matching output 1") 46 | else: 47 | if verbose: 48 | print("BlockSparseLinearFunction.forward matching output 1") 49 | 50 | if verbose: 51 | print("BlockSparseLinearFunction.forward output\n", output[::stride, ::stride]) 52 | 53 | return output 54 | 55 | @staticmethod 56 | def backward(ctx, grad_output): 57 | check = False 58 | verbose = False 59 | input, weight_data = ctx.saved_tensors 60 | weight = ctx.weight 61 | assert isinstance(weight, BlockSparseMatrixBase) 62 | 63 | if verbose or check: 64 | dense_weight = weight.to_dense() 65 | 66 | if verbose: 67 | stride = 8 68 | print("input\n", input[::stride, ::stride]) 69 | print( 70 | "grad_output\n", 71 | grad_output.stride(), 72 | grad_output.storage, 73 | grad_output.layout, 74 | grad_output[::stride, ::stride], 75 | ) 76 | print("dense_weight\n", dense_weight[::stride, ::stride]) 77 | print("weight\n", weight.data[::stride, ::stride]) 78 | 79 | if ctx.needs_input_grad[0]: 80 | grad_input1 = weight.reverse_matmul(grad_output, transpose=False) 81 | 82 | if verbose or check: 83 | grad_input0 = grad_output.matmul(dense_weight) 84 | atol = 1e-4 85 | 86 | if check: 87 | if not grad_input0.isclose(grad_input1).all(): 88 | print(f"grad_output.shape={grad_output.shape}, grad_output.stride={grad_output.stride()}") 89 | print( 90 | "grad_input0/1 comparison\n", 91 | (grad_input0 - grad_input1)[1::32, 1::32, 1::32], 92 | ) 93 | print( 94 | "grad_input0/1 comparison\n", 95 | (grad_input0 - grad_input1).abs().max(), 96 | ) 97 | print( 98 | "grad_input0/1 comparison: count of differences\n", 99 | ((grad_input0 - grad_input1).abs() > atol).sum(), 100 | ) 101 | print( 102 | "grad_input0/1 comparison: position of differences\n", 103 | ((grad_input0 - grad_input1).abs() > atol).nonzero(), 104 | ) 105 | 106 | print("grad_input0 max\n", grad_input0.abs().max()) 107 | print("grad_input1 max\n", grad_input1.abs().max()) 108 | 109 | raise Exception("Non matching grad_input") 110 | else: 111 | if verbose: 112 | print("Backward matching grad_input") 113 | 114 | if verbose: 115 | grad_input2 = weight.reverse_matmul(torch.ones_like(grad_output), transpose=False) 116 | print("grad_input0\n", grad_input0[::stride, ::stride]) 117 | print("grad_input1\n", grad_input1[::stride, ::stride]) 118 | print("grad_input2\n", grad_input2[::stride, ::stride]) 119 | else: 120 | grad_input1 = None 121 | 122 | if ctx.needs_input_grad[1]: 123 | grad_weight1 = weight.matmul_with_output_sparse_support(grad_output, input) 124 | if verbose or check: 125 | grad_weight0 = ( 126 | grad_output.reshape(-1, grad_output.shape[-1]) 127 | .transpose(-1, -2) 128 | .matmul(input.reshape(-1, input.shape[-1])) 129 | ) 130 | if check: 131 | grad_weight1b = weight.to_dense(data_replace=grad_weight1) 132 | grad_weight1mask = weight.to_dense(data_replace=torch.ones_like(grad_weight1)) 133 | grad_weight0 *= grad_weight1mask 134 | 135 | if not grad_weight0.isclose(grad_weight1b).all(): 136 | print("grad_weight0\n", grad_weight0[::stride, ::stride]) 137 | print("grad_weight1\n", grad_weight1[::stride, ::stride]) 138 | raise Exception("Non matching grad_weight") 139 | else: 140 | if verbose: 141 | print("Backward matching grad_weight") 142 | 143 | if verbose: 144 | print("grad_weight0\n", grad_weight0[::stride, ::stride]) 145 | print("grad_weight1\n", grad_weight1[::stride, ::stride]) 146 | else: 147 | grad_weight1 = None 148 | 149 | if grad_weight1 is not None: 150 | assert not (grad_weight1 == 0).all() 151 | if grad_input1 is not None: 152 | assert grad_input1.shape == input.shape 153 | 154 | return grad_input1, grad_weight1, None 155 | 156 | 157 | class BlockSparseLinear(nn.Module): 158 | OPTIMIZED_BLOCK_SIZE = 32 159 | 160 | def __init__( 161 | self, 162 | in_features: int, 163 | out_features: int, 164 | bias: bool = True, 165 | density: float = 0.5, 166 | torch_nn_linear=None, 167 | verbose: bool = False, 168 | block_shape: Tuple[int, int] = (32, 32), 169 | ): 170 | super(BlockSparseLinear, self).__init__() 171 | self.fn = BlockSparseLinearFunction.apply 172 | self.verbose = verbose 173 | self.block_shape = block_shape 174 | self._optimized = ( 175 | self.block_shape[0] == self.OPTIMIZED_BLOCK_SIZE and self.block_shape[1] == self.OPTIMIZED_BLOCK_SIZE 176 | ) 177 | 178 | if torch_nn_linear is not None: 179 | in_features = torch_nn_linear.in_features 180 | out_features = torch_nn_linear.out_features 181 | bias = torch_nn_linear.bias is not None 182 | 183 | if in_features % self.block_shape[1] != 0: 184 | raise Exception( 185 | f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.block_shape[1]}" 186 | ) 187 | if out_features % self.block_shape[0] != 0: 188 | raise Exception( 189 | f"BlockSparseLinear invalid in_features={in_features}, should be multiple of {self.block_shape[0]}" 190 | ) 191 | 192 | if density < 0 or density > 1: 193 | raise Exception(f"BlockSparseLinear invalid density={density}") 194 | 195 | self.block_count = int(density * (in_features * out_features / (self.block_shape[0] * self.block_shape[1]))) 196 | 197 | self.in_features = in_features 198 | self.out_features = out_features 199 | 200 | block_shape = self.block_shape 201 | 202 | if self._optimized: 203 | BlockSparseMatrixConstructor = BlockSparseMatrix 204 | else: 205 | BlockSparseMatrixConstructor = BlockSparseMatrixEmulator 206 | 207 | if torch_nn_linear is not None: 208 | with torch.no_grad(): 209 | weight = BlockSparseMatrixConstructor.from_dense(torch_nn_linear.weight, block_shape, self.block_count) 210 | weight.multiply_(1.0 / math.sqrt(density)) 211 | else: 212 | weight = BlockSparseMatrixConstructor.randn( 213 | (out_features, in_features), 214 | self.block_count, 215 | blocks=None, 216 | block_shape=block_shape, 217 | device="cuda", 218 | ) 219 | self.weight = weight 220 | 221 | if bias: 222 | self.bias = nn.Parameter(torch.zeros(out_features, device="cuda")) 223 | if torch_nn_linear is not None: 224 | with torch.no_grad(): 225 | self.bias.copy_(torch_nn_linear.bias) 226 | else: 227 | self.register_parameter("bias", None) 228 | 229 | def forward(self, x): 230 | x = self.fn(x, self.weight.get_differentiable_data(), self.weight) 231 | if self.bias is not None: 232 | x = x + self.bias 233 | return x 234 | 235 | 236 | class PseudoBlockSparseLinear(torch.nn.Module): 237 | """For debugging purposes mostly: emulate a BlockSparseLinear with only PyTorch primitives.""" 238 | 239 | def __init__(self, block_sparse_linear): 240 | super(PseudoBlockSparseLinear, self).__init__() 241 | 242 | block_sparse_matrix = block_sparse_linear.weight.cuda() 243 | self.weight = torch.nn.Parameter(block_sparse_matrix.to_dense()) 244 | mask = block_sparse_matrix.to_dense(data_replace=torch.ones_like(block_sparse_matrix.data)) == 1 245 | if block_sparse_linear.bias is not None: 246 | self.bias = torch.nn.Parameter(block_sparse_linear.bias) 247 | else: 248 | self.register_parameter("bias", None) 249 | 250 | self.register_buffer("mask", mask) 251 | self.in_features = block_sparse_linear.in_features 252 | self.out_features = block_sparse_linear.out_features 253 | self.density = mask.sum().item() / (mask.shape[0] * mask.shape[1]) 254 | 255 | def forward(self, input): 256 | weight = self.weight * self.mask 257 | return torch.nn.functional.linear(input, weight, self.bias) 258 | 259 | def extra_repr(self): 260 | return "in_features={}, out_features={}, bias={}, fill_ratio={}".format( 261 | self.in_features, self.out_features, self.bias is not None, self.density 262 | ) 263 | -------------------------------------------------------------------------------- /pytorch_block_sparse/tests/test_matmul.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import torch 5 | 6 | from pytorch_block_sparse import BlockSparseMatrix 7 | 8 | 9 | class TestFun(TestCase): 10 | def helper( 11 | self, 12 | sizes, 13 | block_size, 14 | block_count=None, 15 | density=None, 16 | blocks=None, 17 | iterations=1, 18 | device="cuda", 19 | transpose=True, 20 | verbose=False, 21 | ): 22 | device = device 23 | if isinstance(sizes[0], tuple): 24 | sizes_0 = sizes[0] 25 | else: 26 | sizes_0 = (sizes[0],) 27 | 28 | # Build positive matrix to easily check results 29 | if transpose: 30 | a = torch.randn(sizes_0 + (sizes[1],), device=device).abs() 31 | else: 32 | a = torch.randn(sizes_0 + (sizes[2],), device=device).abs() 33 | 34 | # torch.set_printoptions(precision=10, edgeitems=100000, linewidth=10000) 35 | if verbose: 36 | print("a=", a, "\n") 37 | 38 | if block_count is None and blocks is None: 39 | total_block_count = sizes[1] * sizes[2] / block_size[0] / block_size[1] 40 | block_count = int(total_block_count * density) 41 | 42 | bsm = BlockSparseMatrix.randn( 43 | (sizes[2], sizes[1]), 44 | block_count, 45 | blocks=blocks, 46 | block_shape=block_size, 47 | device=device, 48 | positive=True, 49 | ) # Build positive matrix to easily check results 50 | 51 | dbsm = bsm.to_dense() 52 | if verbose: 53 | print("b=", dbsm, "\n") 54 | print("a.shape", a.shape) 55 | bsm.check_with_dense(dbsm) 56 | 57 | timings = {} 58 | for kind in ["pytorch", "cutlass"]: 59 | start = torch.cuda.Event(enable_timing=True) 60 | end = torch.cuda.Event(enable_timing=True) 61 | 62 | start.record() 63 | 64 | for i in range(iterations): 65 | if kind == "pytorch": 66 | if transpose: 67 | dbsm_ = dbsm.t() 68 | else: 69 | dbsm_ = dbsm 70 | c = a.matmul(dbsm_) 71 | 72 | if verbose: 73 | print("c=", c, "\n") 74 | 75 | elif kind == "cutlass": 76 | c = bsm.reverse_matmul(a, transpose) 77 | elif kind == "cublas": 78 | import block_sparse_native 79 | 80 | prr = torch.zeros((sizes[2], sizes[0]), device=device) 81 | prr = prr.t() 82 | _ = block_sparse_native.blocksparse_matmul_transpose_dense(a, dbsm, prr) 83 | elif kind == "cuda": 84 | c = bsm.matmul_cuda(a) 85 | 86 | end.record() 87 | torch.cuda.synchronize() 88 | elapsed = start.elapsed_time(end) 89 | 90 | timing = dict(kind=kind, elapsed=elapsed, result=c) 91 | timings[kind] = timing 92 | 93 | if "pytorch" in timings: 94 | c0 = timings["pytorch"]["result"] 95 | for k, t in timings.items(): 96 | if k == "pytorch": 97 | t["comparison"] = True 98 | continue 99 | c = t["result"] 100 | torch.set_printoptions(precision=8, edgeitems=100000, linewidth=10000) 101 | stride = 32 102 | shift = 0 103 | c_ = c[shift::stride, shift::stride] 104 | c0_ = c0[shift::stride, shift::stride] 105 | if verbose: 106 | print("c shape", c.shape) 107 | print("c\n", c_) 108 | print("c0\n", c0_) 109 | print("c!=0\n", (c_ != 0).long()) 110 | print("c0!=0\n", (c0_ != 0).long()) 111 | print("equals\n", ((c_ - c0_).abs() < 1e-06).long()) 112 | print("equals nonzero\n", ((c_ - c0_).abs() > 1e-06).nonzero().t()) 113 | 114 | atol = 1e-8 115 | rtol = 1e-5 116 | # Matrix are positive, so this is ok 117 | s = c.isclose(c0).all() 118 | if not s.item(): 119 | print( 120 | f"max difference for {t['kind']} = { (c - c0).abs().max()}," 121 | f" max_values={c.abs().max()}, {c0.abs().max()}" 122 | ) 123 | diff = (c - c0).abs() / (atol + rtol * c0.abs()) 124 | t["comparison"] = False 125 | raise Exception( 126 | f"Comparison NOK : reverse_matmul issue for {k} sizes={sizes}," 127 | f" density={density}, block_count={block_count}," 128 | f"diff={diff}, blocks={blocks}, transpose={transpose}" 129 | ) 130 | else: 131 | if verbose: 132 | print(f"Comparison OK for reverse_matmul for {k}") 133 | print("max difference %s=" % t["kind"], (c - c0).abs().max()) 134 | t["comparison"] = True 135 | if verbose: 136 | print("c_cutlass=", c) 137 | torch.set_printoptions(profile="default") 138 | 139 | return timings 140 | 141 | def test0(self): 142 | tests = [ 143 | { 144 | "sizes": [32, 32, 32], 145 | "block_setups": [ 146 | [(0, 0)], 147 | ], 148 | }, 149 | { 150 | "sizes": [32, 64, 32], 151 | "block_setups": [ 152 | [(0, 0)], 153 | ], 154 | }, 155 | { 156 | "sizes": [64, 32, 32], 157 | "block_setups": [ 158 | [(0, 0)], 159 | ], 160 | }, 161 | { 162 | "sizes": [128, 32, 32], 163 | "block_setups": [ 164 | [(0, 0)], 165 | ], 166 | }, 167 | { 168 | "sizes": [128, 64, 32], 169 | "block_setups": [ 170 | [(0, 0)], 171 | [(0, 1)], 172 | [(0, 0), (0, 1)], 173 | ], 174 | }, 175 | ] 176 | tests += [ 177 | { 178 | "sizes": [32, 32, 64], 179 | "block_setups": [ 180 | [(0, 0)], 181 | [(1, 0)], 182 | [(0, 0), (1, 0)], 183 | ], 184 | } 185 | ] 186 | tests += [ 187 | { 188 | "sizes": [(64, 32), 32, 64], 189 | "block_setups": [ 190 | [(0, 0)], 191 | [(1, 0)], 192 | [(0, 0), (1, 0)], 193 | ], 194 | } 195 | ] 196 | tests += [ 197 | { 198 | "sizes": [(64, 128, 32), 128, 256], 199 | "block_setups": [ 200 | [(0, 0)], 201 | [(1, 0)], 202 | [(0, 0), (1, 0)], 203 | ], 204 | } 205 | ] 206 | tests += [ 207 | { 208 | "sizes": [32, 64, 64], 209 | "block_setups": [ 210 | [(0, 0), (1, 0), (0, 1)], 211 | ], 212 | } 213 | ] 214 | tests += [ 215 | { 216 | "sizes": [32, 64, 128], 217 | "block_setups": [ 218 | [(0, 0), (1, 0), (0, 1), (2, 0)], 219 | ], 220 | } 221 | ] 222 | tests += [ 223 | { 224 | "sizes": [1, 32, 32], 225 | "block_setups": [ 226 | [(0, 0)], 227 | ], 228 | } 229 | ] 230 | block_size = (32, 32) 231 | device = "cuda" 232 | for transpose in [False, True]: 233 | for test_info in tests: 234 | sizes = test_info["sizes"] 235 | for blocks in test_info["block_setups"]: 236 | _ = self.helper( 237 | sizes, 238 | block_size, 239 | density=None, 240 | blocks=blocks, 241 | device=device, 242 | verbose=False, 243 | transpose=transpose, 244 | ) 245 | 246 | def test1(self): 247 | size = 512 248 | test_sizes = [ 249 | [1, size * 2, size * 4], 250 | [8 * size * 16, size * 2, size * 4], 251 | [(4 * size * 2, 16), size * 2, size * 4], 252 | ] 253 | 254 | test_densities = [0.42, 1.0] 255 | 256 | import functools 257 | import operator 258 | 259 | block_size = (32, 32) 260 | iterations = 1 261 | 262 | results = {} 263 | for sizes in test_sizes: 264 | if isinstance(sizes[0], int): 265 | sizes_0 = sizes[0] 266 | else: 267 | sizes_0 = functools.reduce(operator.mul, sizes[0], 1) 268 | flops = float(2 * sizes_0 * sizes[1] * sizes[2]) 269 | 270 | for density in test_densities: 271 | for transpose in [False, True]: 272 | for i in range(1): 273 | timings = self.helper( 274 | sizes, 275 | block_size, 276 | density=density, 277 | iterations=iterations, 278 | verbose=False, 279 | transpose=transpose, 280 | ) 281 | 282 | if "pytorch" in timings: 283 | pytorch_time = timings["pytorch"]["elapsed"] 284 | else: 285 | pytorch_time = None 286 | 287 | for kind, d in timings.items(): 288 | if kind == "pytorch": 289 | continue 290 | if kind not in results: 291 | results[kind] = {True: 0, False: 0} 292 | if "comparison" in d: 293 | results[kind][d["comparison"]] += 1 294 | 295 | kind = d["kind"] 296 | kind_elapsed = d["elapsed"] 297 | if pytorch_time is None: 298 | ratio = "Unknown" 299 | else: 300 | ratio = kind_elapsed / pytorch_time 301 | gflops = flops * iterations / kind_elapsed / 1e6 302 | print( 303 | f"density={density}, transpose = {transpose}," 304 | f" elapsed={kind_elapsed}, gflops = {gflops}, ratio = {ratio}" 305 | ) 306 | 307 | # print(results) 308 | 309 | 310 | if __name__ == "__main__": 311 | unittest.main() 312 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/gemm/k_split_control.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * Abstraction for coordinating inter-block k-splitting 33 | */ 34 | 35 | #include 36 | 37 | #include "../util/util.h" 38 | 39 | namespace cutlass { 40 | namespace gemm { 41 | 42 | 43 | /****************************************************************************** 44 | * Storage and initialization 45 | ******************************************************************************/ 46 | 47 | enum 48 | { 49 | NumFlagsSplitK = 4096 50 | }; 51 | 52 | 53 | /** 54 | * Global K-split semaphore flags 55 | * 56 | * TODO: use demand-allocated storage to provide copies for concurrent streams 57 | */ 58 | __device__ int d_flags_split_k[NumFlagsSplitK]; 59 | 60 | 61 | /** 62 | * Preparation kernel for zero-initializing semaphore flags 63 | */ 64 | __global__ inline void prepare_kernel(int *d_flags_split_k) 65 | { 66 | int tid = (blockIdx.x * blockDim.x) + threadIdx.x; 67 | if (tid < NumFlagsSplitK) 68 | d_flags_split_k[tid] = 0; 69 | } 70 | 71 | 72 | /****************************************************************************** 73 | * k_split_control 74 | ******************************************************************************/ 75 | 76 | /** 77 | * \brief Abstraction for coordinating inter-block k-splitting 78 | */ 79 | struct k_split_control 80 | { 81 | /// Extent of a thread block's partition along the GEMM K-axis 82 | // for tested cases, split_k is the number of k element that on z-dim block is responsible for 83 | int split_k; 84 | 85 | /// Whether or not to use a semaphore for inter-block k-splitting. 86 | bool use_semaphore; 87 | 88 | /// Pointer to semaphore 89 | int *d_flags; 90 | 91 | 92 | 93 | //------------------------------------------------------------------------- 94 | // Device API 95 | //------------------------------------------------------------------------- 96 | 97 | /** 98 | * Return the thread block's starting coordinate (k) within the 99 | * multiplicand matrices 100 | */ 101 | inline __device__ 102 | int block_begin_item_k() 103 | { 104 | return blockIdx.z * split_k; 105 | } 106 | 107 | 108 | /** 109 | * Return the thread block's ending coordinate (k) within the multiplicand 110 | * matrices (one-past) 111 | */ 112 | inline __device__ 113 | int block_end_item_k(int dim_k) 114 | { 115 | int next_start_k = block_begin_item_k() + split_k; 116 | return __NV_STD_MIN(next_start_k, dim_k); 117 | } 118 | 119 | 120 | /** 121 | * Whether the thread block is a secondary accumulator in an inter-block 122 | * k-splitting scheme 123 | */ 124 | inline __device__ 125 | bool is_secondary_accumulator() 126 | { 127 | return (blockIdx.z > 0); 128 | } 129 | 130 | 131 | /** 132 | * Wait for predecessor thread block(s) to produce the exclusive 133 | * partial-sums for this block-wide tile 134 | */ 135 | inline __device__ 136 | void wait() 137 | { 138 | // Wait on semaphore 139 | if ((use_semaphore) && (blockIdx.z > 0)) 140 | { 141 | if (threadIdx.x == 0) 142 | { 143 | int bid = (blockIdx.y * gridDim.x) + blockIdx.x; 144 | int hash = bid % NumFlagsSplitK; 145 | int found; 146 | int looking = blockIdx.z; 147 | while (true) 148 | { 149 | asm volatile ("ld.global.cg.u32 %0, [%1];\n" : "=r"(found) : "l"(d_flags + hash)); 150 | 151 | if (found == looking) 152 | break; 153 | 154 | /// Fence to keep load from being hoisted from the loop 155 | __syncwarp(0x00000001); 156 | } 157 | } 158 | 159 | __syncthreads(); 160 | } 161 | } 162 | 163 | 164 | /** 165 | * Signal the successor thread_block(s) that the inclusive partial-sums 166 | * from this block-wide tile are available 167 | */ 168 | inline __device__ 169 | void signal() 170 | { 171 | if (use_semaphore) 172 | { 173 | __syncthreads(); 174 | 175 | if (threadIdx.x == 0) 176 | { 177 | int bid = (blockIdx.y * gridDim.x) + blockIdx.x; 178 | int hash = bid % NumFlagsSplitK; 179 | int val = blockIdx.z + 1; 180 | 181 | asm volatile ("st.global.cg.u32 [%0], %1;\n" : : "l"(d_flags + hash), "r"(val)); 182 | } 183 | } 184 | } 185 | 186 | 187 | //------------------------------------------------------------------------- 188 | // Grid launch API 189 | //------------------------------------------------------------------------- 190 | 191 | /** 192 | * Constructor 193 | */ 194 | inline 195 | k_split_control( 196 | int *d_flags, 197 | int sm_count, 198 | int max_sm_occupancy, 199 | int dim_k, 200 | int block_tile_items_k, 201 | dim3 block_dims, 202 | dim3 &grid_dims) ///< [in,out] 203 | : 204 | d_flags(d_flags), 205 | split_k(dim_k) 206 | { 207 | // Compute wave efficiency 208 | float wave_efficiency = get_wave_efficiency( 209 | sm_count, 210 | max_sm_occupancy, 211 | block_dims, 212 | grid_dims); 213 | 214 | // printf("Here original split_k (dim_k) %d \n", split_k); // for tested cases, split_k is always dim_k 215 | 216 | // Update split-k if wave efficiency is less than some threshold 217 | if (wave_efficiency < 0.9) 218 | { 219 | int num_threadblocks = grid_dims.x * grid_dims.y * grid_dims.z; 220 | 221 | // Ideal number of thread blocks in grid 222 | int ideal_threadblocks = lcm(sm_count, num_threadblocks); 223 | 224 | // Desired number of partitions to split K-axis into 225 | int num_partitions = ideal_threadblocks / num_threadblocks; 226 | 227 | // Compute new k-split share 228 | int new_split_k = (dim_k + num_partitions - 1) / num_partitions; 229 | 230 | // Round split_k share to the nearest block_task_policy_t::BlockItemsK 231 | new_split_k = round_nearest(new_split_k, block_tile_items_k); 232 | 233 | // Recompute k-splitting factor with new_split_k 234 | num_partitions = (dim_k + new_split_k - 1) / new_split_k; 235 | 236 | // Update grid dims and k if we meet the minimum number of iterations worth the overhead of splitting 237 | // set min_iterations_k 16 if we want to check better performance than cublas on 1024 x 1024 matrix 238 | int min_iterations_k = 16;//8; 239 | 240 | if (((new_split_k / block_tile_items_k) > min_iterations_k) && // We're going to go through at least this many k iterations 241 | (sm_count * max_sm_occupancy < NumFlagsSplitK)) // We have enough semaphore flags allocated 242 | { 243 | grid_dims.z = num_partitions; 244 | split_k = new_split_k; 245 | } 246 | } 247 | 248 | // printf("Updated split_k %d \n", split_k); 249 | 250 | 251 | use_semaphore = (grid_dims.z > 1); 252 | } 253 | 254 | 255 | /** 256 | * Initializer 257 | */ 258 | cudaError_t prepare( 259 | cudaStream_t stream, ///< CUDA stream to launch kernels within. Default is stream0. 260 | bool debug_synchronous) ///< Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console if DEBUG is defined. Default is \p false. 261 | 262 | { 263 | cudaError error = cudaSuccess; 264 | 265 | if (use_semaphore) 266 | { 267 | int block_threads = 128; 268 | int grid_dims = (NumFlagsSplitK + block_threads - 1) / block_threads; 269 | 270 | prepare_kernel<<>>(d_flags); 271 | 272 | // Check for failure to launch 273 | if (CUDA_PERROR_DEBUG(error = cudaPeekAtLastError())) 274 | return error; 275 | 276 | // Sync the stream if specified to flush runtime errors 277 | if (debug_synchronous && (CUDA_PERROR_DEBUG(error = cudaStreamSynchronize(stream)))) 278 | return error; 279 | } 280 | 281 | return error; 282 | } 283 | 284 | 285 | /** 286 | * Compute the efficiency of dispatch wave quantization 287 | */ 288 | float get_wave_efficiency( 289 | int sm_count, 290 | int max_sm_occupancy, 291 | dim3 block_dims, 292 | dim3 grid_dims) 293 | { 294 | // Heuristic for how many warps are needed to saturate an SM for a given 295 | // multiply-accumulate genre. (NB: We could make this more rigorous by 296 | // specializing on data types and SM width) 297 | int saturating_warps_per_sm = 16; 298 | 299 | int num_threadblocks = grid_dims.x * grid_dims.y * grid_dims.z; 300 | int threads_per_threadblock = block_dims.x * block_dims.y; 301 | int warps_per_threadblock = threads_per_threadblock / 32; 302 | int saturating_threadblocks_per_sm = (saturating_warps_per_sm + warps_per_threadblock - 1) / warps_per_threadblock; 303 | 304 | int saturating_residency = sm_count * saturating_threadblocks_per_sm; 305 | int full_waves = num_threadblocks / saturating_residency; 306 | int remainder_threadblocks = num_threadblocks % saturating_residency; 307 | int total_waves = (remainder_threadblocks == 0) ? full_waves : full_waves + 1; 308 | 309 | float last_wave_saturating_efficiency = float(remainder_threadblocks) / saturating_residency; 310 | 311 | return (float(full_waves) + last_wave_saturating_efficiency) / total_waves; 312 | } 313 | }; 314 | 315 | 316 | } // namespace gemm 317 | } // namespace cutlass 318 | -------------------------------------------------------------------------------- /pytorch_block_sparse/sparse_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | from pytorch_block_sparse import BlockSparseMatrix 5 | 6 | 7 | class SparseOptimizerStrategy: 8 | def run(self, block_sparse_matrix): 9 | raise NotImplementedError() 10 | 11 | 12 | class MagnitudeSparseOptimizerStrategy(SparseOptimizerStrategy): 13 | def __init__( 14 | self, 15 | cleanup_ratio, 16 | new_coefficients_distribution="uniform", 17 | new_coefficients_scale=0.1, 18 | ): 19 | self.cleanup_ratio = cleanup_ratio 20 | self.new_coefficients_distribution = new_coefficients_distribution 21 | self.new_coefficients_scale = new_coefficients_scale 22 | 23 | def initialize_new_blocks(self, old_data, new_data): 24 | mean, std = old_data.mean(), old_data.std() 25 | 26 | if self.new_coefficients_distribution == "gaussian": 27 | new_data.normal_( 28 | mean=mean * self.new_coefficients_scale, 29 | std=std * self.new_coefficients_scale, 30 | ) 31 | elif self.new_coefficients_distribution == "uniform": 32 | new_data.random_(0, 1) 33 | new_data -= 0.5 34 | new_data *= 2 * std * self.new_coefficients_scale 35 | else: 36 | raise Exception("Unknown new coefficients method %s" % self.new_coefficients_distribution) 37 | 38 | def run(self, block_sparse_matrix): 39 | bsm = block_sparse_matrix 40 | # Get the norm of each block 41 | norms = bsm.block_norm() 42 | 43 | # Sort the norm 44 | _, indices = norms.sort() 45 | 46 | # Extract the worst blocks 47 | bad_blocks = indices[: int(indices.shape[0] * self.cleanup_ratio)] 48 | 49 | # Find available positions 50 | block_mask = ~bsm.block_mask_build(None) 51 | available = block_mask.nonzero() 52 | 53 | # Extract some random position 54 | empty_positions_indices = torch.randperm(available.shape[0])[: bad_blocks.shape[0]] 55 | new_positions = available[empty_positions_indices] 56 | 57 | block_replacements = torch.cat([new_positions, bad_blocks.unsqueeze(-1)], -1) 58 | 59 | bsm.block_replace(block_replacements) 60 | 61 | # bad_blocks 62 | new_block_mask = torch.zeros( 63 | bsm.data.shape[0] // bsm.block_shape[0], 64 | dtype=torch.bool, 65 | device=bsm.data.device, 66 | ) 67 | 68 | new_block_mask[bad_blocks] = True 69 | 70 | new_block_mask = new_block_mask.unsqueeze(-1) 71 | new_block_mask = new_block_mask.repeat_interleave(bsm.block_shape[0], dim=0) 72 | new_block_mask = new_block_mask.repeat_interleave(bsm.block_shape[1], dim=1) 73 | new_block_mask = new_block_mask.float() 74 | 75 | new_blocks = torch.zeros_like(bsm.data) 76 | 77 | self.initialize_new_blocks(bsm.data, new_blocks) 78 | 79 | new_blocks *= new_block_mask 80 | 81 | state_keep_mask = 1.0 - new_block_mask 82 | 83 | with torch.no_grad(): 84 | bsm.data *= state_keep_mask 85 | bsm.data += new_blocks 86 | 87 | return state_keep_mask 88 | 89 | 90 | class _RequiredParameter(object): 91 | """Singleton class representing a required parameter for an Optimizer.""" 92 | 93 | def __repr__(self): 94 | return "" 95 | 96 | 97 | required = _RequiredParameter() 98 | 99 | 100 | class OptimizerStateUpdater: 101 | def __init__(self, optimizer, sparse_object): 102 | self.optimizer = optimizer 103 | if not isinstance(sparse_object, BlockSparseMatrix): 104 | raise Exception(f"Unknown sparse_object type {sparse_object}") 105 | 106 | self.sparse_object = sparse_object 107 | 108 | def update_state_data(self, param, state_keep_mask): 109 | raise NotImplementedError() 110 | 111 | def update_state(self, state_keep_mask): 112 | if isinstance(self.sparse_object, BlockSparseMatrix): 113 | search_param = self.sparse_object.data 114 | else: 115 | raise Exception(f"Unknown sparse_object type {self.sparse_object}") 116 | 117 | found = False 118 | for param_group in self.optimizer.param_groups: 119 | for param in param_group["params"]: 120 | if param is search_param: 121 | found = True 122 | self.update_state_data(param, state_keep_mask) 123 | 124 | return found 125 | 126 | 127 | class AdamOptimizerStateUpdater(OptimizerStateUpdater): 128 | @staticmethod 129 | def is_compatible(optimizer): 130 | if isinstance(optimizer, optim.Adam): 131 | return True 132 | 133 | try: 134 | import transformers.optimization as transformers_optim 135 | except Exception: 136 | transformers_optim = None 137 | 138 | if transformers_optim is not None: 139 | if isinstance(optimizer, transformers_optim.AdamW): 140 | return True 141 | 142 | def update_state_data(self, param, state_keep_mask): 143 | opt = self.optimizer 144 | 145 | param_state = opt.state[param] 146 | 147 | for key in param_state: 148 | if key in ["exp_avg", "exp_avg_sq", "max_exp_avg_sq"]: 149 | param_state[key] *= state_keep_mask 150 | elif key == "step": 151 | # We cannot really alter the step info, it's global, so the bias_correction1 and bias_correction2 may 152 | # not be completely correct for the new coefficients, but it should not be a big issue 153 | pass 154 | else: 155 | raise Exception(f"Unknown key in Adam parameter state {key}") 156 | 157 | 158 | class SparseOptimizer(torch.optim.Optimizer): 159 | METHODS = ["magnitude"] 160 | COEFFICIENTS_DISTRIBUTION = ["uniform", "gaussian"] 161 | allowed_keys = { 162 | "lr", 163 | "method", 164 | "new_coefficients_scale", 165 | "new_coefficients_distribution", 166 | } 167 | """optimizer = sparse_cleaner.SparseOptimizer([BlockSparseMatrix,BlockSparseMatrix], 168 | method="magnitude", new_coefficients_distribution="uniform") 169 | optimizer.add_param_group(dict(sparse_objects=[BlockSparseMatrix], 170 | lr=0.5, method="magnitude", 171 | new_coefficients_distribution="gaussian", new_coefficients_scale = 1.0))""" 172 | 173 | def __init__( 174 | self, 175 | sparse_objects, 176 | lr=1e-1, 177 | method="magnitude", 178 | new_coefficients_scale=0.1, 179 | new_coefficients_distribution="uniform", 180 | ): 181 | if not 0.0 < lr: 182 | raise ValueError("Invalid learning rate: {}".format(lr)) 183 | 184 | defaults = dict( 185 | lr=lr, 186 | method=method, 187 | new_coefficients_scale=new_coefficients_scale, 188 | new_coefficients_distribution=new_coefficients_distribution, 189 | ) 190 | 191 | super(SparseOptimizer, self).__init__([{"sparse_objects": sparse_objects}], defaults) 192 | self.attached_optimizers = [] 193 | 194 | @staticmethod 195 | def sparse_objects(model): 196 | ret = [] 197 | for name, module in model.named_modules(): 198 | if isinstance(module, BlockSparseMatrix): 199 | ret.append(module) 200 | 201 | return ret 202 | 203 | def attach_optimizer(self, optimizer): 204 | if optimizer in self.attached_optimizers: 205 | Warning("Optimizer already attached") 206 | return 207 | self.attached_optimizers.append(optimizer) 208 | 209 | def add_param_group(self, sparse_objects_group): 210 | assert isinstance(sparse_objects_group, dict), "param group must be a dict" 211 | 212 | for k in sparse_objects_group: 213 | if k == "sparse_objects": 214 | continue 215 | elif k not in self.allowed_keys: 216 | raise Exception("Unknown cleaning parameter %s" % k) 217 | 218 | sparse_objects = sparse_objects_group["sparse_objects"] 219 | 220 | if isinstance(sparse_objects, BlockSparseMatrix): 221 | sparse_objects_group["sparse_objects"] = [sparse_objects] 222 | else: 223 | sparse_objects_group["sparse_objects"] = list(sparse_objects) 224 | 225 | sparse_objects = sparse_objects_group["sparse_objects"] 226 | 227 | for p in sparse_objects: 228 | if isinstance(p, BlockSparseMatrix): 229 | continue 230 | else: 231 | raise Exception("I don't know how to clean this type of object: %s" % p) 232 | 233 | for name, default in self.defaults.items(): 234 | if default is required and name not in sparse_objects_group: 235 | raise ValueError("parameter group didn't specify a value of required optimization parameter " + name) 236 | else: 237 | sparse_objects_group.setdefault(name, default) 238 | 239 | if sparse_objects_group["method"] not in self.METHODS: 240 | raise Exception(f"Invalid Method {sparse_objects_group['method']}") 241 | 242 | if sparse_objects_group["new_coefficients_distribution"] not in self.COEFFICIENTS_DISTRIBUTION: 243 | raise Exception( 244 | f"Invalid new coefficients distribution {sparse_objects_group['new_coefficients_distribution']}" 245 | ) 246 | 247 | param_set = set() 248 | for group in self.param_groups: 249 | param_set.update(set(group["sparse_objects"])) 250 | 251 | if not param_set.isdisjoint(set(sparse_objects_group["sparse_objects"])): 252 | raise ValueError("some parameters appear in more than one parameter group") 253 | 254 | self.param_groups.append(sparse_objects_group) 255 | 256 | def clean( 257 | self, 258 | p, 259 | method, 260 | clean_ratio, 261 | new_coefficients_scale, 262 | new_coefficients_distribution, 263 | ): 264 | if not isinstance(p, BlockSparseMatrix): 265 | raise Exception("I don't know how to clean this : %s" % p) 266 | 267 | if method == "magnitude": 268 | cleaner = MagnitudeSparseOptimizerStrategy( 269 | clean_ratio, 270 | new_coefficients_distribution=new_coefficients_distribution, 271 | new_coefficients_scale=new_coefficients_scale, 272 | ) 273 | else: 274 | raise Exception(f"Unknowncleaning method {method}") 275 | 276 | state_keep_mask = cleaner.run(p) 277 | 278 | if len(self.attached_optimizers) != 0: 279 | found = False 280 | for optimizer in self.attached_optimizers: 281 | if AdamOptimizerStateUpdater.is_compatible(optimizer): 282 | updater = AdamOptimizerStateUpdater(optimizer, p) 283 | found = found or updater.update_state(state_keep_mask) 284 | else: 285 | raise Exception(f"unsupported optimizer {optimizer.__class__}") 286 | 287 | if not found: 288 | raise Exception(f"Could not find sparse object {p} in optimizers {self.attached_optimizers}") 289 | else: 290 | Warning("No attached optimizer.") 291 | 292 | def step(self): 293 | for group in self.param_groups: 294 | clean_ratio = group["lr"] 295 | if clean_ratio == 0.0: 296 | continue 297 | for p in group["sparse_objects"]: 298 | self.clean( 299 | p, 300 | clean_ratio=clean_ratio, 301 | method=group["method"], 302 | new_coefficients_scale=group["new_coefficients_scale"], 303 | new_coefficients_distribution=group["new_coefficients_distribution"], 304 | ) 305 | -------------------------------------------------------------------------------- /pytorch_block_sparse/cutlass/gemm/dispatch_policies.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * * Redistributions of source code must retain the above copyright 7 | * notice, this list of conditions and the following disclaimer. 8 | * * Redistributions in binary form must reproduce the above copyright 9 | * notice, this list of conditions and the following disclaimer in the 10 | * documentation and/or other materials provided with the distribution. 11 | * * Neither the name of the NVIDIA CORPORATION nor the 12 | * names of its contributors may be used to endorse or promote products 13 | * derived from this software without specific prior written permission. 14 | * 15 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 19 | * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 22 | * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | * 26 | ******************************************************************************/ 27 | 28 | #pragma once 29 | 30 | /** 31 | * \file 32 | * Architecture-specific GEMM block_task policies 33 | */ 34 | 35 | #include 36 | 37 | #include "../util/util.h" 38 | #include "block_task.h" 39 | #include "block_task_back.h" 40 | #include "grid_raster.h" 41 | #include "grid_raster_sparse.h" 42 | 43 | namespace cutlass { 44 | namespace gemm { 45 | 46 | 47 | /****************************************************************************** 48 | * tiling_strategy 49 | ******************************************************************************/ 50 | 51 | /** 52 | * Enumeration of tile-sizing granularities 53 | */ 54 | struct tiling_strategy : printable_t 55 | { 56 | /// \brief Enumerants 57 | enum kind_t 58 | { 59 | Unknown, 60 | Small, 61 | Medium, 62 | Large, 63 | Tall, 64 | Wide, 65 | Huge, 66 | Custom, 67 | CustomLarge, 68 | CustomBack, 69 | }; 70 | 71 | /// Enumerant value 72 | kind_t kind; 73 | 74 | /// Default constructor 75 | tiling_strategy() : kind(Unknown) {} 76 | 77 | /// Copy constructor 78 | tiling_strategy(const kind_t &other_kind) : kind(other_kind) {} 79 | 80 | /// Cast to kind_t 81 | operator kind_t() const { return kind; } 82 | 83 | /// Returns the instance as a string 84 | __host__ __device__ inline 85 | char const* to_string() const 86 | { 87 | switch (kind) 88 | { 89 | case Small: return "small"; 90 | case Medium: return "medium"; 91 | case Large: return "large"; 92 | case Tall: return "tall"; 93 | case Wide: return "wide"; 94 | case Huge: return "huge"; 95 | case Custom: return "Custom"; 96 | case CustomLarge: return "CustomLarge"; 97 | case CustomBack: return "CustomBack"; 98 | case Unknown: 99 | default: return "unknown"; 100 | } 101 | } 102 | 103 | /// Insert the formatted instance into the output stream 104 | void print(std::ostream& out) const { out << to_string(); } 105 | }; 106 | 107 | 108 | /****************************************************************************** 109 | * GEMM 110 | ******************************************************************************/ 111 | 112 | /** 113 | * GEMM task policy specialization for sgemm 114 | */ 115 | template < 116 | typename value_t, 117 | typename accum_t, 118 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 119 | matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B 120 | tiling_strategy::kind_t TilingStrategy> ///< Tile-sizing classification 121 | struct gemm_policy; 122 | 123 | 124 | /****************************************************************************** 125 | * SGEMM 126 | ******************************************************************************/ 127 | /** 128 | * GEMM task policy specialization for Custom sgemm 129 | */ 130 | template < 131 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 132 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 133 | struct gemm_policy : 134 | block_task_policy< 135 | 256, // _BlockItemsY 136 | 32, // _BlockItemsX 137 | 32, // _BlockItemsK 138 | 8, // _ThreadItemsY 139 | 8, // _ThreadItemsX 140 | false, // _UseDoubleScratchTiles 141 | grid_raster_strategy::Default> // _RasterStrategy 142 | {}; 143 | 144 | /** 145 | * GEMM task policy specialization for Custom sgemm 146 | */ 147 | template < 148 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 149 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 150 | struct gemm_policy : 151 | block_task_policy< 152 | 32, // _BlockItemsY 153 | 32, // _BlockItemsX 154 | 32, // _BlockItemsK 155 | 4, // _ThreadItemsY 156 | 4, // _ThreadItemsX 157 | false, // _UseDoubleScratchTiles 158 | grid_raster_strategy::Default> // _RasterStrategy 159 | {}; 160 | 161 | 162 | /** 163 | * GEMM task policy specialization for CustomBack sgemm 164 | */ 165 | template < 166 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 167 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 168 | struct gemm_policy : 169 | block_task_back_policy< 170 | 32, // _BlockItemsY 171 | 32, // _BlockItemsX 172 | 32, // _BlockItemsK 173 | 4, // _ThreadItemsY 174 | 4, // _ThreadItemsX 175 | false, // _UseDoubleScratchTiles 176 | grid_raster_sparse_strategy::Sparse> // _RasterStrategy 177 | {}; 178 | 179 | # if 0 180 | /** 181 | * GEMM task policy specialization for small sgemm 182 | */ 183 | template < 184 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 185 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 186 | struct gemm_policy : 187 | block_task_policy< 188 | 16, // _BlockItemsY 189 | 16, // _BlockItemsX 190 | 16, // _BlockItemsK 191 | 2, // _ThreadItemsY 192 | 2, // _ThreadItemsX 193 | false, // _UseDoubleScratchTiles 194 | grid_raster_strategy::Default> // _RasterStrategy 195 | {}; 196 | 197 | 198 | /** 199 | * GEMM task policy specialization for medium sgemm 200 | */ 201 | template < 202 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 203 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 204 | struct gemm_policy : 205 | block_task_policy< 206 | 32, // _BlockItemsY 207 | 32, // _BlockItemsX 208 | 8, // _BlockItemsK 209 | 4, // _ThreadItemsY 210 | 4, // _ThreadItemsX 211 | false, // _UseDoubleScratchTiles 212 | grid_raster_strategy::Default> // _RasterStrategy 213 | {}; 214 | 215 | /** 216 | * GEMM task policy specialization for large sgemm 217 | */ 218 | template < 219 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 220 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 221 | struct gemm_policy : 222 | block_task_policy< 223 | 64, // _BlockItemsY 224 | 64, // _BlockItemsX 225 | 8, // _BlockItemsK 226 | 8, // _ThreadItemsY 227 | 8, // _ThreadItemsX 228 | false, // _UseDoubleScratchTiles 229 | grid_raster_strategy::Default> // _RasterStrategy 230 | {}; 231 | 232 | /** 233 | * GEMM task policy specialization for tall sgemm 234 | */ 235 | template < 236 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 237 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 238 | struct gemm_policy : 239 | block_task_policy< 240 | 128, // _BlockItemsY 241 | 32, // _BlockItemsX 242 | 8, // _BlockItemsK 243 | 8, // _ThreadItemsY 244 | 4, // _ThreadItemsX 245 | false, // _UseDoubleScratchTiles 246 | grid_raster_strategy::Default> // _RasterStrategy 247 | {}; 248 | 249 | /** 250 | * GEMM task policy specialization for wide sgemm 251 | */ 252 | template < 253 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 254 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 255 | struct gemm_policy : 256 | block_task_policy< 257 | 32, // _BlockItemsY 258 | 128, // _BlockItemsX 259 | 8, // _BlockItemsK 260 | 4, // _ThreadItemsY 261 | 8, // _ThreadItemsX 262 | false, // _UseDoubleScratchTiles 263 | grid_raster_strategy::Default> // _RasterStrategy 264 | {}; 265 | 266 | /** 267 | * GEMM task policy specialization for huge sgemm 268 | */ 269 | template < 270 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 271 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 272 | struct gemm_policy : 273 | block_task_policy< 274 | 128, // _BlockItemsY 275 | 128, // _BlockItemsX 276 | 8, // _BlockItemsK 277 | 8, // _ThreadItemsY 278 | 8, // _ThreadItemsX 279 | false, // _UseDoubleScratchTiles 280 | grid_raster_strategy::Default> // _RasterStrategy 281 | {}; 282 | #endif 283 | 284 | /****************************************************************************** 285 | * DGEMM 286 | ******************************************************************************/ 287 | 288 | /** 289 | * GEMM task policy specialization for Custom dgemm 290 | */ 291 | template < 292 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 293 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 294 | struct gemm_policy : 295 | block_task_policy< 296 | 32, // _BlockItemsY 297 | 32, // _BlockItemsX 298 | 32, // _BlockItemsK 299 | 4, // _ThreadItemsY 300 | 4, // _ThreadItemsX 301 | false, // _UseDoubleScratchTiles 302 | grid_raster_strategy::Default> // _RasterStrategy 303 | {}; 304 | 305 | # if 0 306 | /** 307 | * GEMM task policy specialization for small dgemm 308 | */ 309 | template < 310 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 311 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 312 | struct gemm_policy : 313 | block_task_policy< 314 | 16, // _BlockItemsY 315 | 16, // _BlockItemsX 316 | 16, // _BlockItemsK 317 | 2, // _ThreadItemsY 318 | 2, // _ThreadItemsX 319 | false, // _UseDoubleScratchTiles 320 | grid_raster_strategy::Default> // _RasterStrategy 321 | {}; 322 | 323 | 324 | /** 325 | * GEMM task policy specialization for medium dgemm 326 | */ 327 | template < 328 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 329 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 330 | struct gemm_policy : 331 | block_task_policy< 332 | 32, // _BlockItemsY 333 | 32, // _BlockItemsX 334 | 16, // _BlockItemsK 335 | 4, // _ThreadItemsY 336 | 4, // _ThreadItemsX 337 | false, // _UseDoubleScratchTiles 338 | grid_raster_strategy::Default> // _RasterStrategy 339 | {}; 340 | 341 | /** 342 | * GEMM task policy specialization for large dgemm 343 | */ 344 | template < 345 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 346 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 347 | struct gemm_policy : 348 | block_task_policy< 349 | 64, // _BlockItemsY 350 | 64, // _BlockItemsX 351 | 8, // _BlockItemsK 352 | 4, // _ThreadItemsY 353 | 4, // _ThreadItemsX 354 | false, // _UseDoubleScratchTiles 355 | grid_raster_strategy::Default> // _RasterStrategy 356 | {}; 357 | 358 | /** 359 | * GEMM task policy specialization for tall dgemm 360 | */ 361 | template < 362 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 363 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 364 | struct gemm_policy : 365 | block_task_policy< 366 | 128, // _BlockItemsY 367 | 32, // _BlockItemsX 368 | 8, // _BlockItemsK 369 | 8, // _ThreadItemsY 370 | 4, // _ThreadItemsX 371 | false, // _UseDoubleScratchTiles 372 | grid_raster_strategy::Default> // _RasterStrategy 373 | {}; 374 | 375 | /** 376 | * GEMM task policy specialization for wide dgemm 377 | */ 378 | template < 379 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 380 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 381 | struct gemm_policy : 382 | block_task_policy< 383 | 32, // _BlockItemsY 384 | 128, // _BlockItemsX 385 | 8, // _BlockItemsK 386 | 4, // _ThreadItemsY 387 | 8, // _ThreadItemsX 388 | false, // _UseDoubleScratchTiles 389 | grid_raster_strategy::Default> // _RasterStrategy 390 | {}; 391 | 392 | /** 393 | * GEMM task policy specialization for huge dgemm 394 | */ 395 | template < 396 | matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A 397 | matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B 398 | struct gemm_policy : 399 | block_task_policy< 400 | 64, // _BlockItemsY 401 | 128, // _BlockItemsX 402 | 8, // _BlockItemsK 403 | 8, // _ThreadItemsY 404 | 8, // _ThreadItemsX 405 | false, // _UseDoubleScratchTiles 406 | grid_raster_strategy::Default> // _RasterStrategy 407 | {}; 408 | 409 | #endif 410 | } // namespace gemm 411 | } // namespace cutlass 412 | --------------------------------------------------------------------------------