├── .gitignore ├── LICENSE ├── README.md ├── outer_softmax.py ├── outer_softmax_online.py ├── softmax_loop_along_reduce_axis_v1.py ├── softmax_loop_along_reduce_axis_v2.py ├── softmax_naive.py ├── softmax_online_v1.py ├── softmax_online_v2.py ├── softmax_online_v2_evict.py ├── softmax_online_v2_rev.py ├── softmax_online_v2_spec.py ├── softmax_online_v2_spec_rev.py ├── softmax_online_v2_spec_rev_evict.py └── softmax_split.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Clement Chan 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # optimize_softmax 2 | Optimize softmax in triton in many cases 3 | -------------------------------------------------------------------------------- /outer_softmax.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def softmax_kernel( 7 | output_ptr, 8 | input_ptr, 9 | M, 10 | N, 11 | K, 12 | TILE_N: tl.constexpr, 13 | TILE_K: tl.constexpr, 14 | ): 15 | pid_k = tl.program_id(0) 16 | pid_m = tl.program_id(1) 17 | 18 | k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K) 19 | n_offsets = tl.arange(0, TILE_N) 20 | offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets 21 | mask = (n_offsets[:, None] < N) & (k_offsets < K) 22 | input_ptrs = input_ptr + offset 23 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.type.element_ty) 24 | m = tl.max(inp, 0) 25 | e = tl.exp(inp - m[None, :]) 26 | z = tl.sum(e, 0) 27 | out = e / z 28 | output_ptrs = output_ptr + offset 29 | tl.store(output_ptrs, out, mask=mask) 30 | 31 | 32 | def softmax(x, TILE_K): 33 | inp = x.contiguous() 34 | M = 1 35 | N, K = inp.shape 36 | out = torch.empty_like(x) 37 | TILE_N = triton.next_power_of_2(N) 38 | # TILE_K = 1 39 | grid = (triton.cdiv(K, TILE_K), M, 1) 40 | softmax_kernel[grid]( 41 | out, 42 | inp, 43 | M, 44 | N, 45 | K, 46 | TILE_N, 47 | TILE_K, 48 | ) 49 | return out 50 | 51 | 52 | import pytest 53 | 54 | @pytest.mark.parametrize("n", [10, 128]) 55 | @pytest.mark.parametrize("m", [512, 1024, 32 * 1024]) 56 | @pytest.mark.parametrize("TILE_K", [1, 2, 4]) 57 | def test_softmax(m, n, TILE_K): 58 | x = torch.randn((m, n), device="cuda") 59 | hyp = softmax(x, TILE_K) 60 | ref = torch.softmax(x, dim=0) 61 | torch.testing.assert_close(hyp, ref) 62 | 63 | def benchmark_softmax(m, n): 64 | x = torch.randn((m, n), device="cuda") 65 | t1 = triton.testing.do_bench(lambda: softmax(x, 1), return_mode="median") 66 | t2 = triton.testing.do_bench(lambda: softmax(x, 2), return_mode="median") 67 | t3 = triton.testing.do_bench(lambda: softmax(x, 4), return_mode="median") 68 | t4 = triton.testing.do_bench(lambda: torch.softmax(x, dim=0), return_mode="median") 69 | def throughput(t): 70 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 71 | return throughput(t1), throughput(t2), throughput(t3), throughput(t4) 72 | 73 | import pandas as pd 74 | def run_benchmark(): 75 | records = [] 76 | for n in [10, 128, 1024, 4096]: 77 | for m in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64 * 1024, 128 * 1024]: 78 | if m * n * 4 > 1024 * 1024 * 1024 * 4: 79 | continue 80 | t1, t2, t3, t4 = benchmark_softmax(m, n) 81 | record = (m, n, t1, t2, t3, t4) 82 | records.append(record) 83 | df = pd.DataFrame.from_records(records, columns=["reduce_size", "post_size", "naive_outer_k1", "naive_outer_k2", "naive_outer_k4", "torch"]) 84 | print(df) 85 | df.to_excel("naive_outer.xlsx") 86 | 87 | def run_an_example(m, n, tile_k): 88 | x = torch.randn((m, n), device="cuda") 89 | y = softmax(x, tile_k) 90 | 91 | if __name__ == "__main__": 92 | run_benchmark() 93 | # run_an_example(4096, 4096, ) -------------------------------------------------------------------------------- /outer_softmax_online.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | 6 | @triton.jit 7 | def next_multiple_of(a, b): 8 | # the smallest x>=a that x%b ==0 9 | return tl.cidv(a, b) * b 10 | 11 | 12 | @triton.jit 13 | def prev_multiple_of(a, b): 14 | # the largest x 1024 * 1024 * 1024 * 4: 116 | continue 117 | for tile_k in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: 118 | t1 = benchmark_softmax(m, n, tile_k) 119 | record = (m, n, tile_k, t1, triton.cdiv(n, tile_k)) 120 | records.append(record) 121 | df = pd.DataFrame.from_records(records, columns=["reduce_size", "post_size", "tiel_k", "online", "grid"]) 122 | print(df) 123 | df.to_excel("online_outer.xlsx") 124 | 125 | def run_an_example(m, n, tile_k): 126 | x = torch.randn((m, n), device="cuda") 127 | y = softmax(x, tile_k) 128 | 129 | if __name__ == "__main__": 130 | run_benchmark() 131 | # run_an_example(4096, 32 * 1024, 64) 132 | # run_an_example(4096, 32 * 1024, 2) 133 | # run_an_example(4096, 32 * 1024, 1024) 134 | -------------------------------------------------------------------------------- /softmax_loop_along_reduce_axis_v1.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def softmax_kernel_loop_v1( 7 | output_ptr, 8 | input_ptr, 9 | M, 10 | N, 11 | TILE_N: tl.constexpr, 12 | ): 13 | pid_m = tl.program_id(0) 14 | m = tl.full((), value=-float("inf"), dtype=output_ptr.dtype.element_ty) 15 | for start_n in range(0, N, TILE_N): 16 | n_offsets = start_n + tl.arange(0, TILE_N) 17 | offset = pid_m * N + n_offsets 18 | input_ptrs = input_ptr + offset 19 | mask = n_offsets < N 20 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 21 | m = tl.maximum(m, tl.max(inp, 0)) 22 | 23 | z = tl.full((), value=0, dtype=output_ptr.dtype.element_ty) 24 | for start_n in range(0, N, TILE_N): 25 | n_offsets = start_n + tl.arange(0, TILE_N) 26 | offset = pid_m * N + n_offsets 27 | input_ptrs = input_ptr + offset 28 | mask = n_offsets < N 29 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 30 | e = tl.exp(inp - m) 31 | z += tl.sum(e) 32 | 33 | for start_n in range(0, N, TILE_N): 34 | n_offsets = start_n + tl.arange(0, TILE_N) 35 | offset = pid_m * N + n_offsets 36 | input_ptrs = input_ptr + offset 37 | mask = n_offsets < N 38 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 39 | e = tl.exp(inp - m) 40 | out = e / z 41 | output_ptrs = output_ptr + offset 42 | tl.store(output_ptrs, out, mask=mask) 43 | 44 | 45 | def softmax(x): 46 | M, N = x.shape 47 | out = torch.empty_like(x) 48 | # TODO tune 49 | TILE_N = min(4096, triton.next_power_of_2(N)) 50 | grid = (M, 1, 1) 51 | softmax_kernel_loop_v1[grid](out, x, M, N, TILE_N) 52 | return out 53 | 54 | import pytest 55 | 56 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024]) 57 | @pytest.mark.parametrize("m", [10, 128]) 58 | def test_softmax(m, n): 59 | x = torch.randn((m, n), device="cuda") 60 | hyp = softmax(x) 61 | ref = torch.softmax(x, dim=-1) 62 | torch.testing.assert_close(hyp, ref) 63 | 64 | def benchmark_softmax(m, n): 65 | x = torch.randn((m, n), device="cuda") 66 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median") 67 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median") 68 | def throughput(t): 69 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 70 | return throughput(t1), throughput(t2) 71 | 72 | import pandas as pd 73 | def run_benchmark(): 74 | records = [] 75 | for m in [10, 128, 1024, 4096]: 76 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]: 77 | if m * n * 4 > 1024 * 1024 * 1024 * 4: 78 | continue 79 | t1, t2 = benchmark_softmax(m, n) 80 | record = (m, n, t1, t2) 81 | records.append(record) 82 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "loop_v1", "torch"]) 83 | print(df) 84 | df.to_excel("loop_v1.xlsx") 85 | 86 | if __name__ == "__main__": 87 | run_benchmark() 88 | -------------------------------------------------------------------------------- /softmax_loop_along_reduce_axis_v2.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def softmax_kernel_loop_v2( 7 | output_ptr, 8 | input_ptr, 9 | M, 10 | N, 11 | TILE_N: tl.constexpr, 12 | ): 13 | pid_m = tl.program_id(0) 14 | m = tl.full((TILE_N, ), value=-float("inf"), dtype=output_ptr.dtype.element_ty) 15 | for start_n in range(0, N, TILE_N): 16 | n_offsets = start_n + tl.arange(0, TILE_N) 17 | offset = pid_m * N + n_offsets 18 | input_ptrs = input_ptr + offset 19 | mask = n_offsets < N 20 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 21 | m = tl.maximum(m, inp) 22 | m = tl.max(m, 0) 23 | 24 | z = tl.full((TILE_N, ), value=0, dtype=output_ptr.dtype.element_ty) 25 | for start_n in range(0, N, TILE_N): 26 | n_offsets = start_n + tl.arange(0, TILE_N) 27 | offset = pid_m * N + n_offsets 28 | input_ptrs = input_ptr + offset 29 | mask = n_offsets < N 30 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 31 | e = tl.exp(inp - m) 32 | z += e 33 | z = tl.sum(z, 0) 34 | 35 | for start_n in range(0, N, TILE_N): 36 | n_offsets = start_n + tl.arange(0, TILE_N) 37 | offset = pid_m * N + n_offsets 38 | input_ptrs = input_ptr + offset 39 | mask = n_offsets < N 40 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 41 | e = tl.exp(inp - m) 42 | out = e / z 43 | output_ptrs = output_ptr + offset 44 | tl.store(output_ptrs, out, mask=mask) 45 | 46 | 47 | def softmax(x): 48 | M, N = x.shape 49 | out = torch.empty_like(x) 50 | # TODO: tune 51 | TILE_N = min(4096, triton.next_power_of_2(N)) 52 | grid = (M, 1, 1) 53 | softmax_kernel_loop_v2[grid](out, x, M, N, TILE_N) 54 | return out 55 | 56 | import pytest 57 | 58 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024]) 59 | @pytest.mark.parametrize("m", [10, 128]) 60 | def test_softmax(m, n): 61 | x = torch.randn((m, n), device="cuda") 62 | hyp = softmax(x) 63 | ref = torch.softmax(x, dim=-1) 64 | torch.testing.assert_close(hyp, ref) 65 | 66 | def benchmark_softmax(m, n): 67 | x = torch.randn((m, n), device="cuda") 68 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median") 69 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median") 70 | def throughput(t): 71 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 72 | return throughput(t1), throughput(t2) 73 | 74 | import pandas as pd 75 | def run_benchmark(): 76 | records = [] 77 | for m in [10, 128, 1024, 4096]: 78 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]: 79 | if m * n * 4 > 1024 * 1024 * 1024 * 4: 80 | continue 81 | t1, t2 = benchmark_softmax(m, n) 82 | record = (m, n, t1, t2) 83 | records.append(record) 84 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "loop_v2", "torch"]) 85 | print(df) 86 | df.to_excel("loop_v2.xlsx") 87 | 88 | if __name__ == "__main__": 89 | run_benchmark() -------------------------------------------------------------------------------- /softmax_naive.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def softmax_kernel( 7 | output_ptr, 8 | input_ptr, 9 | M, 10 | N, 11 | TILE_N: tl.constexpr, 12 | ): 13 | pid_m = tl.program_id(0) 14 | n_offsets = tl.arange(0, TILE_N) 15 | offset = pid_m * N + n_offsets 16 | input_ptrs = input_ptr + offset 17 | mask = n_offsets < N 18 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 19 | m = tl.max(inp, 0) 20 | e = tl.exp(inp - m) 21 | z = tl.sum(e, 0) 22 | out = e / z 23 | output_ptrs = output_ptr + offset 24 | tl.store(output_ptrs, out, mask=mask) 25 | 26 | 27 | def softmax(x): 28 | M, N = x.shape 29 | out = torch.empty_like(x) 30 | TILE_N = triton.next_power_of_2(N) 31 | grid = (M, 1, 1) 32 | softmax_kernel[grid](out, x, M, N, TILE_N) 33 | return out 34 | 35 | import pytest 36 | 37 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024]) 38 | @pytest.mark.parametrize("m", [10, 128, 1024]) 39 | def test_softmax(m, n): 40 | x = torch.randn((m, n), device="cuda") 41 | hyp = softmax(x) 42 | ref = torch.softmax(x, dim=-1) 43 | torch.testing.assert_close(hyp, ref) 44 | 45 | def benchmark_softmax(m, n): 46 | x = torch.randn((m, n), device="cuda") 47 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median") 48 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median") 49 | def throughput(t): 50 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 51 | return throughput(t1), throughput(t2) 52 | 53 | import pandas as pd 54 | def run_benchmark(): 55 | records = [] 56 | for m in [10, 128, 1024, 4096]: 57 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024]: 58 | t1, t2 = benchmark_softmax(m, n) 59 | record = (m, n, t1, t2) 60 | records.append(record) 61 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "naive", "torch"]) 62 | print(df) 63 | df.to_excel("naive.xlsx") 64 | 65 | def run_an_example(m, n): 66 | x = torch.randn((m, n), device="cuda") 67 | y = softmax(x) 68 | 69 | if __name__ == "__main__": 70 | run_an_example(4096, 4 * 1024) 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /softmax_online_v1.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def softmax_kernel_online_v1( 7 | output_ptr, 8 | input_ptr, 9 | M, 10 | N, 11 | TILE_N: tl.constexpr, 12 | ): 13 | pid_m = tl.program_id(0) 14 | m = tl.full((), value=-float("inf"), dtype=output_ptr.dtype.element_ty) 15 | z = tl.full((), value=0, dtype=output_ptr.dtype.element_ty) 16 | for start_n in range(0, N, TILE_N): 17 | n_offsets = start_n + tl.arange(0, TILE_N) 18 | offset = pid_m * N + n_offsets 19 | input_ptrs = input_ptr + offset 20 | mask = n_offsets < N 21 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 22 | new_m = tl.maximum(m, tl.max(inp, 0)) 23 | new_z = tl.exp(m - new_m) * z + tl.sum(tl.exp(inp - new_m), 0) 24 | m = new_m 25 | z = new_z 26 | 27 | for start_n in range(0, N, TILE_N): 28 | n_offsets = start_n + tl.arange(0, TILE_N) 29 | offset = pid_m * N + n_offsets 30 | input_ptrs = input_ptr + offset 31 | mask = n_offsets < N 32 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 33 | e = tl.exp(inp - m) 34 | out = e / z 35 | output_ptrs = output_ptr + offset 36 | tl.store(output_ptrs, out, mask=mask) 37 | 38 | 39 | def softmax(x): 40 | M, N = x.shape 41 | out = torch.empty_like(x) 42 | # TODO tune 43 | TILE_N = min(4096, triton.next_power_of_2(N)) 44 | grid = (M, 1, 1) 45 | softmax_kernel_online_v1[grid](out, x, M, N, TILE_N) 46 | return out 47 | 48 | import pytest 49 | 50 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024]) 51 | @pytest.mark.parametrize("m", [10, 128]) 52 | def test_softmax(m, n): 53 | x = torch.randn((m, n), device="cuda") 54 | hyp = softmax(x) 55 | ref = torch.softmax(x, dim=-1) 56 | torch.testing.assert_close(hyp, ref) 57 | 58 | def benchmark_softmax(m, n): 59 | x = torch.randn((m, n), device="cuda") 60 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median") 61 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median") 62 | def throughput(t): 63 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 64 | return throughput(t1), throughput(t2) 65 | 66 | import pandas as pd 67 | def run_benchmark(): 68 | records = [] 69 | for m in [10, 128, 1024, 4096]: 70 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]: 71 | if m * n * 4 > 1024 * 1024 * 1024 * 4: 72 | continue 73 | t1, t2 = benchmark_softmax(m, n) 74 | record = (m, n, t1, t2) 75 | records.append(record) 76 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v1", "torch"]) 77 | print(df) 78 | df.to_excel("online_v1.xlsx") 79 | 80 | if __name__ == "__main__": 81 | run_benchmark() -------------------------------------------------------------------------------- /softmax_online_v2.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def softmax_kernel_online_v2( 7 | output_ptr, 8 | input_ptr, 9 | M, 10 | N, 11 | TILE_N: tl.constexpr, 12 | ): 13 | pid_m = tl.program_id(0) 14 | m = tl.full((TILE_N,), value=-float("inf"), dtype=output_ptr.dtype.element_ty) 15 | z = tl.full((TILE_N,), value=0, dtype=output_ptr.dtype.element_ty) 16 | for start_n in range(0, N, TILE_N): 17 | n_offsets = start_n + tl.arange(0, TILE_N) 18 | offset = pid_m * N + n_offsets 19 | input_ptrs = input_ptr + offset 20 | mask = n_offsets < N 21 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 22 | new_m = tl.maximum(m, inp) 23 | new_z = tl.exp(m - new_m) * z + tl.exp(inp - new_m) 24 | m = new_m 25 | z = new_z 26 | final_m = tl.max(m, 0) 27 | z = tl.sum(tl.exp(m - final_m) * z) 28 | m = final_m 29 | 30 | for start_n in range(0, N, TILE_N): 31 | n_offsets = start_n + tl.arange(0, TILE_N) 32 | offset = pid_m * N + n_offsets 33 | input_ptrs = input_ptr + offset 34 | mask = n_offsets < N 35 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 36 | e = tl.exp(inp - m) 37 | out = e / z 38 | output_ptrs = output_ptr + offset 39 | tl.store(output_ptrs, out, mask=mask) 40 | 41 | 42 | def softmax(x): 43 | M, N = x.shape 44 | out = torch.empty_like(x) 45 | # TODO tune 46 | TILE_N = min(4096, triton.next_power_of_2(N)) 47 | grid = (M, 1, 1) 48 | softmax_kernel_online_v2[grid](out, x, M, N, TILE_N) 49 | return out 50 | 51 | import pytest 52 | 53 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024]) 54 | @pytest.mark.parametrize("m", [10, 128]) 55 | def test_softmax(m, n): 56 | x = torch.randn((m, n), device="cuda") 57 | hyp = softmax(x) 58 | ref = torch.softmax(x, dim=-1) 59 | torch.testing.assert_close(hyp, ref) 60 | 61 | def benchmark_softmax(m, n): 62 | x = torch.randn((m, n), device="cuda") 63 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median") 64 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median") 65 | def throughput(t): 66 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 67 | return throughput(t1), throughput(t2) 68 | 69 | import pandas as pd 70 | def run_benchmark(): 71 | records = [] 72 | for m in [10, 128, 1024, 4096]: 73 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]: 74 | if m * n * 4 > 1024 * 1024 * 1024 * 4: 75 | continue 76 | t1, t2 = benchmark_softmax(m, n) 77 | record = (m, n, t1, t2) 78 | records.append(record) 79 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2", "torch"]) 80 | print(df) 81 | df.to_excel("online_v2.xlsx") 82 | 83 | if __name__ == "__main__": 84 | run_benchmark() -------------------------------------------------------------------------------- /softmax_online_v2_evict.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def softmax_kernel_online_v2( 7 | output_ptr, 8 | input_ptr, 9 | M, 10 | N, 11 | TILE_N: tl.constexpr, 12 | ): 13 | pid_m = tl.program_id(0) 14 | m = tl.full((TILE_N,), value=-float("inf"), dtype=output_ptr.dtype.element_ty) 15 | z = tl.full((TILE_N,), value=0, dtype=output_ptr.dtype.element_ty) 16 | for start_n in range(0, N, TILE_N): 17 | n_offsets = start_n + tl.arange(0, TILE_N) 18 | offset = pid_m * N + n_offsets 19 | input_ptrs = input_ptr + offset 20 | mask = n_offsets < N 21 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty) 22 | new_m = tl.maximum(m, inp) 23 | new_z = tl.exp(m - new_m) * z + tl.exp(inp - new_m) 24 | m = new_m 25 | z = new_z 26 | final_m = tl.max(m, 0) 27 | z = tl.sum(tl.exp(m - final_m) * z) 28 | m = final_m 29 | 30 | for start_n in range(0, N, TILE_N): 31 | n_offsets = start_n + tl.arange(0, TILE_N) 32 | offset = pid_m * N + n_offsets 33 | input_ptrs = input_ptr + offset 34 | mask = n_offsets < N 35 | inp = tl.load(input_ptrs, mask=mask, other=-float("inf"), eviction_policy="evict_first").to(output_ptr.dtype.element_ty) 36 | e = tl.exp(inp - m) 37 | out = e / z 38 | output_ptrs = output_ptr + offset 39 | tl.store(output_ptrs, out, mask=mask) 40 | 41 | 42 | def softmax(x): 43 | M, N = x.shape 44 | out = torch.empty_like(x) 45 | # TODO tune 46 | TILE_N = min(4096, triton.next_power_of_2(N)) 47 | grid = (M, 1, 1) 48 | softmax_kernel_online_v2[grid](out, x, M, N, TILE_N) 49 | return out 50 | 51 | import pytest 52 | 53 | @pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024]) 54 | @pytest.mark.parametrize("m", [10, 128]) 55 | def test_softmax(m, n): 56 | x = torch.randn((m, n), device="cuda") 57 | hyp = softmax(x) 58 | ref = torch.softmax(x, dim=-1) 59 | torch.testing.assert_close(hyp, ref) 60 | 61 | def benchmark_softmax(m, n): 62 | x = torch.randn((m, n), device="cuda") 63 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median") 64 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median") 65 | def throughput(t): 66 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 67 | return throughput(t1), throughput(t2) 68 | 69 | import pandas as pd 70 | def run_benchmark(): 71 | records = [] 72 | for m in [10, 128, 1024, 4096]: 73 | for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024]: 74 | if m * n * 4 > 1024 * 1024 * 1024 * 4: 75 | continue 76 | t1, t2 = benchmark_softmax(m, n) 77 | record = (m, n, t1, t2) 78 | records.append(record) 79 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_evict", "torch"]) 80 | print(df) 81 | df.to_excel("online_v2_evict.xlsx") 82 | 83 | if __name__ == "__main__": 84 | run_benchmark() -------------------------------------------------------------------------------- /softmax_online_v2_rev.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def next_multiple_of(a, b): 7 | # the smallest x>=a that x%b ==0 8 | return tl.cidv(a, b) * b 9 | 10 | 11 | @triton.jit 12 | def prev_multiple_of(a, b): 13 | # the largest x 1024 * 1024 * 1024 * 4: 87 | continue 88 | t1, t2 = benchmark_softmax(m, n) 89 | record = (m, n, t1, t2) 90 | records.append(record) 91 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_rev", "torch"]) 92 | print(df) 93 | df.to_excel("online_v2_rev.xlsx") 94 | 95 | if __name__ == "__main__": 96 | run_benchmark() -------------------------------------------------------------------------------- /softmax_online_v2_spec.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def next_multiple_of(a, b): 7 | # the smallest x>=a that x%b ==0 8 | return tl.cidv(a, b) * b 9 | 10 | 11 | @triton.jit 12 | def prev_multiple_of(a, b): 13 | # the largest x 1024 * 1024 * 1024 * 4: 107 | continue 108 | t1, t2 = benchmark_softmax(m, n) 109 | record = (m, n, t1, t2) 110 | records.append(record) 111 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_spec", "torch"]) 112 | print(df) 113 | df.to_excel("online_v2_spec.xlsx") 114 | 115 | if __name__ == "__main__": 116 | run_benchmark() -------------------------------------------------------------------------------- /softmax_online_v2_spec_rev.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def next_multiple_of(a, b): 7 | # the smallest x>=a that x%b ==0 8 | return tl.cidv(a, b) * b 9 | 10 | 11 | @triton.jit 12 | def prev_multiple_of(a, b): 13 | # the largest x 1024 * 1024 * 1024 * 4: 107 | continue 108 | t1, t2 = benchmark_softmax(m, n) 109 | record = (m, n, t1, t2) 110 | records.append(record) 111 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_spec_rev", "torch"]) 112 | print(df) 113 | df.to_excel("online_v2_spec_rev.xlsx") 114 | 115 | if __name__ == "__main__": 116 | run_benchmark() -------------------------------------------------------------------------------- /softmax_online_v2_spec_rev_evict.py: -------------------------------------------------------------------------------- 1 | import triton 2 | from triton import language as tl 3 | import torch 4 | 5 | @triton.jit 6 | def next_multiple_of(a, b): 7 | # the smallest x>=a that x%b ==0 8 | return tl.cidv(a, b) * b 9 | 10 | 11 | @triton.jit 12 | def prev_multiple_of(a, b): 13 | # the largest x 1024 * 1024 * 1024 * 4: 107 | continue 108 | t1, t2 = benchmark_softmax(m, n) 109 | record = (m, n, t1, t2) 110 | records.append(record) 111 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "online_v2_spec_rev_evict", "torch"]) 112 | print(df) 113 | df.to_excel("online_v2_spec_rev_evict.xlsx") 114 | 115 | if __name__ == "__main__": 116 | run_benchmark() -------------------------------------------------------------------------------- /softmax_split.py: -------------------------------------------------------------------------------- 1 | import math 2 | import triton 3 | from triton import language as tl 4 | import torch 5 | 6 | @triton.jit 7 | def logsumexp_kernel( 8 | out_ptr, 9 | in_ptr, 10 | M, 11 | N, 12 | TILE_N: tl.constexpr, 13 | ): 14 | pid_n = tl.program_id(0) 15 | num_programs_n = tl.num_programs(0) 16 | pid_m = tl.program_id(1) 17 | 18 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) 19 | mask = n_offsets < N 20 | offset = pid_m * N + n_offsets 21 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to(out_ptr.dtype.element_ty) 22 | m = tl.max(inp, 0) 23 | e = tl.exp(inp - m) 24 | z = tl.sum(e, 0) 25 | logz = m + tl.log(z) 26 | 27 | output_ptrs = out_ptr + pid_m * num_programs_n + pid_n 28 | tl.store(output_ptrs, logz) 29 | 30 | @triton.jit 31 | def combine_logsumexp_kernel(out_ptr, inp_ptr, M, N, TILE_N: tl.constexpr): 32 | pid_m = tl.program_id(0) 33 | n_offsets = tl.arange(0, TILE_N) 34 | mask = n_offsets < N 35 | logzs = tl.load(inp_ptr + pid_m * N + n_offsets, other=-float("inf"), mask=mask).to(out_ptr.dtype.element_ty) 36 | m = tl.max(logzs, 0) 37 | e = tl.exp(logzs - m) 38 | z = tl.sum(e, 0) 39 | logz = m + tl.log(z) 40 | tl.store(out_ptr + pid_m, logz) 41 | 42 | @triton.jit 43 | def softmax_kernel(out_ptr, in_ptr, logz_ptr, M, N, TILE_N: tl.constexpr): 44 | pid_n = tl.program_id(0) 45 | pid_m = tl.program_id(1) 46 | n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) 47 | offset = pid_m * N + n_offsets 48 | mask = n_offsets < N 49 | inp = tl.load(in_ptr + offset, mask=mask, other=-float("inf")).to(out_ptr.dtype.element_ty) 50 | logz = tl.load(logz_ptr + pid_m).to(out_ptr.dtype.element_ty) 51 | out = tl.exp(inp - logz) 52 | tl.store(out_ptr + offset, out, mask=mask) 53 | 54 | 55 | 56 | def softmax(x): 57 | M, N = x.shape 58 | 59 | num_sms = torch.cuda.get_device_properties(x.device).multi_processor_count 60 | 61 | TILE_N = min(4096, triton.next_power_of_2(N)) 62 | num_tiles_n = triton.cdiv(N, TILE_N) 63 | logz = torch.empty((M, num_tiles_n), dtype=x.dtype, device=x.device) 64 | grid = (num_tiles_n, M, 1) 65 | logsumexp_kernel[grid](logz, x, M, N, TILE_N) 66 | 67 | combined_logz = torch.empty((M, ), dtype=x.dtype, device=x.device) 68 | TILE_N = triton.next_power_of_2(num_tiles_n) 69 | grid = (M, 1, 1) 70 | combine_logsumexp_kernel[grid](combined_logz, logz, M, num_tiles_n, TILE_N) 71 | 72 | out = torch.empty_like(x) 73 | TILE_N = min(4096, triton.next_power_of_2(N)) 74 | num_tiles_n = triton.cdiv(N, TILE_N) 75 | grid = (num_tiles_n, M, 1) 76 | softmax_kernel[grid](out, x, combined_logz, M, N, TILE_N) 77 | return out 78 | 79 | import pytest 80 | 81 | @pytest.mark.parametrize("n", [128 * 1024, 1024 * 1024, 8192 * 1024]) 82 | @pytest.mark.parametrize("m", [10, ]) 83 | def test_softmax(m, n): 84 | x = torch.randn((m, n), device="cuda") 85 | hyp = softmax(x) 86 | ref = torch.softmax(x, dim=-1) 87 | torch.testing.assert_close(hyp, ref) 88 | 89 | def benchmark_softmax(m, n): 90 | x = torch.randn((m, n), device="cuda") 91 | t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median") 92 | t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median") 93 | def throughput(t): 94 | return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3) 95 | return throughput(t1), throughput(t2) 96 | 97 | import pandas as pd 98 | def run_benchmark(): 99 | records = [] 100 | for m in [1, 8, 16, 32]: 101 | for n in [16 * 1024, 32* 1024, 64* 1024, 128 * 1024, 256 * 1024, 512 * 1024, 1024 * 1024, 2048 * 1024, 4096 * 1024, 8192 * 1024]: 102 | if m * n * 4 > 1024 * 1024 * 1024 * 4: 103 | continue 104 | t1, t2 = benchmark_softmax(m, n) 105 | record = (m, n, t1, t2) 106 | records.append(record) 107 | df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "split", "torch"]) 108 | print(df) 109 | df.to_excel("split.xlsx") 110 | 111 | if __name__ == "__main__": 112 | run_benchmark() --------------------------------------------------------------------------------